Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 1442 sliding window inference for yolonas #1979

Merged
merged 26 commits into from
May 22, 2024
Merged
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2ca48bf
wip
shaydeci Apr 18, 2024
c3666e9
wip
shaydeci Apr 30, 2024
4b783ea
wip2
shaydeci May 1, 2024
667ed3b
Merge remote-tracking branch 'origin/master' into feature/SG-1442_sli…
shaydeci May 5, 2024
1d8cc8c
working version, hard coded nms params
shaydeci May 5, 2024
fae6d8d
moved post prediction callback to utils
shaydeci May 5, 2024
45aea2a
moved back to wrapper
shaydeci May 7, 2024
684af84
Merge remote-tracking branch 'origin/master' into feature/SG-1442_sli…
shaydeci May 9, 2024
837ffd3
added abstract class, small refactoring for pipeline
shaydeci May 9, 2024
f77616c
rolled back customizable detector, solved pretrained weights setting …
shaydeci May 9, 2024
dce1b4a
temp cleanup
shaydeci May 9, 2024
6c64ddd
support for fuse model in predict
shaydeci May 9, 2024
2cdf4ff
example added for predict
shaydeci May 9, 2024
80d81e9
added support for forward wrappers in trainer
shaydeci May 9, 2024
bf809eb
added test for validation forward wrapper
shaydeci May 9, 2024
877e016
added option for None as post prediction callback in DetectionMetrics
shaydeci May 9, 2024
8192a15
wip adding set_model before using wrapper
shaydeci May 15, 2024
60cf723
Merge remote-tracking branch 'origin/master' into feature/SG-1442_sli…
shaydeci May 16, 2024
ebfefd1
commit changes before removal of validation during training support
shaydeci May 16, 2024
aa7d0cb
refined docs
shaydeci May 16, 2024
7f3a0d4
removed old test for forward wrapper, fixed defaults
shaydeci May 20, 2024
1056b23
fixed test and added clarifications
shaydeci May 20, 2024
2981c23
forward wrapper test removed
shaydeci May 20, 2024
cf169e9
Merge remote-tracking branch 'origin/master' into feature/SG-1442_sli…
shaydeci May 20, 2024
0bcb821
updated wrong threshold extraction and test result
shaydeci May 20, 2024
2d6331a
fixed docstring format
shaydeci May 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added test for validation forward wrapper
shaydeci committed May 9, 2024
commit bf809eb481cdff519c23079eb082a8854491dfd2
4 changes: 2 additions & 2 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
@@ -2289,8 +2289,8 @@ def evaluate(
else:
self.phase_callback_handler.on_test_batch_start(context)

if self.self.validation_forward_wrapper is not None:
output = self.validation_forward_wrapper(inputs, self.net)
if self.validation_forward_wrapper is not None:
output = self.validation_forward_wrapper(inputs=inputs, model=self.net)
else:
output = self.net(inputs)

2 changes: 2 additions & 0 deletions tests/deci_core_unit_test_suite_runner.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@
TestClassificationAdapter,
TestDetectionAdapter,
TestSegmentationAdapter,
TestForwardWrapper,
)
from tests.end_to_end_tests import TestTrainer
from tests.unit_tests.depth_estimation_dataset_test import DepthEstimationDatasetTest
@@ -192,6 +193,7 @@ def _add_modules_to_unit_tests_suite(self):
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ClassBalancerTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ClassBalancedSamplerTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestSegmentationModelExport))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestForwardWrapper))

def _add_modules_to_end_to_end_tests_suite(self):
"""
2 changes: 2 additions & 0 deletions tests/unit_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
from tests.unit_tests.crash_tips_test import CrashTipTest
from tests.unit_tests.double_training_test import CallTrainTwiceTest
from tests.unit_tests.factories_test import FactoriesTest
from tests.unit_tests.forward_wrapper_test import TestForwardWrapper
from tests.unit_tests.optimizer_params_override_test import TrainOptimizerParamsOverride
from tests.unit_tests.resume_training_test import ResumeTrainingTest
from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
@@ -63,4 +64,5 @@
"TestClassificationAdapter",
"TestDetectionAdapter",
"TestSegmentationAdapter",
"TestForwardWrapper",
]
79 changes: 79 additions & 0 deletions tests/unit_tests/forward_wrapper_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import unittest

from super_gradients.training import Trainer
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
from super_gradients.training.metrics import Accuracy
from super_gradients.training.models import LeNet
from super_gradients.training.utils.callbacks import PhaseContext, Callback
import torch


class OutputsCollectorCallback(Callback):
def __init__(self):
self.validation_outputs = []
self.train_outputs = []

def on_validation_batch_end(self, context: PhaseContext) -> None:
self.validation_outputs.append(context.preds)

def on_train_batch_end(self, context: PhaseContext) -> None:
self.train_outputs.append(context.preds)


class DummyForwardWrapper:
def __call__(self, inputs: torch.Tensor, model: torch.nn.Module):
return torch.ones_like(model(inputs))


def compare_tensor_lists(list1, list2):
if len(list1) != len(list2):
return False

# Move tensors to CPU
list1 = [t.cpu() for t in list1]
list2 = [t.cpu() for t in list2]

for tensor1, tensor2 in zip(list1, list2):
if not torch.all(torch.eq(tensor1, tensor2)):
return False
return True


class TestForwardWrapper(unittest.TestCase):
def test_train_with_validation_forward_wrapper(self):
# Define Model
net = LeNet()
trainer = Trainer("test_train_with_validation_forward_wrapper")
output_collector = OutputsCollectorCallback()
validation_forward_wrapper = DummyForwardWrapper()
train_params = {
"max_epochs": 1,
"initial_lr": 0.1,
"loss": "CrossEntropyLoss",
"optimizer": "SGD",
"criterion_params": {},
"optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
"train_metrics_list": [Accuracy()],
"valid_metrics_list": [Accuracy()],
"metric_to_watch": "Accuracy",
"greater_metric_to_watch_is_better": True,
"ema": False,
"phase_callbacks": [output_collector],
"warmup_mode": "LinearEpochLRWarmup",
"validation_forward_wrapper": validation_forward_wrapper,
"average_best_models": False,
}

expected_outputs = [torch.ones(4, 10)]
trainer.train(
model=net,
training_params=train_params,
train_loader=classification_test_dataloader(batch_size=4),
valid_loader=classification_test_dataloader(batch_size=4),
)
self.assertTrue(compare_tensor_lists(expected_outputs, output_collector.validation_outputs))
self.assertFalse(compare_tensor_lists(expected_outputs, output_collector.train_outputs))


if __name__ == "__main__":
unittest.main()