Skip to content

Commit

Permalink
define model to be used as env var
Browse files Browse the repository at this point in the history
  • Loading branch information
tkrieger committed Mar 24, 2021
1 parent a0b7368 commit 10ccb32
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
17 changes: 13 additions & 4 deletions agent_code/auto_bomber/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
import pickle
import shutil
from pathlib import Path
Expand All @@ -10,18 +11,26 @@
from agent_code.auto_bomber.transitions import Transitions


def get_model_dir():
try:
return os.environ["MODEL_DIR"]
except KeyError as e:
return model_path.MODEL_DIR


class LinearAutoBomberModel:
def __init__(self, train, feature_extractor):
self.train = train
self.weights = None
self.feature_extractor = feature_extractor

if model_path.MODEL_DIR and Path(model_path.MODEL_DIR).is_dir():
self.model_dir = Path(model_path.MODEL_DIR)
elif model_path.MODEL_DIR and not Path(model_path.MODEL_DIR).is_dir():
model_dir = get_model_dir()
if model_dir and Path(model_dir).is_dir():
self.model_dir = Path(model_dir)
elif model_dir and not Path(model_dir).is_dir():
raise FileNotFoundError("The specified model directory does not exist!\nIf you wish to train a NEW model"
"set parameter to None, otherwise specify a valid model directory.")
elif not self.train and not model_path.MODEL_DIR:
elif not self.train and not model_dir:
raise ValueError("No model directory has been specified.\n A model directory is required for inference.")
else:
root_dir = Path(model_path.MODELS_ROOT)
Expand Down
2 changes: 1 addition & 1 deletion agent_code/auto_bomber/model_path.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
MODELS_ROOT = "./models"
MODEL_DIR = "./models/4"
MODEL_DIR = None

0 comments on commit 10ccb32

Please sign in to comment.