update
This commit is contained in:
@ -45,7 +45,7 @@ class RNN(nn.Module):
|
||||
def forward(self, x, h_state):
|
||||
# x (batch, time_step, input_size)
|
||||
# h_state (n_layers, batch, hidden_size)
|
||||
# r_out (batch, time_step, output_size)
|
||||
# r_out (batch, time_step, hidden_size)
|
||||
r_out, h_state = self.rnn(x, h_state)
|
||||
|
||||
outs = [] # save all predictions
|
||||
|
||||
Reference in New Issue
Block a user