update for torch 0.3

This commit is contained in:
Morvan Zhou
2018-01-02 15:05:18 +11:00
parent 5df56c99dd
commit f65e27f71e

View File

@ -4,7 +4,7 @@ My Youtube Channel: https://www.youtube.com/user/MorvanZhou
More about Reinforcement learning: https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/
Dependencies:
torch: 0.2
torch: 0.3
gym: 0.8.1
numpy
"""
@ -60,7 +60,7 @@ class DQN(object):
if np.random.uniform() < EPSILON: # greedy
actions_value = self.eval_net.forward(x)
action = torch.max(actions_value, 1)[1].data.numpy()
action = action[0, 0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) # return the argmax index
action = action[0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) # return the argmax index
else: # random
action = np.random.randint(0, N_ACTIONS)
action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)