update
This commit is contained in:
@ -45,7 +45,7 @@ class RNN(nn.Module):
|
|||||||
def forward(self, x, h_state):
|
def forward(self, x, h_state):
|
||||||
# x (batch, time_step, input_size)
|
# x (batch, time_step, input_size)
|
||||||
# h_state (n_layers, batch, hidden_size)
|
# h_state (n_layers, batch, hidden_size)
|
||||||
# r_out (batch, time_step, output_size)
|
# r_out (batch, time_step, hidden_size)
|
||||||
r_out, h_state = self.rnn(x, h_state)
|
r_out, h_state = self.rnn(x, h_state)
|
||||||
|
|
||||||
outs = [] # save all predictions
|
outs = [] # save all predictions
|
||||||
|
|||||||
Reference in New Issue
Block a user