update for torch 0.3
This commit is contained in:
@ -4,7 +4,7 @@ My Youtube Channel: https://www.youtube.com/user/MorvanZhou
|
|||||||
More about Reinforcement learning: https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/
|
More about Reinforcement learning: https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/
|
||||||
|
|
||||||
Dependencies:
|
Dependencies:
|
||||||
torch: 0.2
|
torch: 0.3
|
||||||
gym: 0.8.1
|
gym: 0.8.1
|
||||||
numpy
|
numpy
|
||||||
"""
|
"""
|
||||||
@ -60,7 +60,7 @@ class DQN(object):
|
|||||||
if np.random.uniform() < EPSILON: # greedy
|
if np.random.uniform() < EPSILON: # greedy
|
||||||
actions_value = self.eval_net.forward(x)
|
actions_value = self.eval_net.forward(x)
|
||||||
action = torch.max(actions_value, 1)[1].data.numpy()
|
action = torch.max(actions_value, 1)[1].data.numpy()
|
||||||
action = action[0, 0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) # return the argmax index
|
action = action[0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) # return the argmax index
|
||||||
else: # random
|
else: # random
|
||||||
action = np.random.randint(0, N_ACTIONS)
|
action = np.random.randint(0, N_ACTIONS)
|
||||||
action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)
|
action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)
|
||||||
|
|||||||
Reference in New Issue
Block a user