diff --git a/tutorial-contents/405_DQN_Reinforcement_learning.py b/tutorial-contents/405_DQN_Reinforcement_learning.py index 20fa80c..f81c6b8 100644 --- a/tutorial-contents/405_DQN_Reinforcement_learning.py +++ b/tutorial-contents/405_DQN_Reinforcement_learning.py @@ -26,14 +26,15 @@ env = gym.make('CartPole-v0') env = env.unwrapped N_ACTIONS = env.action_space.n N_STATES = env.observation_space.shape[0] +ENV_A_SHAPE = 0 if isinstance(env.action_space.sample(), int) else env.action_space.sample().shape # to confirm the shape class Net(nn.Module): def __init__(self, ): super(Net, self).__init__() - self.fc1 = nn.Linear(N_STATES, 10) + self.fc1 = nn.Linear(N_STATES, 50) self.fc1.weight.data.normal_(0, 0.1) # initialization - self.out = nn.Linear(10, N_ACTIONS) + self.out = nn.Linear(50, N_ACTIONS) self.out.weight.data.normal_(0, 0.1) # initialization def forward(self, x): @@ -58,9 +59,11 @@ class DQN(object): # input only one sample if np.random.uniform() < EPSILON: # greedy actions_value = self.eval_net.forward(x) - action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmax + action = torch.max(actions_value, 1)[1].data.numpy() + action = action[0, 0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) # return the argmax index else: # random action = np.random.randint(0, N_ACTIONS) + action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) return action def store_transition(self, s, a, r, s_):