diff --git a/tutorial-contents/403_RNN_regressor.py b/tutorial-contents/403_RNN_regressor.py index c8b19ca..2010ded 100644 --- a/tutorial-contents/403_RNN_regressor.py +++ b/tutorial-contents/403_RNN_regressor.py @@ -55,7 +55,13 @@ class RNN(nn.Module): # instead, for simplicity, you can replace above codes by follows # r_out = r_out.view(-1, 32) # outs = self.out(r_out) + # outs = outs.view(-1, TIME_STEP, 1) # return outs, h_state + + # or even simpler, since nn.Linear can accept inputs of any dimension + # and returns outputs with same dimension except for the last + # outs = self.out(r_out) + # return outs rnn = RNN() print(rnn)