This commit is contained in:
Morvan Zhou
2017-05-08 12:48:29 +10:00
committed by Morvan Zhou
parent 468039f49c
commit b212b3e026
7 changed files with 80 additions and 84 deletions

View File

@ -20,7 +20,7 @@ torch.manual_seed(1) # reproducible
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50 BATCH_SIZE = 50
LR = 0.001 # learning rate LR = 0.001 # learning rate
DOWNLOAD_MNIST = False DOWNLOAD_MNIST = True # set to False if you have downloaded
# Mnist digits dataset # Mnist digits dataset

View File

@ -23,7 +23,7 @@ BATCH_SIZE = 64
TIME_STEP = 28 # rnn time step / image height TIME_STEP = 28 # rnn time step / image height
INPUT_SIZE = 28 # rnn input size / image width INPUT_SIZE = 28 # rnn input size / image width
LR = 0.01 # learning rate LR = 0.01 # learning rate
DOWNLOAD_MNIST = False # set to True if haven't download the data DOWNLOAD_MNIST = True # set to True if haven't download the data
# Mnist digital dataset # Mnist digital dataset

View File

@ -39,9 +39,9 @@ train_data = torchvision.datasets.MNIST(
# plot one example # plot one example
print(train_data.train_data.size()) # (60000, 28, 28) print(train_data.train_data.size()) # (60000, 28, 28)
print(train_data.train_labels.size()) # (60000) print(train_data.train_labels.size()) # (60000)
# plt.imshow(train_data.train_data[2].numpy(), cmap='gray') plt.imshow(train_data.train_data[2].numpy(), cmap='gray')
# plt.title('%i' % train_data.train_labels[2]) plt.title('%i' % train_data.train_labels[2])
# plt.show() plt.show()
# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28) # Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

View File

@ -46,7 +46,7 @@ class DQN(object):
def __init__(self): def __init__(self):
self.eval_net, self.target_net = Net(), Net() self.eval_net, self.target_net = Net(), Net()
self.learn_step_counter = 0 # for target updateing self.learn_step_counter = 0 # for target updating
self.memory_counter = 0 # for storing memory self.memory_counter = 0 # for storing memory
self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # initialize memory self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # initialize memory
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR) self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
@ -100,7 +100,6 @@ for i_episode in range(400):
ep_r = 0 ep_r = 0
while True: while True:
env.render() env.render()
a = dqn.choose_action(s) a = dqn.choose_action(s)
# take action # take action
@ -112,7 +111,6 @@ for i_episode in range(400):
r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5 r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
r = r1 + r2 r = r1 + r2
# store experience
dqn.store_transition(s, a, r, s_) dqn.store_transition(s, a, r, s_)
ep_r += r ep_r += r
@ -120,10 +118,8 @@ for i_episode in range(400):
dqn.learn() dqn.learn()
if done: if done:
print('Ep: ', i_episode, print('Ep: ', i_episode,
'| Ep_r: ', round(ep_r, 2), '| Ep_r: ', round(ep_r, 2))
)
if done: if done:
break break
s = s_ s = s_

View File

@ -20,7 +20,7 @@ BATCH_SIZE = 64
TIME_STEP = 5 # rnn time step / image height TIME_STEP = 5 # rnn time step / image height
INPUT_SIZE = 1 # rnn input size / image width INPUT_SIZE = 1 # rnn input size / image width
LR = 0.02 # learning rate LR = 0.02 # learning rate
DOWNLOAD_MNIST = False # set to True if haven't download the data DOWNLOAD_MNIST = True # set to False if have downloaded the data
class RNN(nn.Module): class RNN(nn.Module):

View File

@ -147,7 +147,7 @@ for epoch in range(EPOCH):
loss = loss_func(pred, b_y) loss = loss_func(pred, b_y)
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
opt.step() # it will also learn the parameters in Batch Normalization opt.step() # it will also learns the parameters in Batch Normalization
plt.ioff() plt.ioff()