302 - remove .squeeze() from 1-D numpy array

prediction.data.numpy() is already 1-D, so the .squeeze() is unnecessary.
This commit is contained in:
keineahnung2345
2018-11-08 14:14:35 +08:00
committed by GitHub
parent 906cf71b6f
commit 5b1d191946

View File

@ -59,7 +59,7 @@ for t in range(100):
# plot and show learning process
plt.cla()
prediction = torch.max(out, 1)[1]
pred_y = prediction.data.numpy().squeeze()
pred_y = prediction.data.numpy()
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
@ -67,4 +67,4 @@ for t in range(100):
plt.pause(0.1)
plt.ioff()
plt.show()
plt.show()