Merge pull request #17 from okdolly/patch-1
torch.dot does not broadcast
This commit is contained in:
@ -279,10 +279,11 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# incorrect method\n",
|
"# incorrect method\n",
|
||||||
"data = np.array(data)\n",
|
"data = np.array(data)\n",
|
||||||
|
"tensor = torch.Tensor([1,2,3,4]\n",
|
||||||
"print(\n",
|
"print(\n",
|
||||||
" '\\nmatrix multiplication (dot)',\n",
|
" '\\nmatrix multiplication (dot)',\n",
|
||||||
" '\\nnumpy: ', data.dot(data), # [[7, 10], [15, 22]]\n",
|
" '\\nnumpy: ', data.dot(data), # [[7, 10], [15, 22]]\n",
|
||||||
" '\\ntorch: ', tensor.dot(tensor) # this will convert tensor to [1,2,3,4], you'll get 30.0\n",
|
" '\\ntorch: ', torch.dot(tensor.dot(tensor) # 30.0. Beware that torch.dot does not broadcast, only works for 1-dimensional tensor\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -360,7 +361,8 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"tensor.dot(tensor)"
|
"torch.dot(torch.Tensor([2, 3]), torch.Tensor([2, 1]))
|
||||||
|
7.0"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user