update for torch 0.3

This commit is contained in:
Morvan Zhou
2018-01-04 16:56:37 +11:00
parent f65e27f71e
commit ebf9bfd250

View File

@ -59,7 +59,7 @@ for t in range(100):
if t % 2 == 0:
# plot and show learning process
plt.cla()
prediction = torch.max(F.softmax(out), 1)[1]
prediction = torch.max(out, 1)[1]
pred_y = prediction.data.numpy().squeeze()
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')