move files
This commit is contained in:
31
tutorial-contents/305_batch_train.py
Normal file
31
tutorial-contents/305_batch_train.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""
|
||||
Know more, visit 莫烦Python: https://morvanzhou.github.io/tutorials/
|
||||
My Youtube Channel: https://www.youtube.com/user/MorvanZhou
|
||||
|
||||
Dependencies:
|
||||
torch: 0.1.11
|
||||
"""
|
||||
import torch
|
||||
import torch.utils.data as Data
|
||||
|
||||
torch.manual_seed(1) # reproducible
|
||||
|
||||
BATCH_SIZE = 5
|
||||
# BATCH_SIZE = 8
|
||||
|
||||
x = torch.linspace(1, 10, 10) # this is x data (torch tensor)
|
||||
y = torch.linspace(10, 1, 10) # this is y data (torch tensor)
|
||||
|
||||
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
|
||||
loader = Data.DataLoader(
|
||||
dataset=torch_dataset, # torch TensorDataset format
|
||||
batch_size=BATCH_SIZE, # mini batch size
|
||||
shuffle=True, # random shuffle for training
|
||||
num_workers=2, # subprocesses for loading data
|
||||
)
|
||||
|
||||
for epoch in range(3): # train entire dataset 3 times
|
||||
for step, (batch_x, batch_y) in enumerate(loader): # for each training step
|
||||
# train your data...
|
||||
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
|
||||
batch_x.numpy(), '| batch y: ', batch_y.numpy())
|
||||
Reference in New Issue
Block a user