Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UPDATED FEATURE-993] Migrate Image Training Endpoint to Django #1053

Merged
merged 24 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c3d8938
adding image training endpoint
codingwithsurya Oct 6, 2023
91f56bf
:art: Format Python code with psf/black
codingwithsurya Oct 7, 2023
d7ce98d
importing image for dl_model.py
codingwithsurya Oct 7, 2023
2edb68d
Merge remote-tracking branch 'origin/feature-993' into feature-993
codingwithsurya Oct 7, 2023
cd8f850
fixing return typing and typos
codingwithsurya Oct 7, 2023
a5d0c0b
optimizing imports and adding cifar dataset
codingwithsurya Oct 7, 2023
9e03ea0
deleting local downloaded data
codingwithsurya Oct 7, 2023
b3e82b9
adding return typing
codingwithsurya Oct 13, 2023
f9f1edb
:art: Format Python code with psf/black
codingwithsurya Oct 13, 2023
ce4351f
fixing comment
codingwithsurya Oct 13, 2023
ac4cc8e
pulling remote branch
codingwithsurya Oct 13, 2023
64645d6
adding dummy keyword arg
codingwithsurya Oct 13, 2023
5185cda
:art: Format Python code with psf/black
codingwithsurya Oct 13, 2023
30f5dd2
added logging
codingwithsurya Oct 13, 2023
ce571e4
added logging
codingwithsurya Oct 13, 2023
5cbbaf0
testing endpoint
codingwithsurya Oct 13, 2023
6a35554
adding maxpool2d and flatten
codingwithsurya Oct 27, 2023
2922549
updating criterion
codingwithsurya Oct 27, 2023
dbc97d1
removing print statements
codingwithsurya Nov 19, 2023
5308884
fixing cifar typo
codingwithsurya Nov 19, 2023
08ff344
Merge branch 'nextjs' of https://github.com/DSGT-DLP/Deep-Learning-Pl…
karkir0003 Nov 20, 2023
69d1477
fixing nits and feedback
codingwithsurya Nov 20, 2023
4f33824
Merge branch 'feature-993-updated' of https://github.com/DSGT-DLP/Dee…
karkir0003 Nov 20, 2023
bb26b81
:art: Format Python code with psf/black
karkir0003 Nov 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dlp-cli
2 changes: 1 addition & 1 deletion training/training/core/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CELossHandler(CriterionHandler):
def _compute_loss(self, output, labels):
output = torch.reshape(
output,
(output.shape[0], output.shape[2]),
(output.shape[0], output.shape[-1]),
)
labels = labels.squeeze_()
return nn.CrossEntropyLoss(reduction="mean")(output, labels.long())
Expand Down
107 changes: 106 additions & 1 deletion training/training/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,21 @@
import torch
from torch.utils.data import Dataset
from torch.autograd import Variable
from torchvision import datasets, transforms
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
from torch.utils.data import DataLoader
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
from enum import Enum
import os
import shutil


class TrainTestDatasetCreator(ABC):
"Creator that creates train and test PyTorch datasets"
"""
Creator that creates train and test PyTorch datasets from a given dataset.

This class serves as an abstract base class for creating training and testing
datasets compatible with PyTorch's dataset structure. Implementations should
define specific methods for dataset processing and loading.
"""

@abstractmethod
def createTrainDataset(self) -> Dataset:
Expand Down Expand Up @@ -98,3 +109,97 @@ def getCategoryList(self) -> list[str]:
if self._category_list is None:
raise Exception("Category list not available")
return self._category_list


class DefaultImageDatasets(Enum):
MNIST = "MNIST"
FASHION_MNIST = "FashionMNIST"
KMNIST = "KMNIST"
CIFAR10 = "CIFAR10"
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved


class ImageDefaultDatasetCreator(TrainTestDatasetCreator):
def __init__(
self,
dataset_name: str,
train_transform: None,
test_transform: None,
batch_size: int = 32,
shuffle: bool = True,
):
if dataset_name not in DefaultImageDatasets.__members__:
raise Exception(
f"The {dataset_name} file does not currently exist in our inventory. Please submit a request to the contributors of the repository"
)

self.dataset_dir = "./training/image_data_uploads"
self.train_transform = train_transform or transforms.Compose(
[transforms.ToTensor()]
)

self.test_transform = test_transform or transforms.Compose(
[transforms.ToTensor()]
)
self.batch_size = batch_size
self.shuffle = shuffle

# Ensure the directory exists
os.makedirs(self.dataset_dir, exist_ok=True)
print(f"train transform: {train_transform}")
print(f"test transform: {test_transform}")
# Load the datasets

self.train_set = datasets.__dict__[dataset_name](
root=self.dataset_dir,
train=True,
download=True,
transform=self.train_transform,
)
self.test_set = datasets.__dict__[dataset_name](
root=self.dataset_dir,
train=False,
download=True,
transform=self.test_transform,
)

@classmethod
def fromDefault(
cls,
dataset_name: str,
train_transform=None,
test_transform=None,
batch_size: int = 32,
shuffle: bool = True,
) -> "ImageDefaultDatasetCreator":
return cls(dataset_name, train_transform, test_transform, batch_size, shuffle)

def delete_datasets_from_directory(self):
if os.path.exists(self.dataset_dir):
try:
shutil.rmtree(self.dataset_dir)
print(f"Successfully deleted {self.dataset_dir}")
except Exception as e:
print(f"Failed to delete {self.dataset_dir} with error: {e}")

def createTrainDataset(self) -> DataLoader:
train_loader = DataLoader(
self.train_set,
batch_size=self.batch_size,
shuffle=self.shuffle,
drop_last=True,
)
self.delete_datasets_from_directory() # Delete datasets after loading
return train_loader

def createTestDataset(self) -> DataLoader:
test_loader = DataLoader(
self.test_set,
batch_size=self.batch_size,
shuffle=self.shuffle,
drop_last=True,
)
self.delete_datasets_from_directory() # Delete datasets after loading
return test_loader

def getCategoryList(self) -> list[str]:
return self.train_set.classes if hasattr(self.train_set, "classes") else []
5 changes: 5 additions & 0 deletions training/training/core/dl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn as nn

from training.routes.tabular.schemas import LayerParams
from training.routes.image.schemas import LayerParams


class DLModel(nn.Module):
Expand All @@ -13,6 +14,10 @@ class DLModel(nn.Module):
"SOFTMAX": nn.Softmax,
"SIGMOID": nn.Sigmoid,
"LOGSOFTMAX": nn.LogSoftmax,
"CONV2D": nn.Conv2d,
"DROPOUT": nn.Dropout,
"MAXPOOL2D": nn.MaxPool2d,
"FLATTEN": nn.Flatten,
}

def __init__(self, layer_list: list[nn.Module]):
Expand Down
Empty file.
42 changes: 42 additions & 0 deletions training/training/routes/image/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Literal, Optional
from django.http import HttpRequest
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
from ninja import Router, Schema
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
from training.core.criterion import getCriterionHandler
from training.core.dl_model import DLModel
from training.core.dataset import ImageDefaultDatasetCreator
from torch.utils.data import DataLoader
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
from training.core.optimizer import getOptimizer
from training.core.trainer import ClassificationTrainer
from training.routes.image.schemas import ImageParams
from training.core.authenticator import FirebaseAuth

router = Router()


@router.post("", auth=FirebaseAuth())
def imageTrain(request: HttpRequest, imageParams: ImageParams):
if imageParams.default:
dataCreator = ImageDefaultDatasetCreator.fromDefault(imageParams.default)
print(vars(dataCreator))
train_loader = dataCreator.createTrainDataset()
test_loader = dataCreator.createTestDataset()
model = DLModel.fromLayerParamsList(imageParams.user_arch)
# print(f'model is: {model}')
optimizer = getOptimizer(model, imageParams.optimizer_name, 0.05)
criterionHandler = getCriterionHandler(imageParams.criterion)
if imageParams.problem_type == "CLASSIFICATION":
trainer = ClassificationTrainer(
train_loader,
test_loader,
model,
optimizer,
criterionHandler,
imageParams.epochs,
dataCreator.getCategoryList(),
)
for epoch_result in trainer:
print(epoch_result)
print(trainer.labels_last_epoch, trainer.y_pred_last_epoch)
print(trainer.generate_confusion_matrix())
print(trainer.generate_AUC_ROC_CURVE())
return trainer.generate_AUC_ROC_CURVE()
22 changes: 22 additions & 0 deletions training/training/routes/image/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any, Literal, Optional
from ninja import Schema
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved


class LayerParams(Schema):
value: str
parameters: list[Any]


class ImageParams(Schema):
codingwithsurya marked this conversation as resolved.
Show resolved Hide resolved
# target: str
# features: list[str]
name: str
problem_type: Literal["CLASSIFICATION"]
default: Optional[str]
criterion: str
optimizer_name: str
shuffle: bool
epochs: int
test_size: float
batch_size: int
user_arch: list[LayerParams]
4 changes: 3 additions & 1 deletion training/training/urls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""
URL configuration for training project.

The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/4.2/topics/http/urls/
Examples:
Expand All @@ -19,8 +18,10 @@
from django.http import HttpRequest
from django.urls import path
from ninja import NinjaAPI, Schema

from training.routes.datasets.default.columns import router as default_dataset_router
from training.routes.tabular.tabular import router as tabular_router
from training.routes.image.image import router as image_router

api = NinjaAPI()

Expand All @@ -32,6 +33,7 @@ def test(request: HttpRequest):

api.add_router("/datasets/default/", default_dataset_router)
api.add_router("/tabular", tabular_router)
api.add_router("/image", image_router)

urlpatterns = [
path("admin/", admin.site.urls),
Expand Down