update for torch 0.3
This commit is contained in:
@ -59,7 +59,7 @@ for t in range(100):
|
||||
if t % 2 == 0:
|
||||
# plot and show learning process
|
||||
plt.cla()
|
||||
prediction = torch.max(F.softmax(out), 1)[1]
|
||||
prediction = torch.max(out, 1)[1]
|
||||
pred_y = prediction.data.numpy().squeeze()
|
||||
target_y = y.data.numpy()
|
||||
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
|
||||
|
||||
Reference in New Issue
Block a user