Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate Image Training Endpoint to Django #1011

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions training/training/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import torch
from torch.utils.data import Dataset
from torch.autograd import Variable
from torchvision import datasets, transforms
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
from torch.util.data import DataLoader, random_split
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
from enum import Enum


class TrainTestDatasetCreator(ABC):
Expand Down Expand Up @@ -98,3 +101,39 @@ def getCategoryList(self) -> list[str]:
if self._category_list is None:
raise Exception("Category list not available")
return self._category_list

class DefaultImageDatasets(Enum):
MNIST = "MNIST"
FASHION_MNIST = "FashionMNIST"
KMNIST = "KMNIST"
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
class ImageDefaultDatasetCreator(TrainTestDatasetCreator):

def __init__(
self,
dataset_name:str,
train_tarnsform:None,
test_transform:None,
batch_size: int = 32,
shuffle: bool = True,
):
if dataset_name not in DefaultImageDatasets.__members__:
raise Exception(f"The {dataset_name} file does not currently exist in our inventory. Please submit a request to the contributors of the repository")
self.train_transform = train_tarnsform or transforms.Compose([transforms.toTensor()])
self.test_transform = test_transform or transforms.Compose([transforms.toTensor()])
self.batch_size = batch_size
self.shuffle = shuffle

self.train_set = datasets.__dict__[dataset_name](root='./backend/image_data_uploads', train=True, download=True, transform=self.train_transform)
self.test_set = datasets.__dict__[dataset_name](root='./backend/image_data_uploads', train=False, download=True, transform=self.test_transform)


@classmethod
def fromDefault(cls, dataset_name: str, train_transform=None, test_transform=None, batch_size: int = 32, shuffle: bool = True):
return cls(dataset_name, train_transform, test_transform, batch_size, shuffle)
def createTrainDataset(self) -> DataLoader:
return DataLoader(self.train_set, batch_size = self.batch_size, shuffle = self.shuffle, drop_last=True)

def createTestDataset(self):
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
return DataLoader(self.test_set, batch_size = self.batch_size, shuffle = self.shuffle, drop_last=True)
def getCategoryList(self) -> list[str]:
return self.train_set.classes if hasattr(self.train_set, "classes") else []
5 changes: 5 additions & 0 deletions training/training/routes/image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from training.routes.image.image import router


def get_image_router():
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
return router
53 changes: 53 additions & 0 deletions training/training/routes/image/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Literal, Optional
from django.http import HttpRequest
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
from ninja import Router, Schema
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
from training.core.criterion import getCriterionHandler
from training.core.dl_model import DLModel
from training.core.datasets import ImageDefaultDatasetCreator
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
from torch.utils.data import DataLoader
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
from training.core.optimizer import getOptimizer
from training.core.trainer import ClassificationTrainer
from training.routes.image.schemas import ImageParams
from training.core.authenticator import FirebaseAuth

router = Router()

@router.post("", auth=FirebaseAuth())
def imageTrain(request: HttpRequest, imageParams: ImageParams):
if imageParams.default:
dataCreator = ImageDefaultDatasetCreator.fromDefault(
imageParams.default, imageParams.test_size, imageParams.shuffle
)
train_loader = DataLoader(
dataCreator.createTrainDataset(),
batch_size=imageParams.batch_size,
shuffle=False,
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
drop_last=True,
)

test_loader = DataLoader(
dataCreator.createTestDataset(),
batch_size=imageParams.batch_size,
shuffle=False,
drop_last=True,
)

model = DLModel.fromLayerParamsList(imageParams.user_arch)
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
optimizer = getOptimizer(model, imageParams.optimizer_name, 0.05)
criterionHandler = getCriterionHandler(imageParams.criterion)
if imageParams.problem_type == "CLASSIFICATION":
trainer = ClassificationTrainer(
train_loader,
test_loader,
model,
optimizer,
criterionHandler,
imageParams.epochs,
dataCreator.getCategoryList(),
)
for epoch_result in trainer:
print(epoch_result)
print(trainer.labels_last_epoch, trainer.y_pred_last_epoch)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all these print statements necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

saw this in the original tabular.py as well. if they're not necessary, i can remove them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@farisdurrani @karkir0003 any clarification on what to do here. Are all the print statements necessary in grand scheme of things?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we resolve this @codingwithsurya

print(trainer.generate_confusion_matrix())
print(trainer.generate_AUC_ROC_CURVE())
return trainer.generate_AUC_ROC_CURVE()
22 changes: 22 additions & 0 deletions training/training/routes/image/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any, Literal, Optional
from ninja import Schema
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved


class LayerParams(Schema):
value: str
parameters: list[Any]


class ImageParams(Schema):
target: str
features: list[str]
name: str
problem_type: Literal["CLASSIFICATION"]
default: Optional[str]
criterion: str
optimizer_name: str
shuffle: bool
epochs: int
test_size: float
batch_size: int
user_arch: list[LayerParams]
2 changes: 2 additions & 0 deletions training/training/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ninja import NinjaAPI, Schema
from training.routes.datasets.default import get_default_datasets_router
from training.routes.tabular import get_tabular_router
from training.routes.image import get_image_router

api = NinjaAPI()

Expand All @@ -32,6 +33,7 @@ def test(request: HttpRequest):

api.add_router("/datasets/default/", get_default_datasets_router())
api.add_router("/tabular", get_tabular_router())
api.add_router("/image", get_image_router())

urlpatterns = [
path("admin/", admin.site.urls),
Expand Down