This commit is contained in:
morvanzhou
2018-11-07 15:58:12 +08:00
parent 889455575a
commit ce55cc9446

View File

@ -183,7 +183,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -208,29 +208,29 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "ename": "RuntimeError",
"output_type": "stream", "evalue": "dot: Expected 1-D argument self, but got 2-D",
"text": [ "traceback": [
"\n", "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"matrix multiplication (dot) \n", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"numpy: [[ 7 10]\n", "\u001b[0;32m<ipython-input-3-a29f9258176b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m'\\nmatrix multiplication (dot)'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m'\\nnumpy: '\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# [[7, 10], [15, 22]]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0;34m'\\ntorch: '\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# 30.0. Beware that torch.dot does not broadcast, only works for 1-dimensional tensor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m )\n",
" [15 22]] \n", "\u001b[0;31mRuntimeError\u001b[0m: dot: Expected 1-D argument self, but got 2-D"
"torch: 30.0\n" ],
] "output_type": "error"
} }
], ],
"source": [ "source": [
"# incorrect method\n", "# incorrect method\n",
"data = np.array(data)\n", "data = np.array(data)\n",
"tensor = torch.Tensor([1,2,3,4]\n", "tensor = torch.Tensor(data)\n",
"print(\n", "print(\n",
" '\\nmatrix multiplication (dot)',\n", " '\\nmatrix multiplication (dot)',\n",
" '\\nnumpy: ', data.dot(data), # [[7, 10], [15, 22]]\n", " '\\nnumpy: ', data.dot(data), # [[7, 10], [15, 22]]\n",
" '\\ntorch: ', torch.dot(tensor.dot(tensor) # 30.0. Beware that torch.dot does not broadcast, only works for 1-dimensional tensor\n", " '\\ntorch: ', torch.dot(tensor.dot(tensor)) # NOT WORKING! Beware that torch.dot does not broadcast, only works for 1-dimensional tensor\n",
")" ")"
] ]
}, },