diff --git a/tutorial-contents/305_batch_train.py b/tutorial-contents/305_batch_train.py index 06350b5..f8e9c98 100644 --- a/tutorial-contents/305_batch_train.py +++ b/tutorial-contents/305_batch_train.py @@ -24,8 +24,15 @@ loader = Data.DataLoader( num_workers=2, # subprocesses for loading data ) -for epoch in range(3): # train entire dataset 3 times - for step, (batch_x, batch_y) in enumerate(loader): # for each training step - # train your data... - print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', - batch_x.numpy(), '| batch y: ', batch_y.numpy()) + +def show_batch(): + for epoch in range(3): # train entire dataset 3 times + for step, (batch_x, batch_y) in enumerate(loader): # for each training step + # train your data... + print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', + batch_x.numpy(), '| batch y: ', batch_y.numpy()) + + +if __name__ == '__main__': + show_batch() +