update to torch 0.4
This commit is contained in:
@ -37,12 +37,12 @@ class Net(torch.nn.Module):
|
||||
net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network
|
||||
print(net) # net architecture
|
||||
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
|
||||
loss_func = torch.nn.MSELoss() # this is for regression mean squared loss
|
||||
|
||||
plt.ion() # something about plotting
|
||||
|
||||
for t in range(100):
|
||||
for t in range(200):
|
||||
prediction = net(x) # input x and predict based on x
|
||||
|
||||
loss = loss_func(prediction, y) # must be (1. nn output, 2. target)
|
||||
|
||||
Reference in New Issue
Block a user