This commit is contained in:
Morvan Zhou
2017-05-11 11:47:28 +10:00
parent 8b3004daba
commit 29da380ec2

View File

@ -16,7 +16,7 @@ import matplotlib.pyplot as plt
torch.manual_seed(1) # reproducible torch.manual_seed(1) # reproducible
# Hyper Parameters # Hyper Parameters
TIME_STEP = 5 # rnn time step TIME_STEP = 10 # rnn time step
INPUT_SIZE = 1 # rnn input size INPUT_SIZE = 1 # rnn input size
LR = 0.02 # learning rate LR = 0.02 # learning rate
@ -69,7 +69,7 @@ plt.show()
for step in range(60): for step in range(60):
start, end = step * np.pi, (step+1)*np.pi # time steps start, end = step * np.pi, (step+1)*np.pi # time steps
# use sin predicts cos # use sin predicts cos
steps = np.linspace(start, end, 10, dtype=np.float32) steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
x_np = np.sin(steps) # float32 for converting torch FloatTensor x_np = np.sin(steps) # float32 for converting torch FloatTensor
y_np = np.cos(steps) y_np = np.cos(steps)