This commit is contained in:
Morvan Zhou
2017-05-11 13:56:13 +10:00
parent 2523e0c1b1
commit 53b13ff245
2 changed files with 5 additions and 6 deletions

View File

@ -8,6 +8,9 @@
<br> <br>
** If you'd like to use Tensorflow, no worries, I made a new Tensorflow Tutorial just like PyTorch. Here is the link:
[https://github.com/MorvanZhou/Tensorflow-Tutorial](https://github.com/MorvanZhou/Tensorflow-Tutorial)**
# pyTorch Tutorials # pyTorch Tutorials
In these tutorials for pyTorch, we will build our first Neural Network and try to build some advanced Neural Network architectures developed recent years. In these tutorials for pyTorch, we will build our first Neural Network and try to build some advanced Neural Network architectures developed recent years.

View File

@ -40,11 +40,6 @@ def artist_works_with_labels(): # painting from the famous artist (real targ
return Variable(paintings), Variable(labels) return Variable(paintings), Variable(labels)
def G_ideas(): # the random ideas for generator to draw something
z = torch.randn(BATCH_SIZE, N_IDEAS)
return Variable(z)
G = nn.Sequential( # Generator G = nn.Sequential( # Generator
nn.Linear(N_IDEAS+1, 128), # random ideas (could from normal distribution) + class label nn.Linear(N_IDEAS+1, 128), # random ideas (could from normal distribution) + class label
nn.ReLU(), nn.ReLU(),
@ -65,7 +60,8 @@ plt.ion() # something about continuous plotting
plt.show() plt.show()
for step in range(10000): for step in range(10000):
artist_paintings, labels = artist_works_with_labels() # real painting, label from artist artist_paintings, labels = artist_works_with_labels() # real painting, label from artist
G_inputs = torch.cat((G_ideas(), labels), 1) G_ideas = Variable(torch.randn(BATCH_SIZE, N_IDEAS)) # random ideas
G_inputs = torch.cat((G_ideas, labels), 1) # ideas with labels
G_paintings = G(G_inputs) # fake painting w.r.t label from G G_paintings = G(G_inputs) # fake painting w.r.t label from G
D_inputs0 = torch.cat((artist_paintings, labels), 1) # all have their labels D_inputs0 = torch.cat((artist_paintings, labels), 1) # all have their labels