diff --git a/tutorial-contents/504_batch_normalization.py b/tutorial-contents/504_batch_normalization.py index ee30ca6..c25844e 100644 --- a/tutorial-contents/504_batch_normalization.py +++ b/tutorial-contents/504_batch_normalization.py @@ -11,7 +11,6 @@ import torch from torch import nn from torch.nn import init import torch.utils.data as Data -import torch.nn.functional as F import matplotlib.pyplot as plt import numpy as np @@ -24,7 +23,7 @@ BATCH_SIZE = 64 EPOCH = 12 LR = 0.03 N_HIDDEN = 8 -ACTIVATION = F.tanh +ACTIVATION = torch.tanh B_INIT = -0.2 # use a bad bias constant initializer # training data @@ -48,6 +47,7 @@ train_loader = Data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shu plt.scatter(train_x.numpy(), train_y.numpy(), c='#FF9359', s=50, alpha=0.2, label='train') plt.legend(loc='upper left') + class Net(nn.Module): def __init__(self, batch_normalization=False): super(Net, self).__init__() @@ -89,20 +89,20 @@ class Net(nn.Module): nets = [Net(batch_normalization=False), Net(batch_normalization=True)] -print(*nets) # print net architecture +# print(*nets) # print net architecture opts = [torch.optim.Adam(net.parameters(), lr=LR) for net in nets] loss_func = torch.nn.MSELoss() -f, axs = plt.subplots(4, N_HIDDEN+1, figsize=(10, 5)) -plt.ion() # something about plotting -plt.show() + def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn): - for i, (ax_pa, ax_pa_bn, ax, ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])): + for i, (ax_pa, ax_pa_bn, ax, ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])): [a.clear() for a in [ax_pa, ax_pa_bn, ax, ax_bn]] - if i == 0: p_range = (-7, 10);the_range = (-7, 10) - else:p_range = (-4, 4);the_range = (-1, 1) + if i == 0: + p_range = (-7, 10);the_range = (-7, 10) + else: + p_range = (-4, 4);the_range = (-1, 1) ax_pa.set_title('L' + str(i)) ax_pa.hist(pre_ac[i].data.numpy().ravel(), bins=10, range=p_range, color='#FF9359', alpha=0.5);ax_pa_bn.hist(pre_ac_bn[i].data.numpy().ravel(), bins=10, range=p_range, color='#74BCFF', alpha=0.5) ax.hist(l_in[i].data.numpy().ravel(), bins=10, range=the_range, color='#FF9359');ax_bn.hist(l_in_bn[i].data.numpy().ravel(), bins=10, range=the_range, color='#74BCFF') @@ -111,44 +111,50 @@ def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn): axs[0, 0].set_ylabel('PreAct');axs[1, 0].set_ylabel('BN PreAct');axs[2, 0].set_ylabel('Act');axs[3, 0].set_ylabel('BN Act') plt.pause(0.01) -# training -losses = [[], []] # recode loss for two networks -for epoch in range(EPOCH): - print('Epoch: ', epoch) - layer_inputs, pre_acts = [], [] - for net, l in zip(nets, losses): - net.eval() # set eval mode to fix moving_mean and moving_var - pred, layer_input, pre_act = net(test_x) - l.append(loss_func(pred, test_y).data[0]) - layer_inputs.append(layer_input) - pre_acts.append(pre_act) - net.train() # free moving_mean and moving_var - plot_histogram(*layer_inputs, *pre_acts) # plot histogram - for step, (b_x, b_y) in enumerate(train_loader): - for net, opt in zip(nets, opts): # train for each network - pred, _, _ = net(b_x) - loss = loss_func(pred, b_y) - opt.zero_grad() - loss.backward() - opt.step() # it will also learns the parameters in Batch Normalization +if __name__ == "__main__": + f, axs = plt.subplots(4, N_HIDDEN + 1, figsize=(10, 5)) + plt.ion() # something about plotting + plt.show() + # training + losses = [[], []] # recode loss for two networks -plt.ioff() + for epoch in range(EPOCH): + print('Epoch: ', epoch) + layer_inputs, pre_acts = [], [] + for net, l in zip(nets, losses): + net.eval() # set eval mode to fix moving_mean and moving_var + pred, layer_input, pre_act = net(test_x) + l.append(loss_func(pred, test_y).data.item()) + layer_inputs.append(layer_input) + pre_acts.append(pre_act) + net.train() # free moving_mean and moving_var + plot_histogram(*layer_inputs, *pre_acts) # plot histogram -# plot training loss -plt.figure(2) -plt.plot(losses[0], c='#FF9359', lw=3, label='Original') -plt.plot(losses[1], c='#74BCFF', lw=3, label='Batch Normalization') -plt.xlabel('step');plt.ylabel('test loss');plt.ylim((0, 2000));plt.legend(loc='best') + for step, (b_x, b_y) in enumerate(train_loader): + for net, opt in zip(nets, opts): # train for each network + pred, _, _ = net(b_x) + loss = loss_func(pred, b_y) + opt.zero_grad() + loss.backward() + opt.step() # it will also learns the parameters in Batch Normalization -# evaluation -# set net to eval mode to freeze the parameters in batch normalization layers -[net.eval() for net in nets] # set eval mode to fix moving_mean and moving_var -preds = [net(test_x)[0] for net in nets] -plt.figure(3) -plt.plot(test_x.data.numpy(), preds[0].data.numpy(), c='#FF9359', lw=4, label='Original') -plt.plot(test_x.data.numpy(), preds[1].data.numpy(), c='#74BCFF', lw=4, label='Batch Normalization') -plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='r', s=50, alpha=0.2, label='train') -plt.legend(loc='best') -plt.show() + plt.ioff() + + # plot training loss + plt.figure(2) + plt.plot(losses[0], c='#FF9359', lw=3, label='Original') + plt.plot(losses[1], c='#74BCFF', lw=3, label='Batch Normalization') + plt.xlabel('step');plt.ylabel('test loss');plt.ylim((0, 2000));plt.legend(loc='best') + + # evaluation + # set net to eval mode to freeze the parameters in batch normalization layers + [net.eval() for net in nets] # set eval mode to fix moving_mean and moving_var + preds = [net(test_x)[0] for net in nets] + plt.figure(3) + plt.plot(test_x.data.numpy(), preds[0].data.numpy(), c='#FF9359', lw=4, label='Original') + plt.plot(test_x.data.numpy(), preds[1].data.numpy(), c='#74BCFF', lw=4, label='Batch Normalization') + plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='r', s=50, alpha=0.2, label='train') + plt.legend(loc='best') + plt.show()