From 01c631bf93437f9b9aefa77a8b6402efbf1693e1 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Sun, 4 Aug 2024 00:43:47 +0300 Subject: [PATCH] add Nakdimon.h5 to package data Signed-off-by: Elazar Gershuni --- examples/usage.py | 5 +---- {models => nakdimon}/Nakdimon.h5 | Bin nakdimon/__init__.py | 8 ++++---- nakdimon/config.py | 4 ++++ nakdimon/external_apis.py | 3 ++- nakdimon/predict.py | 19 +++++++++++-------- nakdimon/pretrain.py | 8 ++++---- nakdimon/server.py | 5 +++-- nakdimon/train.py | 3 ++- pyproject.toml | 5 ++++- 10 files changed, 35 insertions(+), 25 deletions(-) rename {models => nakdimon}/Nakdimon.h5 (100%) create mode 100644 nakdimon/config.py diff --git a/examples/usage.py b/examples/usage.py index 470fc8af..f1f64d5f 100644 --- a/examples/usage.py +++ b/examples/usage.py @@ -1,9 +1,6 @@ # pip install git+https://github.com/nakdimon/nakdimon.git -# mkdir models -# wget https://github.com/elazarg/nakdimon/raw/master/models/Nakdimon.h5 -# mv Nakdimon.h5 models/Nakdimon.h5 from nakdimon import diacritize -result = diacritize("שלום עולם!", "models/Nakdimon.h5") +result = diacritize("שלום עולם!") print(result) diff --git a/models/Nakdimon.h5 b/nakdimon/Nakdimon.h5 similarity index 100% rename from models/Nakdimon.h5 rename to nakdimon/Nakdimon.h5 diff --git a/nakdimon/__init__.py b/nakdimon/__init__.py index fbb7d647..8cc439bb 100644 --- a/nakdimon/__init__.py +++ b/nakdimon/__init__.py @@ -2,7 +2,7 @@ import sys import os import logging - +from nakdimon.config import MAIN_MODEL def do_train(**kwargs) -> None: from nakdimon import train @@ -48,7 +48,7 @@ def main() -> None: parser_train = subparsers.add_parser('train', help='train Nakdimon') parser_train.add_argument('--wandb', action='store_true', help='use wandb.', default=False) - parser_train.add_argument('--model', help='path to output model (.h5 file)', default='models/Full.h5', dest='model_path') + parser_train.add_argument('--model', help='path to output model (.h5 file)', default=MAIN_MODEL, dest='model_path') parser_train.add_argument('--ablation', help='ablation test', default=None, dest='ablation_name') parser_train.set_defaults(func=do_train) @@ -60,7 +60,7 @@ def main() -> None: parser_test = subparsers.add_parser('run_test', help='diacritize a test set') parser_test.add_argument('--test_set', choices=available_tests, help='choose test set', default='tests/new') parser_test.add_argument('--system', choices=test_systems, help='diacritization system to use', default='Nakdimon') - parser_test.add_argument('--model', help='path to model (.h5 file)', default='models/Nakdimon.h5', dest='model_path') + parser_test.add_argument('--model', help='path to model (.h5 file)', default=MAIN_MODEL, dest='model_path') parser_test.add_argument('--skip-existing', action='store_true', help='skip existing files') parser_test.set_defaults(func=do_run_test) @@ -114,6 +114,6 @@ def diacritize_main(): sys.exit(0) -def diacritize(text: str, model_path: str = 'models/Nakdimon.h5') -> str: +def diacritize(text: str, model_path: str = MAIN_MODEL) -> str: import nakdimon.predict return nakdimon.predict.predict(text, model_path) diff --git a/nakdimon/config.py b/nakdimon/config.py new file mode 100644 index 00000000..af50e5eb --- /dev/null +++ b/nakdimon/config.py @@ -0,0 +1,4 @@ +from importlib.resources import files + +MODELS_DIR = 'models' +MAIN_MODEL = files('nakdimon').joinpath('Nakdimon.h5') diff --git a/nakdimon/external_apis.py b/nakdimon/external_apis.py index 9f377690..a71e2432 100644 --- a/nakdimon/external_apis.py +++ b/nakdimon/external_apis.py @@ -10,6 +10,7 @@ from nakdimon.hebrew import Niqqud from nakdimon import hebrew +from nakdimon.config import MAIN_MODEL class DottingError(RuntimeError): @@ -184,7 +185,7 @@ def run_nakdimon(text: str) -> str: 'Snopi': fetch_snopi, # Too slow 'Morfix': fetch_morfix, # terms-of-use issue 'Dicta': fetch_dicta, - 'Nakdimon': make_nakdimon_no_server('models/Nakdimon.h5'), + 'Nakdimon': make_nakdimon_no_server(MAIN_MODEL), } all_oov = set() diff --git a/nakdimon/predict.py b/nakdimon/predict.py index 34ffb560..715bc911 100644 --- a/nakdimon/predict.py +++ b/nakdimon/predict.py @@ -1,17 +1,21 @@ import logging +import pathlib from functools import lru_cache import tensorflow as tf from nakdimon import utils, dataset, hebrew - +from nakdimon.config import MAIN_MODEL if tf.config.set_visible_devices([], 'GPU'): logging.warning('No GPU available.') @lru_cache() -def load_cached_model(m): +def load_cached_model(m: pathlib.Path | str) -> tf.Module: + if isinstance(m, str): + return load_cached_model(pathlib.Path(m)) + assert isinstance(m, pathlib.Path) model = tf.keras.models.load_model(m, custom_objects={'loss': None}) return model @@ -31,13 +35,12 @@ def merge_unconditional(texts, tnss, nss, dss, sss): return res -def predict(text: str, model_or_model_path: tf.Module|str = 'models/Nakdimon.h5', maxlen=10000) -> str: - if isinstance(model_or_model_path, str): - model = load_cached_model(model_or_model_path) - elif isinstance(model_or_model_path, tf.Module): - model = model_or_model_path - else: +def predict(text: str, model_or_model_path: tf.Module | str = MAIN_MODEL, maxlen=10000) -> str: + if isinstance(model_or_model_path, (pathlib.Path, str)): + model_or_model_path = load_cached_model(model_or_model_path) + if not isinstance(model_or_model_path, tf.Module): raise TypeError(f'Expected str or tf.Module, got {type(model_or_model_path)}') + model = model_or_model_path data = dataset.Data.from_text(hebrew.iterate_dotted_text(text), maxlen) prediction = model.predict(data.normalized) [actual_niqqud, actual_dagesh, actual_sin] = [dataset.from_categorical(prediction[0]), dataset.from_categorical(prediction[1]), dataset.from_categorical(prediction[2])] diff --git a/nakdimon/pretrain.py b/nakdimon/pretrain.py index 1c9c2075..188d2e1f 100644 --- a/nakdimon/pretrain.py +++ b/nakdimon/pretrain.py @@ -9,9 +9,9 @@ from nakdimon import hebrew from nakdimon.train import TrainingParams from nakdimon import metrics +from nakdimon.config import MODELS_DIR - -pretrain_path = f'./models/wiki' +pretrain_path = f'{MODELS_DIR}/wiki' model_name = pretrain_path + 'pretrain.h5' @@ -138,7 +138,7 @@ def pretrain(): def train_ablation(params): from train import train model = train(params) - model.save(f'./models/ablations/{params.name}.h5') + model.save(f'./{MODELS_DIR}/ablations/{params.name}.h5') if __name__ == '__main__': @@ -152,7 +152,7 @@ def train_ablation(params): import ablations tf.config.set_visible_devices([], 'GPU') model_name = 'PretrainedModernOnly' - model = tf.keras.models.load_model(f'models/ablations/{model_name}.h5', + model = tf.keras.models.load_model(f'{MODELS_DIR}/ablations/{model_name}.h5', custom_objects={'loss': TrainingParams().loss}) print(model_name, *metrics.metricwise_mean(ablations.calculate_metrics(model)).values(), sep=', ') diff --git a/nakdimon/server.py b/nakdimon/server.py index 0bda1533..91f77c4a 100644 --- a/nakdimon/server.py +++ b/nakdimon/server.py @@ -4,6 +4,7 @@ import logging from nakdimon import predict +from nakdimon.config import MAIN_MODEL app = flask.Flask(__name__) @@ -26,9 +27,9 @@ def diacritize(): def main(): - logging.info("Loading models/Nakdimon.h5") + logging.info(f"Loading {MAIN_MODEL}") try: - predict.predict("שלום", 'models/Nakdimon.h5') + predict.predict("שלום") logging.info("Done loading.") except OSError: logging.warning("Could not load default model") diff --git a/nakdimon/train.py b/nakdimon/train.py index cacca677..332b9c32 100644 --- a/nakdimon/train.py +++ b/nakdimon/train.py @@ -12,6 +12,7 @@ from nakdimon.dataset import NIQQUD_SIZE, DAGESH_SIZE, SIN_SIZE, LETTERS_SIZE from nakdimon import schedulers from nakdimon import transformer +from nakdimon.config import MODELS_DIR # assert tf.config.list_physical_devices('GPU') @@ -270,7 +271,7 @@ def train(params: NakdimonParams, group, ablation=False, wandb_enabled=False): def train_ablation(params, group): model = train(params, group, ablation=True) - model.save(f'./models/ablations/{params.name}.h5') + model.save(f'./{MODELS_DIR}/ablations/{params.name}.h5') class Full(NakdimonParams): diff --git a/pyproject.toml b/pyproject.toml index 05b353d1..b8263f64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "nakdimon" -version = "0.1.1" +version = "0.1.2" authors = [ { name="Elazar Gershuni", email="elazarg@gmail.com" }, ] @@ -28,6 +28,9 @@ include = ["nakdimon"] exclude = [] # exclude packages matching these glob patterns (empty by default) namespaces = false +[tool.setuptools.package-data] +nakdimon = ["Nakdimon.h5"] + [project.urls] "Homepage" = "https://github.com/elazarg/nakdimon" "Bug Tracker" = "https://github.com/elazarg/nakdimon/issues"