diff --git a/tutorial-contents/302_classification.py b/tutorial-contents/302_classification.py index 72ed82e..8baa102 100644 --- a/tutorial-contents/302_classification.py +++ b/tutorial-contents/302_classification.py @@ -59,7 +59,7 @@ for t in range(100): if t % 2 == 0: # plot and show learning process plt.cla() - prediction = torch.max(F.softmax(out), 1)[1] + prediction = torch.max(out, 1)[1] pred_y = prediction.data.numpy().squeeze() target_y = y.data.numpy() plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')