Fix up test code
This commit is contained in:
@ -122,7 +122,7 @@ for epoch in range(EPOCH):
|
||||
plt.ioff()
|
||||
|
||||
# print 10 predictions from test data
|
||||
test_output = cnn(test_x[:10])
|
||||
test_output, _ = cnn(test_x[:10])
|
||||
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
|
||||
print(pred_y, 'prediction number')
|
||||
print(test_y[:10].numpy(), 'real number')
|
||||
|
||||
Reference in New Issue
Block a user