From ebf9bfd25033e1bc4d0c9cd7f32cba1e69ddf251 Mon Sep 17 00:00:00 2001 From: Morvan Zhou Date: Thu, 4 Jan 2018 16:56:37 +1100 Subject: [PATCH] update for torch 0.3 --- tutorial-contents/302_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorial-contents/302_classification.py b/tutorial-contents/302_classification.py index 72ed82e..8baa102 100644 --- a/tutorial-contents/302_classification.py +++ b/tutorial-contents/302_classification.py @@ -59,7 +59,7 @@ for t in range(100): if t % 2 == 0: # plot and show learning process plt.cla() - prediction = torch.max(F.softmax(out), 1)[1] + prediction = torch.max(out, 1)[1] pred_y = prediction.data.numpy().squeeze() target_y = y.data.numpy() plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')