update to torch 0.4
This commit is contained in:
@ -3,13 +3,12 @@ View more, visit my tutorial page: https://morvanzhou.github.io/tutorials/
|
||||
My Youtube Channel: https://www.youtube.com/user/MorvanZhou
|
||||
|
||||
Dependencies:
|
||||
torch: 0.1.11
|
||||
torch: 0.4
|
||||
matplotlib
|
||||
numpy
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@ -69,19 +68,19 @@ h_state = None # for initial hidden state
|
||||
plt.figure(1, figsize=(12, 5))
|
||||
plt.ion() # continuously plot
|
||||
|
||||
for step in range(60):
|
||||
for step in range(100):
|
||||
start, end = step * np.pi, (step+1)*np.pi # time range
|
||||
# use sin predicts cos
|
||||
steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
|
||||
x_np = np.sin(steps) # float32 for converting torch FloatTensor
|
||||
y_np = np.cos(steps)
|
||||
|
||||
x = Variable(torch.from_numpy(x_np[np.newaxis, :, np.newaxis])) # shape (batch, time_step, input_size)
|
||||
y = Variable(torch.from_numpy(y_np[np.newaxis, :, np.newaxis]))
|
||||
x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis]) # shape (batch, time_step, input_size)
|
||||
y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
|
||||
|
||||
prediction, h_state = rnn(x, h_state) # rnn output
|
||||
# !! next step is important !!
|
||||
h_state = Variable(h_state.data) # repack the hidden state, break the connection from last iteration
|
||||
h_state = h_state.data # repack the hidden state, break the connection from last iteration
|
||||
|
||||
loss = loss_func(prediction, y) # cross entropy loss
|
||||
optimizer.zero_grad() # clear gradients for this training step
|
||||
|
||||
Reference in New Issue
Block a user