update for new version of torch
This commit is contained in:
@ -55,23 +55,21 @@ opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)
|
||||
plt.ion() # something about continuous plotting
|
||||
|
||||
for step in range(10000):
|
||||
artist_paintings = artist_works() # real painting from artist
|
||||
G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas
|
||||
artist_paintings = artist_works() # real painting from artist
|
||||
G_ideas = torch.randn(BATCH_SIZE, N_IDEAS, requires_grad=True) # random ideas\n
|
||||
G_paintings = G(G_ideas) # fake painting from G (random ideas)
|
||||
|
||||
prob_artist0 = D(artist_paintings) # D try to increase this prob
|
||||
prob_artist1 = D(G_paintings) # D try to reduce this prob
|
||||
|
||||
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
|
||||
G_loss = torch.mean(torch.log(1. - prob_artist1))
|
||||
|
||||
opt_D.zero_grad()
|
||||
D_loss.backward(retain_graph=True) # reusing computational graph
|
||||
opt_D.step()
|
||||
|
||||
G_loss = torch.mean(torch.log(1. - prob_artist1))
|
||||
opt_G.zero_grad()
|
||||
G_loss.backward()
|
||||
opt_G.step()
|
||||
|
||||
prob_artist0 = D(artist_paintings) # D try to increase this prob
|
||||
prob_artist1 = D(G_paintings.detach()) # D try to reduce this prob
|
||||
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
|
||||
opt_D.zero_grad()
|
||||
D_loss.backward(retain_graph=True) # reusing computational graph
|
||||
opt_D.step()
|
||||
|
||||
if step % 50 == 0: # plotting
|
||||
plt.cla()
|
||||
|
||||
Reference in New Issue
Block a user