fix action shape
This commit is contained in:
@ -58,7 +58,7 @@ class DQN(object):
|
||||
# input only one sample
|
||||
if np.random.uniform() < EPSILON: # greedy
|
||||
actions_value = self.eval_net.forward(x)
|
||||
action = torch.max(actions_value, 1)[1].data.numpy()[0] # return the argmax
|
||||
action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmax
|
||||
else: # random
|
||||
action = np.random.randint(0, N_ACTIONS)
|
||||
return action
|
||||
|
||||
Reference in New Issue
Block a user