update
This commit is contained in:
@ -17,10 +17,10 @@ import matplotlib.pyplot as plt
|
||||
torch.manual_seed(1) # reproducible
|
||||
|
||||
# Hyper Parameters
|
||||
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
|
||||
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
|
||||
BATCH_SIZE = 50
|
||||
LR = 0.001 # learning rate
|
||||
DOWNLOAD_MNIST = False
|
||||
LR = 0.001 # learning rate
|
||||
DOWNLOAD_MNIST = True # set to False if you have downloaded
|
||||
|
||||
|
||||
# Mnist digits dataset
|
||||
@ -33,8 +33,8 @@ train_data = torchvision.datasets.MNIST(
|
||||
)
|
||||
|
||||
# plot one example
|
||||
print(train_data.train_data.size()) # (60000, 28, 28)
|
||||
print(train_data.train_labels.size()) # (60000)
|
||||
print(train_data.train_data.size()) # (60000, 28, 28)
|
||||
print(train_data.train_labels.size()) # (60000)
|
||||
plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
|
||||
plt.title('%i' % train_data.train_labels[0])
|
||||
plt.show()
|
||||
@ -51,28 +51,28 @@ test_y = test_data.test_labels[:2000]
|
||||
class CNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(CNN, self).__init__()
|
||||
self.conv1 = nn.Sequential( # input shape (1, 28, 28)
|
||||
self.conv1 = nn.Sequential( # input shape (1, 28, 28)
|
||||
nn.Conv2d(
|
||||
in_channels=1, # input height
|
||||
out_channels=16, # n_filters
|
||||
kernel_size=5, # filter size
|
||||
stride=1, # filter movement/step
|
||||
padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
|
||||
), # output shape (16, 28, 28)
|
||||
nn.ReLU(), # activation
|
||||
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14)
|
||||
in_channels=1, # input height
|
||||
out_channels=16, # n_filters
|
||||
kernel_size=5, # filter size
|
||||
stride=1, # filter movement/step
|
||||
padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
|
||||
), # output shape (16, 28, 28)
|
||||
nn.ReLU(), # activation
|
||||
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14)
|
||||
)
|
||||
self.conv2 = nn.Sequential( # input shape (1, 28, 28)
|
||||
nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14)
|
||||
nn.ReLU(), # activation
|
||||
nn.MaxPool2d(2), # output shape (32, 7, 7)
|
||||
self.conv2 = nn.Sequential( # input shape (1, 28, 28)
|
||||
nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14)
|
||||
nn.ReLU(), # activation
|
||||
nn.MaxPool2d(2), # output shape (32, 7, 7)
|
||||
)
|
||||
self.out = nn.Linear(32 * 7 * 7, 10) # fully connected layer, output 10 classes
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
|
||||
x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
|
||||
output = self.out(x)
|
||||
return output
|
||||
|
||||
@ -81,7 +81,7 @@ cnn = CNN()
|
||||
print(cnn) # net architecture
|
||||
|
||||
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters
|
||||
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
|
||||
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
|
||||
|
||||
# training and testing
|
||||
for epoch in range(EPOCH):
|
||||
|
||||
Reference in New Issue
Block a user