diff --git a/tutorial-contents/402_RNN_classifier.py b/tutorial-contents/402_RNN_classifier.py index 679a255..1909387 100644 --- a/tutorial-contents/402_RNN_classifier.py +++ b/tutorial-contents/402_RNN_classifier.py @@ -56,7 +56,7 @@ class RNN(nn.Module): super(RNN, self).__init__() self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns - input_size=28, + input_size=INPUT_SIZE, hidden_size=64, # rnn hidden unit num_layers=1, # number of rnn layer batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)