[Bug fix] 401_CNN.py
1. fix typo 2. remove .squeeze for 1-D tensor
This commit is contained in:
@ -65,7 +65,7 @@ class CNN(nn.Module):
|
|||||||
out_channels=16, # n_filters
|
out_channels=16, # n_filters
|
||||||
kernel_size=5, # filter size
|
kernel_size=5, # filter size
|
||||||
stride=1, # filter movement/step
|
stride=1, # filter movement/step
|
||||||
padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
|
padding=2, # if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1
|
||||||
), # output shape (16, 28, 28)
|
), # output shape (16, 28, 28)
|
||||||
nn.ReLU(), # activation
|
nn.ReLU(), # activation
|
||||||
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14)
|
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14)
|
||||||
@ -115,7 +115,7 @@ for epoch in range(EPOCH):
|
|||||||
|
|
||||||
if step % 50 == 0:
|
if step % 50 == 0:
|
||||||
test_output, last_layer = cnn(test_x)
|
test_output, last_layer = cnn(test_x)
|
||||||
pred_y = torch.max(test_output, 1)[1].data.squeeze().numpy()
|
pred_y = torch.max(test_output, 1)[1].data.numpy()
|
||||||
accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
|
accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
|
||||||
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
|
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
|
||||||
if HAS_SK:
|
if HAS_SK:
|
||||||
@ -129,6 +129,6 @@ plt.ioff()
|
|||||||
|
|
||||||
# print 10 predictions from test data
|
# print 10 predictions from test data
|
||||||
test_output, _ = cnn(test_x[:10])
|
test_output, _ = cnn(test_x[:10])
|
||||||
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
|
pred_y = torch.max(test_output, 1)[1].data.numpy()
|
||||||
print(pred_y, 'prediction number')
|
print(pred_y, 'prediction number')
|
||||||
print(test_y[:10].numpy(), 'real number')
|
print(test_y[:10].numpy(), 'real number')
|
||||||
|
|||||||
Reference in New Issue
Block a user