update
This commit is contained in:
@ -16,7 +16,7 @@ import matplotlib.pyplot as plt
|
|||||||
torch.manual_seed(1) # reproducible
|
torch.manual_seed(1) # reproducible
|
||||||
|
|
||||||
# Hyper Parameters
|
# Hyper Parameters
|
||||||
TIME_STEP = 5 # rnn time step
|
TIME_STEP = 10 # rnn time step
|
||||||
INPUT_SIZE = 1 # rnn input size
|
INPUT_SIZE = 1 # rnn input size
|
||||||
LR = 0.02 # learning rate
|
LR = 0.02 # learning rate
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ plt.show()
|
|||||||
for step in range(60):
|
for step in range(60):
|
||||||
start, end = step * np.pi, (step+1)*np.pi # time steps
|
start, end = step * np.pi, (step+1)*np.pi # time steps
|
||||||
# use sin predicts cos
|
# use sin predicts cos
|
||||||
steps = np.linspace(start, end, 10, dtype=np.float32)
|
steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
|
||||||
x_np = np.sin(steps) # float32 for converting torch FloatTensor
|
x_np = np.sin(steps) # float32 for converting torch FloatTensor
|
||||||
y_np = np.cos(steps)
|
y_np = np.cos(steps)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user