edit
This commit is contained in:
@ -51,7 +51,7 @@ plt.show()
|
|||||||
|
|
||||||
for t in range(100):
|
for t in range(100):
|
||||||
out = net(x) # input x and predict based on x
|
out = net(x) # input x and predict based on x
|
||||||
loss = loss_func(out, y) # must be (1. nn output, 2. target), the target label is not one-hotted
|
loss = loss_func(out, y) # must be (1. nn output, 2. target), the target label is NOT one-hotted
|
||||||
|
|
||||||
optimizer.zero_grad() # clear gradients for next train
|
optimizer.zero_grad() # clear gradients for next train
|
||||||
loss.backward() # backpropagation, compute gradients
|
loss.backward() # backpropagation, compute gradients
|
||||||
|
|||||||
Reference in New Issue
Block a user