diff --git a/dialog_tag/DialogTag.py b/dialog_tag/DialogTag.py index d20e6ae..cc861c7 100644 --- a/dialog_tag/DialogTag.py +++ b/dialog_tag/DialogTag.py @@ -12,13 +12,13 @@ class DialogTag: - def __init__(self, model_name): + def __init__(self, model_name, model_path=None): self.__model_name = model_name self.__lib_path = f"{str(Path.home())}"+ model_location["MODEL"] - self.__model_path = os.path.join(self.__lib_path, self.__model_name) - self.__label_mapping_path = os.path.join(self.__lib_path, self.__model_name) + model_location["label_mapping"] + self.__model_path = model_path or os.path.join(self.__lib_path, self.__model_name) + self.__label_mapping_path = self.__model_path + model_location["label_mapping"] # print(self.__lib_path, self.__model_path, self.__label_mapping_path) path_exists = os.path.exists(self.__model_path) diff --git a/setup.py b/setup.py index 8afaa6b..f710733 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,6 @@ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], - python_requires='>=3.7', + python_requires='>=3.6', keywords="Tensorflow BERT NLP deep learning Transformer Networks " )