fix action shape problem
This commit is contained in:
@ -26,14 +26,15 @@ env = gym.make('CartPole-v0')
|
|||||||
env = env.unwrapped
|
env = env.unwrapped
|
||||||
N_ACTIONS = env.action_space.n
|
N_ACTIONS = env.action_space.n
|
||||||
N_STATES = env.observation_space.shape[0]
|
N_STATES = env.observation_space.shape[0]
|
||||||
|
ENV_A_SHAPE = 0 if isinstance(env.action_space.sample(), int) else env.action_space.sample().shape # to confirm the shape
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self, ):
|
def __init__(self, ):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
self.fc1 = nn.Linear(N_STATES, 10)
|
self.fc1 = nn.Linear(N_STATES, 50)
|
||||||
self.fc1.weight.data.normal_(0, 0.1) # initialization
|
self.fc1.weight.data.normal_(0, 0.1) # initialization
|
||||||
self.out = nn.Linear(10, N_ACTIONS)
|
self.out = nn.Linear(50, N_ACTIONS)
|
||||||
self.out.weight.data.normal_(0, 0.1) # initialization
|
self.out.weight.data.normal_(0, 0.1) # initialization
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -58,9 +59,11 @@ class DQN(object):
|
|||||||
# input only one sample
|
# input only one sample
|
||||||
if np.random.uniform() < EPSILON: # greedy
|
if np.random.uniform() < EPSILON: # greedy
|
||||||
actions_value = self.eval_net.forward(x)
|
actions_value = self.eval_net.forward(x)
|
||||||
action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmax
|
action = torch.max(actions_value, 1)[1].data.numpy()
|
||||||
|
action = action[0, 0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) # return the argmax index
|
||||||
else: # random
|
else: # random
|
||||||
action = np.random.randint(0, N_ACTIONS)
|
action = np.random.randint(0, N_ACTIONS)
|
||||||
|
action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def store_transition(self, s, a, r, s_):
|
def store_transition(self, s, a, r, s_):
|
||||||
|
|||||||
Reference in New Issue
Block a user