edit
This commit is contained in:
@ -51,7 +51,7 @@ plt.show()
|
||||
|
||||
for t in range(100):
|
||||
out = net(x) # input x and predict based on x
|
||||
loss = loss_func(out, y) # must be (1. nn output, 2. target), the target label is not one-hotted
|
||||
loss = loss_func(out, y) # must be (1. nn output, 2. target), the target label is NOT one-hotted
|
||||
|
||||
optimizer.zero_grad() # clear gradients for next train
|
||||
loss.backward() # backpropagation, compute gradients
|
||||
|
||||
Reference in New Issue
Block a user