diff --git a/tutorial-contents/405_DQN_Reinforcement_learning.py b/tutorial-contents/405_DQN_Reinforcement_learning.py index f81c6b8..393b41d 100644 --- a/tutorial-contents/405_DQN_Reinforcement_learning.py +++ b/tutorial-contents/405_DQN_Reinforcement_learning.py @@ -4,7 +4,7 @@ My Youtube Channel: https://www.youtube.com/user/MorvanZhou More about Reinforcement learning: https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/ Dependencies: -torch: 0.2 +torch: 0.3 gym: 0.8.1 numpy """ @@ -60,7 +60,7 @@ class DQN(object): if np.random.uniform() < EPSILON: # greedy actions_value = self.eval_net.forward(x) action = torch.max(actions_value, 1)[1].data.numpy() - action = action[0, 0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) # return the argmax index + action = action[0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) # return the argmax index else: # random action = np.random.randint(0, N_ACTIONS) action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)