Merge pull request #17 from okdolly/patch-1

torch.dot does not broadcast
This commit is contained in:
Morvan
2018-02-26 09:39:11 +08:00
committed by GitHub

View File

@ -279,10 +279,11 @@
"source": [ "source": [
"# incorrect method\n", "# incorrect method\n",
"data = np.array(data)\n", "data = np.array(data)\n",
"tensor = torch.Tensor([1,2,3,4]\n",
"print(\n", "print(\n",
" '\\nmatrix multiplication (dot)',\n", " '\\nmatrix multiplication (dot)',\n",
" '\\nnumpy: ', data.dot(data), # [[7, 10], [15, 22]]\n", " '\\nnumpy: ', data.dot(data), # [[7, 10], [15, 22]]\n",
" '\\ntorch: ', tensor.dot(tensor) # this will convert tensor to [1,2,3,4], you'll get 30.0\n", " '\\ntorch: ', torch.dot(tensor.dot(tensor) # 30.0. Beware that torch.dot does not broadcast, only works for 1-dimensional tensor\n",
")" ")"
] ]
}, },
@ -360,7 +361,8 @@
} }
], ],
"source": [ "source": [
"tensor.dot(tensor)" "torch.dot(torch.Tensor([2, 3]), torch.Tensor([2, 1]))
7.0"
] ]
}, },
{ {