update for new version of torch

This commit is contained in:
morvanzhou
2018-11-07 16:08:38 +08:00
parent ce55cc9446
commit 906cf71b6f

View File

@ -11,7 +11,6 @@ import torch
from torch import nn from torch import nn
from torch.nn import init from torch.nn import init
import torch.utils.data as Data import torch.utils.data as Data
import torch.nn.functional as F
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -24,7 +23,7 @@ BATCH_SIZE = 64
EPOCH = 12 EPOCH = 12
LR = 0.03 LR = 0.03
N_HIDDEN = 8 N_HIDDEN = 8
ACTIVATION = F.tanh ACTIVATION = torch.tanh
B_INIT = -0.2 # use a bad bias constant initializer B_INIT = -0.2 # use a bad bias constant initializer
# training data # training data
@ -48,6 +47,7 @@ train_loader = Data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shu
plt.scatter(train_x.numpy(), train_y.numpy(), c='#FF9359', s=50, alpha=0.2, label='train') plt.scatter(train_x.numpy(), train_y.numpy(), c='#FF9359', s=50, alpha=0.2, label='train')
plt.legend(loc='upper left') plt.legend(loc='upper left')
class Net(nn.Module): class Net(nn.Module):
def __init__(self, batch_normalization=False): def __init__(self, batch_normalization=False):
super(Net, self).__init__() super(Net, self).__init__()
@ -89,20 +89,20 @@ class Net(nn.Module):
nets = [Net(batch_normalization=False), Net(batch_normalization=True)] nets = [Net(batch_normalization=False), Net(batch_normalization=True)]
print(*nets) # print net architecture # print(*nets) # print net architecture
opts = [torch.optim.Adam(net.parameters(), lr=LR) for net in nets] opts = [torch.optim.Adam(net.parameters(), lr=LR) for net in nets]
loss_func = torch.nn.MSELoss() loss_func = torch.nn.MSELoss()
f, axs = plt.subplots(4, N_HIDDEN+1, figsize=(10, 5))
plt.ion() # something about plotting
plt.show()
def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn): def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):
for i, (ax_pa, ax_pa_bn, ax, ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])): for i, (ax_pa, ax_pa_bn, ax, ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])):
[a.clear() for a in [ax_pa, ax_pa_bn, ax, ax_bn]] [a.clear() for a in [ax_pa, ax_pa_bn, ax, ax_bn]]
if i == 0: p_range = (-7, 10);the_range = (-7, 10) if i == 0:
else:p_range = (-4, 4);the_range = (-1, 1) p_range = (-7, 10);the_range = (-7, 10)
else:
p_range = (-4, 4);the_range = (-1, 1)
ax_pa.set_title('L' + str(i)) ax_pa.set_title('L' + str(i))
ax_pa.hist(pre_ac[i].data.numpy().ravel(), bins=10, range=p_range, color='#FF9359', alpha=0.5);ax_pa_bn.hist(pre_ac_bn[i].data.numpy().ravel(), bins=10, range=p_range, color='#74BCFF', alpha=0.5) ax_pa.hist(pre_ac[i].data.numpy().ravel(), bins=10, range=p_range, color='#FF9359', alpha=0.5);ax_pa_bn.hist(pre_ac_bn[i].data.numpy().ravel(), bins=10, range=p_range, color='#74BCFF', alpha=0.5)
ax.hist(l_in[i].data.numpy().ravel(), bins=10, range=the_range, color='#FF9359');ax_bn.hist(l_in_bn[i].data.numpy().ravel(), bins=10, range=the_range, color='#74BCFF') ax.hist(l_in[i].data.numpy().ravel(), bins=10, range=the_range, color='#FF9359');ax_bn.hist(l_in_bn[i].data.numpy().ravel(), bins=10, range=the_range, color='#74BCFF')
@ -111,15 +111,22 @@ def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):
axs[0, 0].set_ylabel('PreAct');axs[1, 0].set_ylabel('BN PreAct');axs[2, 0].set_ylabel('Act');axs[3, 0].set_ylabel('BN Act') axs[0, 0].set_ylabel('PreAct');axs[1, 0].set_ylabel('BN PreAct');axs[2, 0].set_ylabel('Act');axs[3, 0].set_ylabel('BN Act')
plt.pause(0.01) plt.pause(0.01)
if __name__ == "__main__":
f, axs = plt.subplots(4, N_HIDDEN + 1, figsize=(10, 5))
plt.ion() # something about plotting
plt.show()
# training # training
losses = [[], []] # recode loss for two networks losses = [[], []] # recode loss for two networks
for epoch in range(EPOCH): for epoch in range(EPOCH):
print('Epoch: ', epoch) print('Epoch: ', epoch)
layer_inputs, pre_acts = [], [] layer_inputs, pre_acts = [], []
for net, l in zip(nets, losses): for net, l in zip(nets, losses):
net.eval() # set eval mode to fix moving_mean and moving_var net.eval() # set eval mode to fix moving_mean and moving_var
pred, layer_input, pre_act = net(test_x) pred, layer_input, pre_act = net(test_x)
l.append(loss_func(pred, test_y).data[0]) l.append(loss_func(pred, test_y).data.item())
layer_inputs.append(layer_input) layer_inputs.append(layer_input)
pre_acts.append(pre_act) pre_acts.append(pre_act)
net.train() # free moving_mean and moving_var net.train() # free moving_mean and moving_var
@ -133,7 +140,6 @@ for epoch in range(EPOCH):
loss.backward() loss.backward()
opt.step() # it will also learns the parameters in Batch Normalization opt.step() # it will also learns the parameters in Batch Normalization
plt.ioff() plt.ioff()
# plot training loss # plot training loss