update
This commit is contained in:
@ -20,7 +20,7 @@ torch.manual_seed(1) # reproducible
|
|||||||
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
|
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
|
||||||
BATCH_SIZE = 50
|
BATCH_SIZE = 50
|
||||||
LR = 0.001 # learning rate
|
LR = 0.001 # learning rate
|
||||||
DOWNLOAD_MNIST = False
|
DOWNLOAD_MNIST = True # set to False if you have downloaded
|
||||||
|
|
||||||
|
|
||||||
# Mnist digits dataset
|
# Mnist digits dataset
|
||||||
|
|||||||
@ -23,7 +23,7 @@ BATCH_SIZE = 64
|
|||||||
TIME_STEP = 28 # rnn time step / image height
|
TIME_STEP = 28 # rnn time step / image height
|
||||||
INPUT_SIZE = 28 # rnn input size / image width
|
INPUT_SIZE = 28 # rnn input size / image width
|
||||||
LR = 0.01 # learning rate
|
LR = 0.01 # learning rate
|
||||||
DOWNLOAD_MNIST = False # set to True if haven't download the data
|
DOWNLOAD_MNIST = True # set to True if haven't download the data
|
||||||
|
|
||||||
|
|
||||||
# Mnist digital dataset
|
# Mnist digital dataset
|
||||||
|
|||||||
@ -39,9 +39,9 @@ train_data = torchvision.datasets.MNIST(
|
|||||||
# plot one example
|
# plot one example
|
||||||
print(train_data.train_data.size()) # (60000, 28, 28)
|
print(train_data.train_data.size()) # (60000, 28, 28)
|
||||||
print(train_data.train_labels.size()) # (60000)
|
print(train_data.train_labels.size()) # (60000)
|
||||||
# plt.imshow(train_data.train_data[2].numpy(), cmap='gray')
|
plt.imshow(train_data.train_data[2].numpy(), cmap='gray')
|
||||||
# plt.title('%i' % train_data.train_labels[2])
|
plt.title('%i' % train_data.train_labels[2])
|
||||||
# plt.show()
|
plt.show()
|
||||||
|
|
||||||
# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
|
# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
|
||||||
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
|
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
|||||||
@ -46,7 +46,7 @@ class DQN(object):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.eval_net, self.target_net = Net(), Net()
|
self.eval_net, self.target_net = Net(), Net()
|
||||||
|
|
||||||
self.learn_step_counter = 0 # for target updateing
|
self.learn_step_counter = 0 # for target updating
|
||||||
self.memory_counter = 0 # for storing memory
|
self.memory_counter = 0 # for storing memory
|
||||||
self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # initialize memory
|
self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # initialize memory
|
||||||
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
|
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
|
||||||
@ -100,7 +100,6 @@ for i_episode in range(400):
|
|||||||
ep_r = 0
|
ep_r = 0
|
||||||
while True:
|
while True:
|
||||||
env.render()
|
env.render()
|
||||||
|
|
||||||
a = dqn.choose_action(s)
|
a = dqn.choose_action(s)
|
||||||
|
|
||||||
# take action
|
# take action
|
||||||
@ -112,7 +111,6 @@ for i_episode in range(400):
|
|||||||
r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
|
r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
|
||||||
r = r1 + r2
|
r = r1 + r2
|
||||||
|
|
||||||
# store experience
|
|
||||||
dqn.store_transition(s, a, r, s_)
|
dqn.store_transition(s, a, r, s_)
|
||||||
|
|
||||||
ep_r += r
|
ep_r += r
|
||||||
@ -120,10 +118,8 @@ for i_episode in range(400):
|
|||||||
dqn.learn()
|
dqn.learn()
|
||||||
if done:
|
if done:
|
||||||
print('Ep: ', i_episode,
|
print('Ep: ', i_episode,
|
||||||
'| Ep_r: ', round(ep_r, 2),
|
'| Ep_r: ', round(ep_r, 2))
|
||||||
)
|
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
|
|
||||||
s = s_
|
s = s_
|
||||||
@ -20,7 +20,7 @@ BATCH_SIZE = 64
|
|||||||
TIME_STEP = 5 # rnn time step / image height
|
TIME_STEP = 5 # rnn time step / image height
|
||||||
INPUT_SIZE = 1 # rnn input size / image width
|
INPUT_SIZE = 1 # rnn input size / image width
|
||||||
LR = 0.02 # learning rate
|
LR = 0.02 # learning rate
|
||||||
DOWNLOAD_MNIST = False # set to True if haven't download the data
|
DOWNLOAD_MNIST = True # set to False if have downloaded the data
|
||||||
|
|
||||||
|
|
||||||
class RNN(nn.Module):
|
class RNN(nn.Module):
|
||||||
|
|||||||
@ -147,7 +147,7 @@ for epoch in range(EPOCH):
|
|||||||
loss = loss_func(pred, b_y)
|
loss = loss_func(pred, b_y)
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
opt.step() # it will also learn the parameters in Batch Normalization
|
opt.step() # it will also learns the parameters in Batch Normalization
|
||||||
|
|
||||||
|
|
||||||
plt.ioff()
|
plt.ioff()
|
||||||
|
|||||||
Reference in New Issue
Block a user