fix action shape

This commit is contained in:
Morvan Zhou
2017-10-27 09:06:10 +11:00
committed by Morvan Zhou
parent ce4a8286fd
commit a7b14b8091

View File

@ -58,7 +58,7 @@ class DQN(object):
# input only one sample
if np.random.uniform() < EPSILON: # greedy
actions_value = self.eval_net.forward(x)
action = torch.max(actions_value, 1)[1].data.numpy()[0] # return the argmax
action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmax
else: # random
action = np.random.randint(0, N_ACTIONS)
return action