update
This commit is contained in:
@ -20,7 +20,7 @@ torch.manual_seed(1) # reproducible
|
||||
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
|
||||
BATCH_SIZE = 50
|
||||
LR = 0.001 # learning rate
|
||||
DOWNLOAD_MNIST = False
|
||||
DOWNLOAD_MNIST = True # set to False if you have downloaded
|
||||
|
||||
|
||||
# Mnist digits dataset
|
||||
|
||||
@ -23,7 +23,7 @@ BATCH_SIZE = 64
|
||||
TIME_STEP = 28 # rnn time step / image height
|
||||
INPUT_SIZE = 28 # rnn input size / image width
|
||||
LR = 0.01 # learning rate
|
||||
DOWNLOAD_MNIST = False # set to True if haven't download the data
|
||||
DOWNLOAD_MNIST = True # set to True if haven't download the data
|
||||
|
||||
|
||||
# Mnist digital dataset
|
||||
|
||||
@ -39,9 +39,9 @@ train_data = torchvision.datasets.MNIST(
|
||||
# plot one example
|
||||
print(train_data.train_data.size()) # (60000, 28, 28)
|
||||
print(train_data.train_labels.size()) # (60000)
|
||||
# plt.imshow(train_data.train_data[2].numpy(), cmap='gray')
|
||||
# plt.title('%i' % train_data.train_labels[2])
|
||||
# plt.show()
|
||||
plt.imshow(train_data.train_data[2].numpy(), cmap='gray')
|
||||
plt.title('%i' % train_data.train_labels[2])
|
||||
plt.show()
|
||||
|
||||
# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
|
||||
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
|
||||
|
||||
@ -46,7 +46,7 @@ class DQN(object):
|
||||
def __init__(self):
|
||||
self.eval_net, self.target_net = Net(), Net()
|
||||
|
||||
self.learn_step_counter = 0 # for target updateing
|
||||
self.learn_step_counter = 0 # for target updating
|
||||
self.memory_counter = 0 # for storing memory
|
||||
self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # initialize memory
|
||||
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
|
||||
@ -100,7 +100,6 @@ for i_episode in range(400):
|
||||
ep_r = 0
|
||||
while True:
|
||||
env.render()
|
||||
|
||||
a = dqn.choose_action(s)
|
||||
|
||||
# take action
|
||||
@ -112,7 +111,6 @@ for i_episode in range(400):
|
||||
r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
|
||||
r = r1 + r2
|
||||
|
||||
# store experience
|
||||
dqn.store_transition(s, a, r, s_)
|
||||
|
||||
ep_r += r
|
||||
@ -120,10 +118,8 @@ for i_episode in range(400):
|
||||
dqn.learn()
|
||||
if done:
|
||||
print('Ep: ', i_episode,
|
||||
'| Ep_r: ', round(ep_r, 2),
|
||||
)
|
||||
'| Ep_r: ', round(ep_r, 2))
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
s = s_
|
||||
@ -20,7 +20,7 @@ BATCH_SIZE = 64
|
||||
TIME_STEP = 5 # rnn time step / image height
|
||||
INPUT_SIZE = 1 # rnn input size / image width
|
||||
LR = 0.02 # learning rate
|
||||
DOWNLOAD_MNIST = False # set to True if haven't download the data
|
||||
DOWNLOAD_MNIST = True # set to False if have downloaded the data
|
||||
|
||||
|
||||
class RNN(nn.Module):
|
||||
|
||||
@ -147,7 +147,7 @@ for epoch in range(EPOCH):
|
||||
loss = loss_func(pred, b_y)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step() # it will also learn the parameters in Batch Normalization
|
||||
opt.step() # it will also learns the parameters in Batch Normalization
|
||||
|
||||
|
||||
plt.ioff()
|
||||
|
||||
Reference in New Issue
Block a user