diff --git a/tutorial-contents/305_batch_train.py b/tutorial-contents/305_batch_train.py index 7e58e87..06350b5 100644 --- a/tutorial-contents/305_batch_train.py +++ b/tutorial-contents/305_batch_train.py @@ -16,7 +16,7 @@ BATCH_SIZE = 5 x = torch.linspace(1, 10, 10) # this is x data (torch tensor) y = torch.linspace(10, 1, 10) # this is y data (torch tensor) -torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y) +torch_dataset = Data.TensorDataset(x, y) loader = Data.DataLoader( dataset=torch_dataset, # torch TensorDataset format batch_size=BATCH_SIZE, # mini batch size