diff --git a/vel/internals/model_config.py b/vel/internals/model_config.py index 7dea0c66..979fcdee 100644 --- a/vel/internals/model_config.py +++ b/vel/internals/model_config.py @@ -30,7 +30,8 @@ def find_project_directory(start_path) -> str: return ModelConfig.find_project_directory(up_path) @classmethod - def from_file(cls, filename: str, run_number: int, continue_training=False, seed: int=None, device: str='cuda', params=None): + def from_file(cls, filename: str, run_number: int, continue_training: bool = False, seed: int = None, + device: str = 'cuda', params=None): """ Create model config from file """ with open(filename, 'r') as fp: model_config_contents = Parser.parse(fp) @@ -45,14 +46,7 @@ def from_file(cls, filename: str, run_number: int, continue_training=False, seed **model_config_contents } - # Options that should exist for every config - try: - model_name = model_config_contents['name'] - except KeyError: - raise VelInitializationException("Model configuration must have a 'name' key") - return ModelConfig( - model_name=model_name, filename=filename, configuration=aggregate_dictionary, run_number=run_number, @@ -64,11 +58,10 @@ def from_file(cls, filename: str, run_number: int, continue_training=False, seed ) @classmethod - def from_memory(cls, model_name: str, model_data: dict, run_number: int, project_dir: str, - continue_training=False, seed: int=None, device: str='cuda', params=None): + def from_memory(cls, model_data: dict, run_number: int, project_dir: str, + continue_training=False, seed: int = None, device: str = 'cuda', params=None): """ Create model config from supplied data """ return ModelConfig( - model_name=model_name, filename="[memory]", configuration=model_data, run_number=run_number, @@ -79,9 +72,8 @@ def from_memory(cls, model_name: str, model_data: dict, run_number: int, project parameters=params ) - def __init__(self, model_name: str, filename: str, configuration: dict, run_number: int, project_dir: str, - continue_training=False, seed: int=None, device: str= 'cuda', parameters=None): - self._model_name = model_name + def __init__(self, filename: str, configuration: dict, run_number: int, project_dir: str, + continue_training=False, seed: int = None, device: str = 'cuda', parameters=None): self.filename = filename self.device = device self.continue_training = continue_training @@ -98,6 +90,7 @@ def __init__(self, model_name: str, filename: str, configuration: dict, run_numb del self.contents['commands'] self.provider = Provider(self._prepare_environment(), {'model_config': self}, parameters=parameters) + self._model_name = self.provider.get("name") def _prepare_environment(self) -> dict: """ Return full environment for dependency injection """ diff --git a/vel/internals/provider.py b/vel/internals/provider.py index 8fe26569..526181cc 100644 --- a/vel/internals/provider.py +++ b/vel/internals/provider.py @@ -116,3 +116,7 @@ def instantiate_by_name_with_default(self, object_name, default_value=None): return instance else: return self.instances[object_name] + + def get(self, name): + """ Get object from given provider """ + return self.instantiate_by_name(name)