move files
This commit is contained in:
88
tutorial-contents/304_save_reload.py
Normal file
88
tutorial-contents/304_save_reload.py
Normal file
@ -0,0 +1,88 @@
|
||||
"""
|
||||
Know more, visit 莫烦Python: https://morvanzhou.github.io/tutorials/
|
||||
My Youtube Channel: https://www.youtube.com/user/MorvanZhou
|
||||
|
||||
Dependencies:
|
||||
torch: 0.1.11
|
||||
matplotlib
|
||||
"""
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
torch.manual_seed(1) # reproducible
|
||||
|
||||
# fake data
|
||||
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
|
||||
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
|
||||
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
|
||||
|
||||
|
||||
def save():
|
||||
# save net1
|
||||
net1 = torch.nn.Sequential(
|
||||
torch.nn.Linear(1, 10),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(10, 1)
|
||||
)
|
||||
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
|
||||
loss_func = torch.nn.MSELoss()
|
||||
|
||||
for t in range(100):
|
||||
prediction = net1(x)
|
||||
loss = loss_func(prediction, y)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# plot result
|
||||
plt.figure(1, figsize=(10, 3))
|
||||
plt.subplot(131)
|
||||
plt.title('Net1')
|
||||
plt.scatter(x.data.numpy(), y.data.numpy())
|
||||
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
|
||||
|
||||
# 2 ways to save the net
|
||||
torch.save(net1, 'net.pkl') # save entire net
|
||||
torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters
|
||||
|
||||
|
||||
def restore_net():
|
||||
# restore entire net1 to net2
|
||||
net2 = torch.load('net.pkl')
|
||||
prediction = net2(x)
|
||||
|
||||
# plot result
|
||||
plt.subplot(132)
|
||||
plt.title('Net2')
|
||||
plt.scatter(x.data.numpy(), y.data.numpy())
|
||||
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
|
||||
|
||||
|
||||
def restore_params():
|
||||
# restore only the parameters in net1 to net3
|
||||
net3 = torch.nn.Sequential(
|
||||
torch.nn.Linear(1, 10),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(10, 1)
|
||||
)
|
||||
|
||||
# copy net1's parameters into net3
|
||||
net3.load_state_dict(torch.load('net_params.pkl'))
|
||||
prediction = net3(x)
|
||||
|
||||
# plot result
|
||||
plt.subplot(133)
|
||||
plt.title('Net3')
|
||||
plt.scatter(x.data.numpy(), y.data.numpy())
|
||||
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
|
||||
plt.show()
|
||||
|
||||
# save net1
|
||||
save()
|
||||
|
||||
# restore entire net (may slow)
|
||||
restore_net()
|
||||
|
||||
# restore only the net parameters
|
||||
restore_params()
|
||||
Reference in New Issue
Block a user