@ -122,7 +122,7 @@ for epoch in range(EPOCH):
|
|||||||
plt.ioff()
|
plt.ioff()
|
||||||
|
|
||||||
# print 10 predictions from test data
|
# print 10 predictions from test data
|
||||||
test_output = cnn(test_x[:10])
|
test_output, _ = cnn(test_x[:10])
|
||||||
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
|
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
|
||||||
print(pred_y, 'prediction number')
|
print(pred_y, 'prediction number')
|
||||||
print(test_y[:10].numpy(), 'real number')
|
print(test_y[:10].numpy(), 'real number')
|
||||||
|
|||||||
Reference in New Issue
Block a user