From a9ef65ed9ac8ea1786425b8ea99060793162a212 Mon Sep 17 00:00:00 2001 From: keineahnung2345 Date: Fri, 9 Nov 2018 17:26:27 +0800 Subject: [PATCH] [Bug fix] 401_CNN.py 1. fix typo 2. remove .squeeze for 1-D tensor --- tutorial-contents/401_CNN.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tutorial-contents/401_CNN.py b/tutorial-contents/401_CNN.py index 1f9daea..8016158 100644 --- a/tutorial-contents/401_CNN.py +++ b/tutorial-contents/401_CNN.py @@ -65,7 +65,7 @@ class CNN(nn.Module): out_channels=16, # n_filters kernel_size=5, # filter size stride=1, # filter movement/step - padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1 + padding=2, # if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1 ), # output shape (16, 28, 28) nn.ReLU(), # activation nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14) @@ -115,7 +115,7 @@ for epoch in range(EPOCH): if step % 50 == 0: test_output, last_layer = cnn(test_x) - pred_y = torch.max(test_output, 1)[1].data.squeeze().numpy() + pred_y = torch.max(test_output, 1)[1].data.numpy() accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0)) print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy) if HAS_SK: @@ -129,6 +129,6 @@ plt.ioff() # print 10 predictions from test data test_output, _ = cnn(test_x[:10]) -pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze() +pred_y = torch.max(test_output, 1)[1].data.numpy() print(pred_y, 'prediction number') print(test_y[:10].numpy(), 'real number')