update to torch 0.4
This commit is contained in:
@ -37,12 +37,12 @@ class Net(torch.nn.Module):
|
|||||||
net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network
|
net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network
|
||||||
print(net) # net architecture
|
print(net) # net architecture
|
||||||
|
|
||||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
|
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
|
||||||
loss_func = torch.nn.MSELoss() # this is for regression mean squared loss
|
loss_func = torch.nn.MSELoss() # this is for regression mean squared loss
|
||||||
|
|
||||||
plt.ion() # something about plotting
|
plt.ion() # something about plotting
|
||||||
|
|
||||||
for t in range(100):
|
for t in range(200):
|
||||||
prediction = net(x) # input x and predict based on x
|
prediction = net(x) # input x and predict based on x
|
||||||
|
|
||||||
loss = loss_func(prediction, y) # must be (1. nn output, 2. target)
|
loss = loss_func(prediction, y) # must be (1. nn output, 2. target)
|
||||||
|
|||||||
Reference in New Issue
Block a user