This commit is contained in:
Morvan Zhou
2017-05-08 12:48:29 +10:00
committed by Morvan Zhou
parent 468039f49c
commit b212b3e026
7 changed files with 80 additions and 84 deletions

View File

@ -46,8 +46,8 @@ class DQN(object):
def __init__(self):
self.eval_net, self.target_net = Net(), Net()
self.learn_step_counter = 0 # for target updateing
self.memory_counter = 0 # for storing memory
self.learn_step_counter = 0 # for target updating
self.memory_counter = 0 # for storing memory
self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # initialize memory
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
self.loss_func = nn.MSELoss()
@ -100,7 +100,6 @@ for i_episode in range(400):
ep_r = 0
while True:
env.render()
a = dqn.choose_action(s)
# take action
@ -112,7 +111,6 @@ for i_episode in range(400):
r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
r = r1 + r2
# store experience
dqn.store_transition(s, a, r, s_)
ep_r += r
@ -120,10 +118,8 @@ for i_episode in range(400):
dqn.learn()
if done:
print('Ep: ', i_episode,
'| Ep_r: ', round(ep_r, 2),
)
'| Ep_r: ', round(ep_r, 2))
if done:
break
s = s_