From b048c4dd84653dc703038481fa5fd92a1fbea37c Mon Sep 17 00:00:00 2001 From: Dolly Ye <1375373964@qq.com> Date: Sat, 17 Feb 2018 22:09:20 -0800 Subject: [PATCH] torch.dot does not broadcast torch.dot() can only work for 1 dimension tensor. http://pytorch.org/docs/master/torch.html https://github.com/pytorch/pytorch/issues/2313 --- tutorial-contents-notebooks/201_torch_numpy.ipynb | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tutorial-contents-notebooks/201_torch_numpy.ipynb b/tutorial-contents-notebooks/201_torch_numpy.ipynb index d2cec63..ecf445a 100644 --- a/tutorial-contents-notebooks/201_torch_numpy.ipynb +++ b/tutorial-contents-notebooks/201_torch_numpy.ipynb @@ -279,10 +279,11 @@ "source": [ "# incorrect method\n", "data = np.array(data)\n", + "tensor = torch.Tensor([1,2,3,4]\n", "print(\n", " '\\nmatrix multiplication (dot)',\n", " '\\nnumpy: ', data.dot(data), # [[7, 10], [15, 22]]\n", - " '\\ntorch: ', tensor.dot(tensor) # this will convert tensor to [1,2,3,4], you'll get 30.0\n", + " '\\ntorch: ', torch.dot(tensor.dot(tensor) # 30.0. Beware that torch.dot does not broadcast, only works for 1-dimensional tensor\n", ")" ] }, @@ -360,7 +361,8 @@ } ], "source": [ - "tensor.dot(tensor)" + "torch.dot(torch.Tensor([2, 3]), torch.Tensor([2, 1])) +7.0" ] }, {