This commit is contained in:
Morvan Zhou
2017-05-08 12:48:29 +10:00
committed by Morvan Zhou
parent 468039f49c
commit b212b3e026
7 changed files with 80 additions and 84 deletions

View File

@ -58,7 +58,7 @@ class Net(nn.Module):
self.bns = []
self.bn_input = nn.BatchNorm1d(1, momentum=0.5) # for input data
for i in range(N_HIDDEN): # build hidden layers and BN layers
for i in range(N_HIDDEN): # build hidden layers and BN layers
input_size = 1 if i == 0 else 10
fc = nn.Linear(input_size, 10)
setattr(self, 'fc%i' % i, fc) # IMPORTANT set layer to the Module
@ -83,7 +83,7 @@ class Net(nn.Module):
for i in range(N_HIDDEN):
x = self.fcs[i](x)
pre_activation.append(x)
if self.do_bn: x = self.bns[i](x) # batch normalization
if self.do_bn: x = self.bns[i](x) # batch normalization
x = ACTIVATION(x)
layer_input.append(x)
out = self.predict(x)
@ -147,7 +147,7 @@ for epoch in range(EPOCH):
loss = loss_func(pred, b_y)
opt.zero_grad()
loss.backward()
opt.step() # it will also learn the parameters in Batch Normalization
opt.step() # it will also learns the parameters in Batch Normalization
plt.ioff()