Fix up test code

This commit is contained in:
Fuyang Liu
2017-06-19 21:49:39 +02:00
parent 51f1c938f3
commit 528d7ccd41

View File

@ -122,7 +122,7 @@ for epoch in range(EPOCH):
plt.ioff()
# print 10 predictions from test data
test_output = cnn(test_x[:10])
test_output, _ = cnn(test_x[:10])
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(pred_y, 'prediction number')
print(test_y[:10].numpy(), 'real number')