Update 406_GAN
This commit is contained in:
@ -12,7 +12,9 @@
|
||||
"Dependencies:\n",
|
||||
"* torch: 0.1.11\n",
|
||||
"* numpy\n",
|
||||
"* matplotlib"
|
||||
"* matplotlib\n",
|
||||
"\n",
|
||||
"Note: Below cells only work for pytorch version lower than 1.5, if you are using pytorch 1.5 or higher, then you will get in place error when you run For loop for Step(10000). Fixed version has been added below."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -252,6 +254,47 @@
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# below is for pytorch version => 1.5\n",
|
||||
"# this fix is done by albanD - https://github.com/pytorch/pytorch/issues/39141\n",
|
||||
"\n",
|
||||
"for step in range(10000):\n",
|
||||
" \n",
|
||||
" artist_paintings = artist_works() # real painting from artist\n",
|
||||
" G_ideas = torch.randn(BATCH_SIZE, N_IDEAS, requires_grad=True) # random ideas\n",
|
||||
" G_paintings = G(G_ideas) # fake painting from G (random ideas)\n",
|
||||
"\n",
|
||||
" prob_artist1 = D(G_paintings) # D try to reduce this prob\n",
|
||||
" \n",
|
||||
" G_loss = torch.mean(torch.log(1. - prob_artist1)) \n",
|
||||
" opt_G.zero_grad()\n",
|
||||
" G_loss.backward()\n",
|
||||
" opt_G.step()\n",
|
||||
" \n",
|
||||
" prob_artist0 = D(artist_paintings) # D try to increase this prob\n",
|
||||
" prob_artist1 = D(G_paintings.detach()) # D try to reduce this prob\n",
|
||||
" D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1)) \n",
|
||||
"\n",
|
||||
" opt_D.zero_grad()\n",
|
||||
" D_loss.backward(retain_graph=True) # reusing computational graph\n",
|
||||
" opt_D.step()\n",
|
||||
"\n",
|
||||
" if step % 1000 == 0: # plotting\n",
|
||||
" plt.cla()\n",
|
||||
" plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)\n",
|
||||
" plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')\n",
|
||||
" plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')\n",
|
||||
" plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 15})\n",
|
||||
" plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 15})\n",
|
||||
" plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=12);plt.draw();plt.pause(0.01)\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -278,7 +321,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.2"
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user