update
This commit is contained in:
@ -99,30 +99,17 @@ loss_func = torch.nn.MSELoss()
|
||||
f, axs = plt.subplots(4, N_HIDDEN+1, figsize=(10, 5))
|
||||
plt.ion() # something about plotting
|
||||
plt.show()
|
||||
|
||||
def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):
|
||||
for i, (ax_pa, ax_pa_bn, ax, ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])):
|
||||
[a.clear() for a in [ax_pa, ax_pa_bn, ax, ax_bn]]
|
||||
if i == 0:
|
||||
p_range = (-7, 10)
|
||||
the_range = (-7, 10)
|
||||
else:
|
||||
p_range = (-4, 4)
|
||||
the_range = (-1, 1)
|
||||
if i == 0: p_range = (-7, 10);the_range = (-7, 10)
|
||||
else:p_range = (-4, 4);the_range = (-1, 1)
|
||||
ax_pa.set_title('L' + str(i))
|
||||
ax_pa.hist(pre_ac[i].data.numpy().ravel(), bins=10, range=p_range, color='#FF9359', alpha=0.5)
|
||||
ax_pa_bn.hist(pre_ac_bn[i].data.numpy().ravel(), bins=10, range=p_range, color='#74BCFF', alpha=0.5)
|
||||
ax.hist(l_in[i].data.numpy().ravel(), bins=10, range=the_range, color='#FF9359')
|
||||
ax_bn.hist(l_in_bn[i].data.numpy().ravel(), bins=10, range=the_range, color='#74BCFF')
|
||||
for a in [ax_pa, ax, ax_pa_bn, ax_bn]:
|
||||
a.set_yticks(())
|
||||
a.set_xticks(())
|
||||
ax_pa_bn.set_xticks(p_range)
|
||||
ax_bn.set_xticks(the_range)
|
||||
axs[0, 0].set_ylabel('PreAct')
|
||||
axs[1, 0].set_ylabel('BN PreAct')
|
||||
axs[2, 0].set_ylabel('Act')
|
||||
axs[3, 0].set_ylabel('BN Act')
|
||||
ax_pa.hist(pre_ac[i].data.numpy().ravel(), bins=10, range=p_range, color='#FF9359', alpha=0.5);ax_pa_bn.hist(pre_ac_bn[i].data.numpy().ravel(), bins=10, range=p_range, color='#74BCFF', alpha=0.5)
|
||||
ax.hist(l_in[i].data.numpy().ravel(), bins=10, range=the_range, color='#FF9359');ax_bn.hist(l_in_bn[i].data.numpy().ravel(), bins=10, range=the_range, color='#74BCFF')
|
||||
for a in [ax_pa, ax, ax_pa_bn, ax_bn]: a.set_yticks(());a.set_xticks(())
|
||||
ax_pa_bn.set_xticks(p_range);ax_bn.set_xticks(the_range)
|
||||
axs[0, 0].set_ylabel('PreAct');axs[1, 0].set_ylabel('BN PreAct');axs[2, 0].set_ylabel('Act');axs[3, 0].set_ylabel('BN Act')
|
||||
plt.pause(0.01)
|
||||
|
||||
# training
|
||||
@ -155,10 +142,7 @@ plt.ioff()
|
||||
plt.figure(2)
|
||||
plt.plot(losses[0], c='#FF9359', lw=3, label='Original')
|
||||
plt.plot(losses[1], c='#74BCFF', lw=3, label='Batch Normalization')
|
||||
plt.xlabel('step')
|
||||
plt.ylabel('test loss')
|
||||
plt.ylim((0, 2000))
|
||||
plt.legend(loc='best')
|
||||
plt.xlabel('step');plt.ylabel('test loss');plt.ylim((0, 2000));plt.legend(loc='best')
|
||||
|
||||
# evaluation
|
||||
# set net to eval mode to freeze the parameters in batch normalization layers
|
||||
|
||||
Reference in New Issue
Block a user