update to torch 0.4

This commit is contained in:
Morvan Zhou
2018-05-30 00:46:22 +08:00
parent 538ba18975
commit 7e7c9bb383

View File

@ -37,12 +37,12 @@ class Net(torch.nn.Module):
net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network
print(net) # net architecture
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss() # this is for regression mean squared loss
plt.ion() # something about plotting
for t in range(100):
for t in range(200):
prediction = net(x) # input x and predict based on x
loss = loss_func(prediction, y) # must be (1. nn output, 2. target)