From 3b23f6bced16eab141c2f2eed7b8849ef1c31719 Mon Sep 17 00:00:00 2001 From: Morvan Zhou Date: Tue, 5 Sep 2017 12:57:25 +1000 Subject: [PATCH] update for torch 0.2 --- tutorial-contents/405_DQN_Reinforcement_learning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tutorial-contents/405_DQN_Reinforcement_learning.py b/tutorial-contents/405_DQN_Reinforcement_learning.py index 13e456a..3526cd8 100644 --- a/tutorial-contents/405_DQN_Reinforcement_learning.py +++ b/tutorial-contents/405_DQN_Reinforcement_learning.py @@ -58,7 +58,7 @@ class DQN(object): # input only one sample if np.random.uniform() < EPSILON: # greedy actions_value = self.eval_net.forward(x) - action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmax + action = torch.max(actions_value, 1)[1].data.numpy()[0] # return the argmax else: # random action = np.random.randint(0, N_ACTIONS) return action @@ -87,7 +87,7 @@ class DQN(object): # q_eval w.r.t the action in experience q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1) q_next = self.target_net(b_s_).detach() # detach from graph, don't backpropagate - q_target = b_r + GAMMA * q_next.max(1)[0] # shape (batch, 1) + q_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1) # shape (batch, 1) loss = self.loss_func(q_eval, q_target) self.optimizer.zero_grad()