This commit is contained in:
Morvan Zhou
2017-08-09 01:51:16 +10:00
parent 1b5fa9f7aa
commit e7599c0f14

View File

@ -53,6 +53,10 @@ class RNN(nn.Module):
outs.append(self.out(r_out[:, time_step, :])) outs.append(self.out(r_out[:, time_step, :]))
return torch.stack(outs, dim=1), h_state return torch.stack(outs, dim=1), h_state
# instead, for simplicity, you can replace above codes by follows
# r_out = r_out.view(-1, 32)
# outs = self.out(r_out)
# return outs, h_state
rnn = RNN() rnn = RNN()
print(rnn) print(rnn)