Skip to content

Commit

Permalink
marius_predict bug (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
basavaraj29 authored Apr 28, 2022
1 parent 49669c0 commit 0f02da9
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 6 deletions.
5 changes: 5 additions & 0 deletions docs/export_and_inference/marius_predict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ Contents of ``configs/fb15k237.yaml``. The test set here has been created during
pipeline:
sync: true
Since ``storage.model_dir`` is not specified in the above configuration, ``marius_predict`` will use the latest trained model present in ``storage.dataset.dataset_dir``.
When ``storage.model_dir`` is not specified, ``marius_train`` stores the model parameters in `model_x` directory within the `storage.dataset.dataset_dir`, where x changes
incrementally from 0 - 10. A maximum of 11 models are stored when `model_dir` is not specified, post which the contents in `model_10/` directory are overwritten with the
latest parameters. ``marius_predict`` will use the latest model for inference and save the files to that directory. If ``storage.model_dir`` is specified, the model
parameters will be loaded from the given directory and the generated files will be saved to the same.
Example output
****************************
Expand Down
6 changes: 3 additions & 3 deletions src/cpp/python_bindings/storage/io_wrap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ void init_io(py::module &m) {

m.def("load_model", [](string filename, bool train) {

shared_ptr<MariusConfig> marius_config = loadConfig(filename, "");
shared_ptr<MariusConfig> marius_config = loadConfig(filename, train);

std::vector<torch::Device> devices = devices_from_config(marius_config->storage);

Expand All @@ -24,7 +24,7 @@ void init_io(py::module &m) {

m.def("load_storage", [](string filename, bool train) {

shared_ptr<MariusConfig> marius_config = loadConfig(filename, "");
shared_ptr<MariusConfig> marius_config = loadConfig(filename, train);

std::vector<torch::Device> devices = devices_from_config(marius_config->storage);

Expand All @@ -40,7 +40,7 @@ void init_io(py::module &m) {

m.def("init_from_config", [](string filename, bool train, bool load_storage) {

shared_ptr<MariusConfig> marius_config = loadConfig(filename, "");
shared_ptr<MariusConfig> marius_config = loadConfig(filename, train);

std::vector<torch::Device> devices = devices_from_config(marius_config->storage);

Expand Down
20 changes: 18 additions & 2 deletions src/python/tools/configuration/marius_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from marius.tools.configuration.validation import *
from dataclasses import field
import os
import re

from pathlib import Path
import shutil
Expand Down Expand Up @@ -827,7 +828,18 @@ def initialize_model_dir(output_config):
node_mapping_filepath = Path(output_config.storage.dataset.dataset_dir) / Path("nodes") / Path("node_mapping.txt")
if node_mapping_filepath.exists():
shutil.copy(str(node_mapping_filepath), "{}/{}".format(output_config.storage.model_dir, "node_mapping.txt"))


def infer_model_dir(output_config):
# if model_dir is of the form `model_x/`, where x belong to [0, 10], then set model_dir to the largest
# existing directory. If model_dir is user specified, no need to change it.
if re.fullmatch("{}model_[0-9]+/".format(output_config.storage.dataset.dataset_dir), output_config.storage.model_dir):
match_result = re.search(r".*/model_([0-9]+)/$", output_config.storage.model_dir)
last_model_id = -1
if len(match_result.groups()) == 1:
last_model_id = int(match_result.groups()[0]) - 1

if last_model_id >= 0:
output_config.storage.model_dir = "{}model_{}/".format(output_config.storage.dataset.dataset_dir, last_model_id)

cs = ConfigStore.instance()
cs.store(name="base_config", node=MariusConfig)
Expand Down Expand Up @@ -868,7 +880,11 @@ def load_config(input_config_path, save=False):

OmegaConf.save(output_config,
output_config.storage.model_dir + PathConstants.saved_full_config_file_name)

else:
# this path is taken in test cases where random configs are passed to this function for parsing.
# could also be taken when marius_predict is run.
infer_model_dir(output_config)

# we can then perform validation, and optimization over the fully specified configuration file here before returning
validate_dataset_config(output_config)
validate_storage_config(output_config)
Expand Down
4 changes: 4 additions & 0 deletions src/python/tools/marius_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,10 @@ def run_predict(args):
config = m.config.loadConfig(args.config)
metrics = get_metrics(config, args)

model_dir_path = pathlib.Path(config.storage.model_dir)
if not model_dir_path.exists():
raise RuntimeError("Path {} with model params doesn't exist.".format(str(model_dir_path)))

model: m.nn.Model = m.storage.load_model(args.config, train=False)
graph_storage: m.storage.GraphModelStorage = m.storage.load_storage(args.config, train=False)

Expand Down
40 changes: 39 additions & 1 deletion test/python/predict/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def setUp(self):
self.config_file = base_dir / Path(name) / Path(filename)

config = m.config.loadConfig(self.config_file.__str__(), True)
self.config_file = Path(config.storage.model_dir) / Path("full_config.yaml")
m.manager.marius_train(config)

@classmethod
Expand All @@ -155,6 +154,45 @@ def test_lp_metrics(self):
config = m.config.loadConfig(self.config_file.__str__(), save=False)
validate_metrics(config, ["MRR", "MEAN RANK", "HITS@1", "HITS@2", "HITS@3", "HITS@4", "HITS@5", "HITS@10", "HITS@20"], config.storage.dataset.num_test)

def test_predict_model_dir(self):
# 1st prediction pass, only model_0/ exists in this case and prediction uses the same directory.
parser = set_args()
args = parser.parse_args(["--config", self.config_file.__str__(), "--metrics", "mrr", "mr", "hits1", "hits2", "hits3", "hits4", "hits5", "hits10", "hits20"])
run_predict(args)

config = m.config.loadConfig(self.config_file.__str__(), save=False)

prediction_out_dir = config.storage.dataset.dataset_dir + "model_0/"
assert config.storage.model_dir == prediction_out_dir, "Prediction should have used {} directory".format(prediction_out_dir)
validate_metrics(config, ["MRR", "MEAN RANK", "HITS@1", "HITS@2", "HITS@3", "HITS@4", "HITS@5", "HITS@10", "HITS@20"], config.storage.dataset.num_test)

# 2st prediction pass, model_0/ and model_1/ exist in this case and prediction uses model_1/ directory.
config = m.config.loadConfig(self.config_file.__str__(), True)
m.manager.marius_train(config)
run_predict(args)

config = m.config.loadConfig(self.config_file.__str__(), save=False)

prediction_out_dir = config.storage.dataset.dataset_dir + "model_1/"
assert config.storage.model_dir == prediction_out_dir, "Prediction should have used {} directory".format(prediction_out_dir)
validate_metrics(config, ["MRR", "MEAN RANK", "HITS@1", "HITS@2", "HITS@3", "HITS@4", "HITS@5", "HITS@10", "HITS@20"], config.storage.dataset.num_test)

# specify model_dir path in the config. in this case, we set it to model_1/. even when you train another model which ends up getting stored in model_2/,
# model_predict will still use model_1/ because `model_dir` is explicitly specified in the config.
config = m.config.loadConfig(self.config_file.__str__(), True)
full_config_file = Path(config.storage.model_dir) / Path("full_config.yaml")
m.manager.marius_train(config)
config = m.config.loadConfig(self.config_file.__str__(), True)
args = parser.parse_args(["--config", full_config_file.__str__(), "--metrics", "mrr", "mr", "hits1", "hits2", "hits3", "hits4", "hits5", "hits10", "hits20"])
run_predict(args)

config = m.config.loadConfig(full_config_file.__str__(), save=False)

assert config.storage.model_dir == prediction_out_dir, "Prediction should have used {} directory".format(prediction_out_dir)
validate_metrics(config, ["MRR", "MEAN RANK", "HITS@1", "HITS@2", "HITS@3", "HITS@4", "HITS@5", "HITS@10", "HITS@20"], config.storage.dataset.num_test)



def test_lp_save_ranks(self):
pass

Expand Down

0 comments on commit 0f02da9

Please sign in to comment.