diff --git a/tutorial-contents/301_regression.py b/tutorial-contents/301_regression.py index 4c67c90..3e28134 100644 --- a/tutorial-contents/301_regression.py +++ b/tutorial-contents/301_regression.py @@ -37,12 +37,12 @@ class Net(torch.nn.Module): net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network print(net) # net architecture -optimizer = torch.optim.SGD(net.parameters(), lr=0.5) +optimizer = torch.optim.SGD(net.parameters(), lr=0.2) loss_func = torch.nn.MSELoss() # this is for regression mean squared loss plt.ion() # something about plotting -for t in range(100): +for t in range(200): prediction = net(x) # input x and predict based on x loss = loss_func(prediction, y) # must be (1. nn output, 2. target)