Merge pull request #53 from keineahnung2345/302-squeeze

302 - remove .squeeze() from 1-D numpy array
This commit is contained in:
Morvan
2018-11-08 19:43:20 +08:00
committed by GitHub

View File

@ -59,7 +59,7 @@ for t in range(100):
# plot and show learning process
plt.cla()
prediction = torch.max(out, 1)[1]
pred_y = prediction.data.numpy().squeeze()
pred_y = prediction.data.numpy()
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
@ -67,4 +67,4 @@ for t in range(100):
plt.pause(0.1)
plt.ioff()
plt.show()
plt.show()