Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate Image Training Endpoint to Django #1011

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion training/training/core/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CELossHandler(CriterionHandler):
def _compute_loss(self, output, labels):
output = torch.reshape(
output,
(output.shape[0], output.shape[2]),
(output.shape[0], output.shape[-1]),
)
labels = labels.squeeze_()
return nn.CrossEntropyLoss(reduction="mean")(output, labels.long())
Expand Down
101 changes: 100 additions & 1 deletion training/training/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@
import torch
from torch.utils.data import Dataset
from torch.autograd import Variable
from torchvision import datasets, transforms
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
from torch.utils.data import DataLoader
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
from enum import Enum
import os
import shutil


class TrainTestDatasetCreator(ABC):
"Creator that creates train and test PyTorch datasets"
"Creator that creates train and test PyTorch datasets from a given dataset"

@abstractmethod
def createTrainDataset(self) -> Dataset:
Expand Down Expand Up @@ -98,3 +103,97 @@ def getCategoryList(self) -> list[str]:
if self._category_list is None:
raise Exception("Category list not available")
return self._category_list


class DefaultImageDatasets(Enum):
MNIST = "MNIST"
FASHION_MNIST = "FashionMNIST"
KMNIST = "KMNIST"
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
CIFAR10 = "CIFAR10"


class ImageDefaultDatasetCreator(TrainTestDatasetCreator):
def __init__(
self,
dataset_name: str,
train_transform: None,
test_transform: None,
batch_size: int = 32,
shuffle: bool = True,
):
if dataset_name not in DefaultImageDatasets.__members__:
raise Exception(
f"The {dataset_name} file does not currently exist in our inventory. Please submit a request to the contributors of the repository"
)

self.dataset_dir = "./training/image_data_uploads"
self.train_transform = train_transform or transforms.Compose(
[transforms.ToTensor()]
)

self.test_transform = test_transform or transforms.Compose(
[transforms.ToTensor()]
)
self.batch_size = batch_size
self.shuffle = shuffle

# Ensure the directory exists
os.makedirs(self.dataset_dir, exist_ok=True)
print(f'train transform: {train_transform}')
print(f'test transform: {test_transform}')
# Load the datasets

self.train_set = datasets.__dict__[dataset_name](
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
root=self.dataset_dir,
train=True,
download=True,
transform=self.train_transform,
)
self.test_set = datasets.__dict__[dataset_name](
root=self.dataset_dir,
train=False,
download=True,
transform=self.test_transform,
)

@classmethod
def fromDefault(
cls,
dataset_name: str,
train_transform=None,
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
test_transform=None,
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
batch_size: int = 32,
shuffle: bool = True,
) -> "ImageDefaultDatasetCreator":
return cls(dataset_name, train_transform, test_transform, batch_size, shuffle)

def delete_datasets_from_directory(self):
if os.path.exists(self.dataset_dir):
try:
shutil.rmtree(self.dataset_dir)
print(f"Successfully deleted {self.dataset_dir}")
except Exception as e:
print(f"Failed to delete {self.dataset_dir} with error: {e}")

def createTrainDataset(self) -> DataLoader:
train_loader = DataLoader(
self.train_set,
batch_size=self.batch_size,
shuffle=self.shuffle,
drop_last=True,
)
self.delete_datasets_from_directory() # Delete datasets after loading
return train_loader

def createTestDataset(self) -> DataLoader:
test_loader = DataLoader(
self.test_set,
batch_size=self.batch_size,
shuffle=self.shuffle,
drop_last=True,
)
self.delete_datasets_from_directory() # Delete datasets after loading
return test_loader

def getCategoryList(self) -> list[str]:
return self.train_set.classes if hasattr(self.train_set, "classes") else []
5 changes: 5 additions & 0 deletions training/training/core/dl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING

from training.routes.tabular.schemas import LayerParams
from training.routes.image.schemas import LayerParams


class DLModel(nn.Module):
Expand All @@ -15,6 +16,10 @@ class DLModel(nn.Module):
"SOFTMAX": nn.Softmax,
"SIGMOID": nn.Sigmoid,
"LOGSOFTMAX": nn.LogSoftmax,
"CONV2D": nn.Conv2d,
"DROPOUT": nn.Dropout,
"MAXPOOL2D": nn.MaxPool2d,
"FLATTEN": nn.Flatten
}

def __init__(self, layer_list: list[nn.Module]):
Expand Down
5 changes: 0 additions & 5 deletions training/training/routes/datasets/default/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
from training.routes.datasets.default.columns import router


def get_default_datasets_router():
return router
Empty file.
44 changes: 44 additions & 0 deletions training/training/routes/image/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Literal, Optional
from django.http import HttpRequest
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
from ninja import Router, Schema
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
from training.core.criterion import getCriterionHandler
from training.core.dl_model import DLModel
from training.core.dataset import ImageDefaultDatasetCreator
from torch.utils.data import DataLoader
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
from training.core.optimizer import getOptimizer
from training.core.trainer import ClassificationTrainer
from training.routes.image.schemas import ImageParams
from training.core.authenticator import FirebaseAuth

router = Router()


@router.post("", auth=FirebaseAuth())
def imageTrain(request: HttpRequest, imageParams: ImageParams):
if imageParams.default:
dataCreator = ImageDefaultDatasetCreator.fromDefault(
imageParams.default
)
print(vars(dataCreator))
train_loader = dataCreator.createTrainDataset()
test_loader = dataCreator.createTestDataset()
model = DLModel.fromLayerParamsList(imageParams.user_arch)
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
print(f'model is: {model}')
optimizer = getOptimizer(model, imageParams.optimizer_name, 0.05)
criterionHandler = getCriterionHandler(imageParams.criterion)
if imageParams.problem_type == "CLASSIFICATION":
trainer = ClassificationTrainer(
train_loader,
test_loader,
model,
optimizer,
criterionHandler,
imageParams.epochs,
dataCreator.getCategoryList(),
)
for epoch_result in trainer:
print(epoch_result)
print(trainer.labels_last_epoch, trainer.y_pred_last_epoch)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all these print statements necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

saw this in the original tabular.py as well. if they're not necessary, i can remove them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@farisdurrani @karkir0003 any clarification on what to do here. Are all the print statements necessary in grand scheme of things?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we resolve this @codingwithsurya

print(trainer.generate_confusion_matrix())
print(trainer.generate_AUC_ROC_CURVE())
return trainer.generate_AUC_ROC_CURVE()
22 changes: 22 additions & 0 deletions training/training/routes/image/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any, Literal, Optional
from ninja import Schema
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved


class LayerParams(Schema):
value: str
parameters: list[Any]


class ImageParams(Schema):
# target: str
# features: list[str]
name: str
problem_type: Literal["CLASSIFICATION"]
default: Optional[str]
criterion: str
optimizer_name: str
shuffle: bool
epochs: int
test_size: float
batch_size: int
user_arch: list[LayerParams]
11 changes: 7 additions & 4 deletions training/training/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from django.http import HttpRequest
from django.urls import path
from ninja import NinjaAPI, Schema
from training.routes.datasets.default import get_default_datasets_router
from training.routes.tabular import get_tabular_router

from training.routes.datasets.default.columns import router as default_dataset_router
from training.routes.tabular.tabular import router as tabular_router
from training.routes.image.image import router as image_router

api = NinjaAPI()

Expand All @@ -30,8 +32,9 @@ def test(request: HttpRequest):
return 200, {"result": "200 Backend surface test successful"}


api.add_router("/datasets/default/", get_default_datasets_router())
api.add_router("/tabular", get_tabular_router())
api.add_router("/datasets/default/", default_dataset_router)
api.add_router("/tabular", tabular_router)
api.add_router("/image", image_router)

urlpatterns = [
path("admin/", admin.site.urls),
Expand Down