update for torch 0.2
This commit is contained in:
@ -58,7 +58,7 @@ class DQN(object):
|
|||||||
# input only one sample
|
# input only one sample
|
||||||
if np.random.uniform() < EPSILON: # greedy
|
if np.random.uniform() < EPSILON: # greedy
|
||||||
actions_value = self.eval_net.forward(x)
|
actions_value = self.eval_net.forward(x)
|
||||||
action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmax
|
action = torch.max(actions_value, 1)[1].data.numpy()[0] # return the argmax
|
||||||
else: # random
|
else: # random
|
||||||
action = np.random.randint(0, N_ACTIONS)
|
action = np.random.randint(0, N_ACTIONS)
|
||||||
return action
|
return action
|
||||||
@ -87,7 +87,7 @@ class DQN(object):
|
|||||||
# q_eval w.r.t the action in experience
|
# q_eval w.r.t the action in experience
|
||||||
q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1)
|
q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1)
|
||||||
q_next = self.target_net(b_s_).detach() # detach from graph, don't backpropagate
|
q_next = self.target_net(b_s_).detach() # detach from graph, don't backpropagate
|
||||||
q_target = b_r + GAMMA * q_next.max(1)[0] # shape (batch, 1)
|
q_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1) # shape (batch, 1)
|
||||||
loss = self.loss_func(q_eval, q_target)
|
loss = self.loss_func(q_eval, q_target)
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|||||||
Reference in New Issue
Block a user