[Bug fix] 401_CNN.py
1. fix typo 2. remove .squeeze for 1-D tensor
This commit is contained in:
@ -65,7 +65,7 @@ class CNN(nn.Module):
|
||||
out_channels=16, # n_filters
|
||||
kernel_size=5, # filter size
|
||||
stride=1, # filter movement/step
|
||||
padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
|
||||
padding=2, # if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1
|
||||
), # output shape (16, 28, 28)
|
||||
nn.ReLU(), # activation
|
||||
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14)
|
||||
@ -115,7 +115,7 @@ for epoch in range(EPOCH):
|
||||
|
||||
if step % 50 == 0:
|
||||
test_output, last_layer = cnn(test_x)
|
||||
pred_y = torch.max(test_output, 1)[1].data.squeeze().numpy()
|
||||
pred_y = torch.max(test_output, 1)[1].data.numpy()
|
||||
accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
|
||||
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
|
||||
if HAS_SK:
|
||||
@ -129,6 +129,6 @@ plt.ioff()
|
||||
|
||||
# print 10 predictions from test data
|
||||
test_output, _ = cnn(test_x[:10])
|
||||
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
|
||||
pred_y = torch.max(test_output, 1)[1].data.numpy()
|
||||
print(pred_y, 'prediction number')
|
||||
print(test_y[:10].numpy(), 'real number')
|
||||
|
||||
Reference in New Issue
Block a user