This commit is contained in:
Morvan Zhou
2017-05-11 16:53:32 +10:00
parent 09ba66ff32
commit 3e3b0e9d10

View File

@ -45,7 +45,7 @@ class RNN(nn.Module):
def forward(self, x, h_state):
# x (batch, time_step, input_size)
# h_state (n_layers, batch, hidden_size)
# r_out (batch, time_step, output_size)
# r_out (batch, time_step, hidden_size)
r_out, h_state = self.rnn(x, h_state)
outs = [] # save all predictions