From 528d7ccd41addf72c01fcfb85b205e14f51a44af Mon Sep 17 00:00:00 2001 From: Fuyang Liu Date: Mon, 19 Jun 2017 21:49:39 +0200 Subject: [PATCH] Fix up test code --- tutorial-contents/401_CNN.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorial-contents/401_CNN.py b/tutorial-contents/401_CNN.py index 00775d0..f765dbb 100644 --- a/tutorial-contents/401_CNN.py +++ b/tutorial-contents/401_CNN.py @@ -122,7 +122,7 @@ for epoch in range(EPOCH): plt.ioff() # print 10 predictions from test data -test_output = cnn(test_x[:10]) +test_output, _ = cnn(test_x[:10]) pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze() print(pred_y, 'prediction number') print(test_y[:10].numpy(), 'real number')