This commit is contained in:
Morvan Zhou
2017-05-06 02:31:19 +10:00
parent 42674bdb94
commit 79b928a27f

View File

@ -51,7 +51,7 @@ plt.show()
for t in range(100): for t in range(100):
out = net(x) # input x and predict based on x out = net(x) # input x and predict based on x
loss = loss_func(out, y) # must be (1. nn output, 2. target), the target label is not one-hotted loss = loss_func(out, y) # must be (1. nn output, 2. target), the target label is NOT one-hotted
optimizer.zero_grad() # clear gradients for next train optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients loss.backward() # backpropagation, compute gradients