diff --git a/tutorial-contents/302_classification.py b/tutorial-contents/302_classification.py index 0b96e2f..62dbcc1 100644 --- a/tutorial-contents/302_classification.py +++ b/tutorial-contents/302_classification.py @@ -59,7 +59,7 @@ for t in range(100): # plot and show learning process plt.cla() prediction = torch.max(out, 1)[1] - pred_y = prediction.data.numpy().squeeze() + pred_y = prediction.data.numpy() target_y = y.data.numpy() plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn') accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size) @@ -67,4 +67,4 @@ for t in range(100): plt.pause(0.1) plt.ioff() -plt.show() \ No newline at end of file +plt.show()