update
This commit is contained in:
@ -58,7 +58,7 @@ class Net(nn.Module):
|
||||
self.bns = []
|
||||
self.bn_input = nn.BatchNorm1d(1, momentum=0.5) # for input data
|
||||
|
||||
for i in range(N_HIDDEN): # build hidden layers and BN layers
|
||||
for i in range(N_HIDDEN): # build hidden layers and BN layers
|
||||
input_size = 1 if i == 0 else 10
|
||||
fc = nn.Linear(input_size, 10)
|
||||
setattr(self, 'fc%i' % i, fc) # IMPORTANT set layer to the Module
|
||||
@ -83,7 +83,7 @@ class Net(nn.Module):
|
||||
for i in range(N_HIDDEN):
|
||||
x = self.fcs[i](x)
|
||||
pre_activation.append(x)
|
||||
if self.do_bn: x = self.bns[i](x) # batch normalization
|
||||
if self.do_bn: x = self.bns[i](x) # batch normalization
|
||||
x = ACTIVATION(x)
|
||||
layer_input.append(x)
|
||||
out = self.predict(x)
|
||||
@ -147,7 +147,7 @@ for epoch in range(EPOCH):
|
||||
loss = loss_func(pred, b_y)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step() # it will also learn the parameters in Batch Normalization
|
||||
opt.step() # it will also learns the parameters in Batch Normalization
|
||||
|
||||
|
||||
plt.ioff()
|
||||
|
||||
Reference in New Issue
Block a user