diff --git a/main.py b/main.py index 7b1b6c8..c2351ea 100644 --- a/main.py +++ b/main.py @@ -9,9 +9,11 @@ from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from runner.classification.classification_framework import ClassificationFramework +from runner.frameworks.classification.classification_framework import ( + ClassificationFramework, +) from runner.frameworks.framework_interface import TrainingFramework -from runner.frameworks.sf_framework import SFFramework +from runner.frameworks.rl.sf_framework import SFFramework from runner.sweep import launch_sweep from runner.util import create_brain, delete_results diff --git a/runner/frameworks/classification/classification_framework.py b/runner/frameworks/classification/classification_framework.py index ceae05f..a701bf1 100644 --- a/runner/frameworks/classification/classification_framework.py +++ b/runner/frameworks/classification/classification_framework.py @@ -6,10 +6,10 @@ from retinal_rl.models.brain import Brain from retinal_rl.models.loss import ContextT from retinal_rl.models.objective import Objective -from runner.classification.analyze import analyze -from runner.classification.dataset import get_datasets -from runner.classification.initialize import initialize -from runner.classification.train import train +from runner.frameworks.classification.analyze import analyze +from runner.frameworks.classification.dataset import get_datasets +from runner.frameworks.classification.initialize import initialize +from runner.frameworks.classification.train import train from runner.frameworks.framework_interface import TrainingFramework diff --git a/runner/frameworks/classification/train.py b/runner/frameworks/classification/train.py index c6e48a1..dcf187c 100644 --- a/runner/frameworks/classification/train.py +++ b/runner/frameworks/classification/train.py @@ -15,7 +15,7 @@ from retinal_rl.classification.training import process_dataset, run_epoch from retinal_rl.models.brain import Brain from retinal_rl.models.objective import Objective -from runner.classification.analyze import analyze +from runner.frameworks.classification.analyze import analyze from runner.util import save_checkpoint # Initialize the logger diff --git a/runner/frameworks/rl/sf_framework.py b/runner/frameworks/rl/sf_framework.py index 139925c..601fc74 100644 --- a/runner/frameworks/rl/sf_framework.py +++ b/runner/frameworks/rl/sf_framework.py @@ -22,7 +22,6 @@ from sample_factory.utils.attr_dict import AttrDict from sample_factory.utils.typing import Config -from retinal_rl.framework_interface import TrainingFramework from retinal_rl.models.brain import Brain from retinal_rl.models.loss import ContextT from retinal_rl.models.objective import Objective @@ -33,6 +32,7 @@ ) from retinal_rl.rl.sample_factory.environment import register_retinal_env from retinal_rl.rl.sample_factory.models import SampleFactoryBrain +from runner.frameworks.framework_interface import TrainingFramework from runner.util import create_brain