Skip to content

Commit

Permalink
fixing nits and feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
codingwithsurya committed Nov 20, 2023
1 parent 5308884 commit 69d1477
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 24 deletions.
8 changes: 7 additions & 1 deletion training/training/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@


class TrainTestDatasetCreator(ABC):
"Creator that creates train and test PyTorch datasets from a given dataset"
"""
Creator that creates train and test PyTorch datasets from a given dataset.
This class serves as an abstract base class for creating training and testing
datasets compatible with PyTorch's dataset structure. Implementations should
define specific methods for dataset processing and loading.
"""

@abstractmethod
def createTrainDataset(self) -> Dataset:
Expand Down
4 changes: 0 additions & 4 deletions training/training/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ def _train_step(self, inputs: torch.Tensor, labels: torch.Tensor):
self.optimizer.zero_grad() # zero out gradient for each batch
self.model.forward(inputs) # make prediction on input
self._outputs: torch.Tensor = self.model(inputs) # make prediction on input
print('MODEL FORWARD PASS DONE!!!!')
print(f'output dim: {self._outputs.shape}')
print(f'label dim: {labels.shape}')
print(f'loss function used: {self.criterionHandler}')
loss = self.criterionHandler.compute_loss(self._outputs, labels)
loss.backward() # backpropagation
self.optimizer.step() # adjust optimizer weights
Expand Down
16 changes: 1 addition & 15 deletions training/training/routes/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,8 @@ def imageTrain(request: HttpRequest, imageParams: ImageParams):
print(vars(dataCreator))
train_loader = dataCreator.createTrainDataset()
test_loader = dataCreator.createTestDataset()
# train_loader = DataLoader(
# dataCreator.createTrainDataset(),
# batch_size=imageParams.batch_size,
# shuffle=False,
# drop_last=True,
# )

# test_loader = DataLoader(
# dataCreator.createTestDataset(),
# batch_size=imageParams.batch_size,
# shuffle=False,
# drop_last=True,
# )

model = DLModel.fromLayerParamsList(imageParams.user_arch)
print(f'model is: {model}')
# print(f'model is: {model}')
optimizer = getOptimizer(model, imageParams.optimizer_name, 0.05)
criterionHandler = getCriterionHandler(imageParams.criterion)
if imageParams.problem_type == "CLASSIFICATION":
Expand Down
4 changes: 0 additions & 4 deletions training/training/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
from training.routes.datasets.default.columns import router as default_dataset_router
from training.routes.tabular.tabular import router as tabular_router
from training.routes.image.image import router as image_router
# from training.routes.datasets.default import get_default_datasets_router
# from training.routes.tabular import get_tabular_router
# from training.routes.image import image_router

api = NinjaAPI()


Expand Down

0 comments on commit 69d1477

Please sign in to comment.