This commit is contained in:
Morvan Zhou
2017-05-10 20:18:33 +10:00
parent e2f10e0ce1
commit 3d8df0c297

View File

@ -56,7 +56,7 @@ class RNN(nn.Module):
super(RNN, self).__init__()
self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns
input_size=28,
input_size=INPUT_SIZE,
hidden_size=64, # rnn hidden unit
num_layers=1, # number of rnn layer
batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)