Merge pull request #53 from keineahnung2345/302-squeeze
302 - remove .squeeze() from 1-D numpy array
This commit is contained in:
@ -59,7 +59,7 @@ for t in range(100):
|
||||
# plot and show learning process
|
||||
plt.cla()
|
||||
prediction = torch.max(out, 1)[1]
|
||||
pred_y = prediction.data.numpy().squeeze()
|
||||
pred_y = prediction.data.numpy()
|
||||
target_y = y.data.numpy()
|
||||
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
|
||||
accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
|
||||
@ -67,4 +67,4 @@ for t in range(100):
|
||||
plt.pause(0.1)
|
||||
|
||||
plt.ioff()
|
||||
plt.show()
|
||||
plt.show()
|
||||
|
||||
Reference in New Issue
Block a user