diff --git a/tutorial-contents/401_CNN.py b/tutorial-contents/401_CNN.py index 00775d0..f765dbb 100644 --- a/tutorial-contents/401_CNN.py +++ b/tutorial-contents/401_CNN.py @@ -122,7 +122,7 @@ for epoch in range(EPOCH): plt.ioff() # print 10 predictions from test data -test_output = cnn(test_x[:10]) +test_output, _ = cnn(test_x[:10]) pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze() print(pred_y, 'prediction number') print(test_y[:10].numpy(), 'real number')