update
This commit is contained in:
@ -18,26 +18,26 @@ import matplotlib.pyplot as plt
|
||||
torch.manual_seed(1) # reproducible
|
||||
|
||||
# Hyper Parameters
|
||||
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
|
||||
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
|
||||
BATCH_SIZE = 64
|
||||
TIME_STEP = 28 # rnn time step / image height
|
||||
INPUT_SIZE = 28 # rnn input size / image width
|
||||
LR = 0.01 # learning rate
|
||||
DOWNLOAD_MNIST = False # set to True if haven't download the data
|
||||
TIME_STEP = 28 # rnn time step / image height
|
||||
INPUT_SIZE = 28 # rnn input size / image width
|
||||
LR = 0.01 # learning rate
|
||||
DOWNLOAD_MNIST = True # set to True if haven't download the data
|
||||
|
||||
|
||||
# Mnist digital dataset
|
||||
train_data = dsets.MNIST(
|
||||
root='./mnist/',
|
||||
train=True, # this is training data
|
||||
transform=transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
|
||||
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
|
||||
download=DOWNLOAD_MNIST, # download it if you don't have it
|
||||
train=True, # this is training data
|
||||
transform=transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
|
||||
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
|
||||
download=DOWNLOAD_MNIST, # download it if you don't have it
|
||||
)
|
||||
|
||||
# plot one example
|
||||
print(train_data.train_data.size()) # (60000, 28, 28)
|
||||
print(train_data.train_labels.size()) # (60000)
|
||||
print(train_data.train_data.size()) # (60000, 28, 28)
|
||||
print(train_data.train_labels.size()) # (60000)
|
||||
plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
|
||||
plt.title('%i' % train_data.train_labels[0])
|
||||
plt.show()
|
||||
@ -55,11 +55,11 @@ class RNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(RNN, self).__init__()
|
||||
|
||||
self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns
|
||||
self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns
|
||||
input_size=28,
|
||||
hidden_size=64, # rnn hidden unit
|
||||
num_layers=1, # number of rnn layer
|
||||
batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
|
||||
hidden_size=64, # rnn hidden unit
|
||||
num_layers=1, # number of rnn layer
|
||||
batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
|
||||
)
|
||||
|
||||
self.out = nn.Linear(64, 10)
|
||||
@ -80,22 +80,22 @@ rnn = RNN()
|
||||
print(rnn)
|
||||
|
||||
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all cnn parameters
|
||||
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
|
||||
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
|
||||
|
||||
# training and testing
|
||||
for epoch in range(EPOCH):
|
||||
for step, (x, y) in enumerate(train_loader): # gives batch data
|
||||
b_x = Variable(x.view(-1, 28, 28)) # reshape x to (batch, time_step, input_size)
|
||||
b_y = Variable(y) # batch y
|
||||
for step, (x, y) in enumerate(train_loader): # gives batch data
|
||||
b_x = Variable(x.view(-1, 28, 28)) # reshape x to (batch, time_step, input_size)
|
||||
b_y = Variable(y) # batch y
|
||||
|
||||
output = rnn(b_x) # rnn output
|
||||
loss = loss_func(output, b_y) # cross entropy loss
|
||||
optimizer.zero_grad() # clear gradients for this training step
|
||||
loss.backward() # backpropagation, compute gradients
|
||||
optimizer.step() # apply gradients
|
||||
output = rnn(b_x) # rnn output
|
||||
loss = loss_func(output, b_y) # cross entropy loss
|
||||
optimizer.zero_grad() # clear gradients for this training step
|
||||
loss.backward() # backpropagation, compute gradients
|
||||
optimizer.step() # apply gradients
|
||||
|
||||
if step % 50 == 0:
|
||||
test_output = rnn(test_x) # (samples, time_step, input_size)
|
||||
test_output = rnn(test_x) # (samples, time_step, input_size)
|
||||
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
|
||||
accuracy = sum(pred_y == test_y) / test_y.size
|
||||
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0], '| test accuracy: %.2f' % accuracy)
|
||||
|
||||
Reference in New Issue
Block a user