From 895ad9a26c7e4bdcc710812868694c4fe7f704aa Mon Sep 17 00:00:00 2001 From: Kawa Yg Date: Wed, 22 Nov 2017 08:13:22 +0800 Subject: [PATCH] mnist dataset download setting --- tutorial-contents/401_CNN.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tutorial-contents/401_CNN.py b/tutorial-contents/401_CNN.py index 5848a1d..0907061 100644 --- a/tutorial-contents/401_CNN.py +++ b/tutorial-contents/401_CNN.py @@ -7,6 +7,11 @@ torch: 0.1.11 torchvision matplotlib """ +# library +# standard library +import os + +# third-party library import torch import torch.nn as nn from torch.autograd import Variable @@ -20,16 +25,20 @@ torch.manual_seed(1) # reproducible EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch BATCH_SIZE = 50 LR = 0.001 # learning rate -DOWNLOAD_MNIST = True # set to False if you have downloaded +DOWNLOAD_MNIST = False # Mnist digits dataset +if not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'): + # not mnist dir or mnist is empyt dir + DOWNLOAD_MNIST = True + train_data = torchvision.datasets.MNIST( root='./mnist/', train=True, # this is training data transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0] - download=DOWNLOAD_MNIST, # download it if you don't have it + download=DOWNLOAD_MNIST, ) # plot one example