Skip to content

Commit

Permalink
Allow to parametrize model name.
Browse files Browse the repository at this point in the history
  • Loading branch information
MillionIntegrals committed Apr 8, 2019
1 parent 599a8c4 commit 7d58032
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
21 changes: 7 additions & 14 deletions vel/internals/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def find_project_directory(start_path) -> str:
return ModelConfig.find_project_directory(up_path)

@classmethod
def from_file(cls, filename: str, run_number: int, continue_training=False, seed: int=None, device: str='cuda', params=None):
def from_file(cls, filename: str, run_number: int, continue_training: bool = False, seed: int = None,
device: str = 'cuda', params=None):
""" Create model config from file """
with open(filename, 'r') as fp:
model_config_contents = Parser.parse(fp)
Expand All @@ -45,14 +46,7 @@ def from_file(cls, filename: str, run_number: int, continue_training=False, seed
**model_config_contents
}

# Options that should exist for every config
try:
model_name = model_config_contents['name']
except KeyError:
raise VelInitializationException("Model configuration must have a 'name' key")

return ModelConfig(
model_name=model_name,
filename=filename,
configuration=aggregate_dictionary,
run_number=run_number,
Expand All @@ -64,11 +58,10 @@ def from_file(cls, filename: str, run_number: int, continue_training=False, seed
)

@classmethod
def from_memory(cls, model_name: str, model_data: dict, run_number: int, project_dir: str,
continue_training=False, seed: int=None, device: str='cuda', params=None):
def from_memory(cls, model_data: dict, run_number: int, project_dir: str,
continue_training=False, seed: int = None, device: str = 'cuda', params=None):
""" Create model config from supplied data """
return ModelConfig(
model_name=model_name,
filename="[memory]",
configuration=model_data,
run_number=run_number,
Expand All @@ -79,9 +72,8 @@ def from_memory(cls, model_name: str, model_data: dict, run_number: int, project
parameters=params
)

def __init__(self, model_name: str, filename: str, configuration: dict, run_number: int, project_dir: str,
continue_training=False, seed: int=None, device: str= 'cuda', parameters=None):
self._model_name = model_name
def __init__(self, filename: str, configuration: dict, run_number: int, project_dir: str,
continue_training=False, seed: int = None, device: str = 'cuda', parameters=None):
self.filename = filename
self.device = device
self.continue_training = continue_training
Expand All @@ -98,6 +90,7 @@ def __init__(self, model_name: str, filename: str, configuration: dict, run_numb
del self.contents['commands']

self.provider = Provider(self._prepare_environment(), {'model_config': self}, parameters=parameters)
self._model_name = self.provider.get("name")

def _prepare_environment(self) -> dict:
""" Return full environment for dependency injection """
Expand Down
4 changes: 4 additions & 0 deletions vel/internals/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,7 @@ def instantiate_by_name_with_default(self, object_name, default_value=None):
return instance
else:
return self.instances[object_name]

def get(self, name):
""" Get object from given provider """
return self.instantiate_by_name(name)

0 comments on commit 7d58032

Please sign in to comment.