Update 406_GAN

This commit is contained in:
Evan Evan.sijia Sijia
2020-06-06 22:53:16 +01:00
parent 9626a06ecb
commit a7bfd6615d
2 changed files with 374 additions and 2 deletions

View File

@ -12,7 +12,9 @@
"Dependencies:\n",
"* torch: 0.1.11\n",
"* numpy\n",
"* matplotlib"
"* matplotlib\n",
"\n",
"Note: Below cells only work for pytorch version lower than 1.5, if you are using pytorch 1.5 or higher, then you will get in place error when you run For loop for Step(10000). Fixed version has been added below."
]
},
{
@ -252,6 +254,47 @@
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# below is for pytorch version => 1.5\n",
"# this fix is done by albanD - https://github.com/pytorch/pytorch/issues/39141\n",
"\n",
"for step in range(10000):\n",
" \n",
" artist_paintings = artist_works() # real painting from artist\n",
" G_ideas = torch.randn(BATCH_SIZE, N_IDEAS, requires_grad=True) # random ideas\n",
" G_paintings = G(G_ideas) # fake painting from G (random ideas)\n",
"\n",
" prob_artist1 = D(G_paintings) # D try to reduce this prob\n",
" \n",
" G_loss = torch.mean(torch.log(1. - prob_artist1)) \n",
" opt_G.zero_grad()\n",
" G_loss.backward()\n",
" opt_G.step()\n",
" \n",
" prob_artist0 = D(artist_paintings) # D try to increase this prob\n",
" prob_artist1 = D(G_paintings.detach()) # D try to reduce this prob\n",
" D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1)) \n",
"\n",
" opt_D.zero_grad()\n",
" D_loss.backward(retain_graph=True) # reusing computational graph\n",
" opt_D.step()\n",
"\n",
" if step % 1000 == 0: # plotting\n",
" plt.cla()\n",
" plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)\n",
" plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')\n",
" plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')\n",
" plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 15})\n",
" plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 15})\n",
" plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=12);plt.draw();plt.pause(0.01)\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -278,7 +321,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
"version": "3.7.6"
}
},
"nbformat": 4,