diff --git a/edsnlp/core/pipeline.py b/edsnlp/core/pipeline.py index 38f2d716d..bd12c0fa5 100644 --- a/edsnlp/core/pipeline.py +++ b/edsnlp/core/pipeline.py @@ -15,6 +15,8 @@ Dict, Iterable, List, + Literal, + Mapping, Optional, Sequence, Set, @@ -257,7 +259,7 @@ def add_pipe( if hasattr(pipe, "name"): if name is not None and name != pipe.name: raise ValueError( - "The provided name does not match the name of the component." + f"The provided name {name!r} does not match the name of the component {pipe.name!r}." ) else: name = pipe.name @@ -828,6 +830,9 @@ def deserialize_tensors(path: Path): import safetensors.torch torch_components = dict(self.torch_components()) + if len(torch_components) == 0 and not path.exists(): + return + for file_name in path.iterdir(): pipe_names = file_name.stem.split("+") if any(pipe_name in torch_components for pipe_name in pipe_names): @@ -957,6 +962,37 @@ def __exit__(ctx_self, type, value, traceback): self._disabled = disable return context() + def package( + self, + name: Optional[str] = None, + root_dir: Union[str, Path] = ".", + artifacts_name: str = "artifacts", + check_dependencies: bool = False, + project_type: Optional[Literal["poetry", "setuptools"]] = None, + version: str = "0.1.0", + metadata: Optional[Dict[str, Any]] = {}, + distributions: Optional[Sequence[Literal["wheel", "sdist"]]] = ["wheel"], + config_settings: Optional[Mapping[str, Union[str, Sequence[str]]]] = None, + isolation: bool = True, + skip_build_dependency_check: bool = False, + ): + from edsnlp.utils.package import package + + return package( + pipeline=self, + name=name, + root_dir=root_dir, + artifacts_name=artifacts_name, + check_dependencies=check_dependencies, + project_type=project_type, + version=version, + metadata=metadata, + distributions=distributions, + config_settings=config_settings, + isolation=isolation, + skip_build_dependency_check=skip_build_dependency_check, + ) + def blank( lang: str, diff --git a/edsnlp/pipelines/misc/measurements/measurements.py b/edsnlp/pipelines/misc/measurements/measurements.py index 859843972..50e2a99e2 100644 --- a/edsnlp/pipelines/misc/measurements/measurements.py +++ b/edsnlp/pipelines/misc/measurements/measurements.py @@ -529,10 +529,8 @@ def __init__( self.after_snippet_limit = after_snippet_limit # MEASURES - for measure_config in measurements: - name = measure_config["name"] - unit = measure_config["unit"] - self.measure_names[self.unit_registry.parse_unit(unit)[0]] = name + for m in measurements: + self.measure_names[self.unit_registry.parse_unit(m["unit"])[0]] = m["name"] if span_setter is None: span_setter = { diff --git a/edsnlp/utils/package.py b/edsnlp/utils/package.py new file mode 100644 index 000000000..8d84002d4 --- /dev/null +++ b/edsnlp/utils/package.py @@ -0,0 +1,461 @@ +import io +import os +import re +import shutil +import subprocess +import sys +from contextlib import contextmanager +from pathlib import Path +from types import FunctionType +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +import build +import dill +import toml +from build.__main__ import build_package, build_package_via_sdist +from confit import Cli +from dill._dill import save_function as dill_save_function +from dill._dill import save_type as dill_save_type +from importlib_metadata import PackageNotFoundError +from importlib_metadata import version as get_version +from loguru import logger +from typing_extensions import Literal + +import edsnlp + +py_version = f"{sys.version_info.major}.{sys.version_info.minor}" + + +def get_package(obj_type: Type): + # Retrieve the __package__ attribute of the module of a type, if possible. + # And returns the package version as well + try: + module_name = obj_type.__module__ + if module_name == "__main__": + raise Exception(f"Could not find package of type {obj_type}") + module = __import__(module_name, fromlist=["__package__"]) + package = module.__package__ + try: + version = get_version(package) + except (PackageNotFoundError, ValueError): + return None + return package, version + except (ImportError, AttributeError): + raise Exception(f"Cound not find package of type {obj_type}") + + +def save_type(pickler, obj, *args, **kwargs): + package_name = get_package(obj) + if package_name is not None: + pickler.packages.add(package_name) + dill_save_type(pickler, obj, *args, **kwargs) + + +def save_function(pickler, obj, *args, **kwargs): + package_name = get_package(obj) + if package_name is not None: + pickler.packages.add(package_name) + return dill_save_function(pickler, obj, *args, **kwargs) + + +class PackagingPickler(dill.Pickler): + dispatch = dill.Pickler.dispatch.copy() + + dispatch[FunctionType] = save_function + dispatch[type] = save_type + + def __init__(self, *args, **kwargs): + self.file = io.BytesIO() + super().__init__(self.file, *args, **kwargs) + self.packages = set() + + +def get_deep_dependencies(obj): + pickler = PackagingPickler() + pickler.dump(obj) + return sorted(pickler.packages) + + +app = Cli(pretty_exceptions_show_locals=False, pretty_exceptions_enable=False) + + +def snake_case(s): + # From https://www.w3resource.com/python-exercises/string/python-data-type-string-exercise-97.php # noqa E501 + return "_".join( + re.sub( + "([A-Z][a-z]+)", r" \1", re.sub("([A-Z]+)", r" \1", s.replace("-", " ")) + ).split() + ).lower() + + +class ModuleName(str): + def __new__(cls, *args, **kwargs): + raise NotImplementedError("ModuleName is only meant for typing.") + + @classmethod + def __get_validators__(self): + yield self.validate + + @classmethod + def validate(cls, value, config=None): + if not isinstance(value, str): + raise TypeError("string required") + + if not re.match( + r"^([A-Z0-9]|[A-Z0-9][A-Z0-9._-]*[A-Z0-9])$", value, flags=re.IGNORECASE + ): + raise ValueError("invalid identifier") + return value + + +if TYPE_CHECKING: + ModuleName = str # noqa F811 + +POETRY_SNIPPET = """\ +from poetry.core.masonry.builders.sdist import SdistBuilder +from poetry.factory import Factory +from poetry.core.masonry.utils.module import ModuleOrPackageNotFound +import sys +# Initialize the Poetry object for the current project +poetry = Factory().create_poetry("__root_dir__") + +# Initialize the builder +try: + builder = SdistBuilder(poetry, None, None) +except ModuleOrPackageNotFound: + if not poetry.package.packages: + print([]) + sys.exit(0) + +print([ + {k: v for k, v in { + "include": include._include, + "from": include.source, + "formats": include.formats, + }.items()} + for include in builder._module.includes +]) + +# Get the list of files to include +files = builder.find_files_to_add() + +# Print the list of files +for file in files: + print(file.path) +""" + +INIT_PY = """ +# ----------------------------------------- +# This section was autogenerated by edsnlp +# ----------------------------------------- + +import edsnlp +from pathlib import Path + +__version__ = {__version__} + +def load(device: "torch.device" = "cpu") -> edsnlp.Pipeline: + artifacts_path = Path(__file__).parent / "{artifacts_dir}" + model = edsnlp.load(artifacts_path).to(device) + return model +""" + + +# def parse_authors_as_dicts(authors): +# authors = [authors] if isinstance(authors, str) else authors +# return [ +# dict(zip(("name", "email"), re.match(r"(.*) <(.*)>", author).groups())) +# if isinstance(author, str) +# else author +# for author in authors +# ] + + +def parse_authors_as_strings(authors): + authors = [authors] if isinstance(authors, str) else authors + return [ + author if isinstance(author, str) else f"{author['name']} <{author['email']}>" + for author in authors + ] + + +class PoetryPackager: + def __init__( + self, + pyproject: Optional[Dict[str, Any]], + pipeline: Union[Path, "edsnlp.Pipeline"], + version: str, + name: Optional[ModuleName], + root_dir: Path = ".", + build_name: Path = "build", + out_dir: Path = "dist", + artifacts_name: ModuleName = "artifacts", + dependencies: Optional[Sequence[Tuple[str, str]]] = None, + metadata: Optional[Dict[str, Any]] = {}, + ): + self.poetry_bin_path = ( + subprocess.run(["which", "poetry"], stdout=subprocess.PIPE) + .stdout.decode() + .strip() + ) + self.version = version + self.name = name + self.pyproject = pyproject + self.root_dir = root_dir.resolve() + self.dependencies = dependencies + self.pipeline = pipeline + self.artifacts_name = artifacts_name + self.out_dir = self.root_dir / out_dir + + with self.ensure_pyproject(metadata): + + python_executable = ( + Path(self.poetry_bin_path).read_text().split("\n")[0][2:] + ) + result = subprocess.run( + [ + *python_executable.split(), + "-c", + POETRY_SNIPPET.replace("__root_dir__", str(self.root_dir)), + ], + stdout=subprocess.PIPE, + cwd=self.root_dir, + ) + if result.returncode != 0: + raise Exception() + out = result.stdout.decode().strip().split("\n") + + self.poetry_packages = eval(out[0]) + self.build_dir = root_dir / build_name / self.name + self.file_paths = [self.root_dir / file_path for file_path in out[1:]] + + logger.info(f"root_dir: {self.root_dir}") + logger.info(f"build_dir: {self.build_dir}") + logger.info(f"artifacts_name: {self.artifacts_name}") + logger.info(f"name: {self.name}") + + @contextmanager + def ensure_pyproject(self, metadata): + """Generates a Poetry based pyproject.toml""" + metadata = dict(metadata) + new_pyproject = self.pyproject is None + if "authors" in metadata: + metadata["authors"] = parse_authors_as_strings(metadata["authors"]) + try: + if new_pyproject: + self.pyproject = { + "build-system": { + "requires": ["poetry-core>=1.0.0"], + "build-backend": "poetry.core.masonry.api", + }, + "tool": { + "poetry": { + **metadata, + "name": self.name, + "version": self.version, + "dependencies": { + "python": f">={py_version},<4.0", + **{ + dep_name: f"^{dep_version}" + for dep_name, dep_version in self.dependencies + }, + }, + }, + }, + } + (self.root_dir / "pyproject.toml").write_text( + toml.dumps(self.pyproject) + ) + else: + self.name = ( + self.pyproject["tool"]["poetry"]["name"] + if self.name is None + else self.name + ) + for key, value in metadata.items(): + pyproject_value = self.pyproject["tool"]["poetry"].get(key) + if pyproject_value != metadata[key]: + raise ValueError( + f"Field {key} in pyproject.toml doesn't match the one " + f"passed as argument, you should remove it from the " + f"metadata parameter. Avoid using metadata if you already " + f"have a pyproject.toml file.\n" + f"pyproject.toml:\n {pyproject_value}\n" + f"metadata:\n {value}" + ) + yield + except Exception: + if new_pyproject: + os.remove(self.root_dir / "pyproject.toml") + raise + + def list_files_to_add(self): + # Extract python from the shebang in the poetry executable + return self.file_paths + + def build( + self, + distributions: Sequence[str] = (), + config_settings: Optional[build.ConfigSettingsType] = None, + isolation: bool = True, + skip_dependency_check: bool = False, + ): + logger.info(f"Building package {self.name}") + + if distributions: + build_call = build_package + else: + build_call = build_package_via_sdist + distributions = ["wheel"] + build_call( + srcdir=self.build_dir, + outdir=self.out_dir, + distributions=distributions, + config_settings=config_settings, + isolation=isolation, + skip_dependency_check=skip_dependency_check, + ) + + def update_pyproject(self): + # Replacing project name + old_name = self.pyproject["tool"]["poetry"]["name"] + self.pyproject["tool"]["poetry"]["name"] = self.name + logger.info( + f"Replaced project name {old_name!r} with {self.name!r} in poetry based " + f"project" + ) + + old_version = self.pyproject["tool"]["poetry"]["version"] + self.pyproject["tool"]["poetry"]["version"] = self.version + logger.info( + f"Replaced project version {old_version!r} with {self.version!r} in poetry " + f"based project" + ) + + # Adding artifacts to include in pyproject.toml + snake_name = snake_case(self.name.lower()) + included = self.pyproject["tool"]["poetry"].setdefault("include", []) + included.append(f"{snake_name}/{self.artifacts_name}/**") + + packages = list(self.poetry_packages) + packages.append({"include": snake_name}) + self.pyproject["tool"]["poetry"]["packages"] = packages + + def make_src_dir(self): + snake_name = snake_case(self.name.lower()) + package_dir = self.build_dir / snake_name + shutil.rmtree(package_dir, ignore_errors=True) + os.makedirs(package_dir, exist_ok=True) + build_artifacts_dir = package_dir / self.artifacts_name + for file_path in self.list_files_to_add(): + new_file_path = self.build_dir / Path(file_path).relative_to(self.root_dir) + if isinstance(self.pipeline, Path) and self.pipeline in file_path.parents: + raise Exception( + f"Pipeline ({self.artifacts_name}) is already " + "included in the package's data, you should " + "remove it from the pyproject.toml metadata." + ) + os.makedirs(new_file_path.parent, exist_ok=True) + logger.info(f"COPY {file_path} TO {new_file_path}") + shutil.copy(file_path, new_file_path) + + self.update_pyproject() + + # Write pyproject.toml + (self.build_dir / "pyproject.toml").write_text(toml.dumps(self.pyproject)) + + if isinstance(self.pipeline, Path): + # self.pipeline = edsnlp.load(self.pipeline) + shutil.copytree( + self.pipeline, + build_artifacts_dir, + ) + else: + self.pipeline.to_disk(build_artifacts_dir) + os.makedirs(package_dir, exist_ok=True) + with open(package_dir / "__init__.py", mode="a") as f: + f.write( + INIT_PY.format( + __version__=repr(self.version), + artifacts_dir=os.path.relpath(build_artifacts_dir, package_dir), + ) + ) + + +@app.command(name="package") +def package( + pipeline: Union[Path, "edsnlp.Pipeline"], + name: Optional[ModuleName] = None, + root_dir: Path = ".", + artifacts_name: ModuleName = "artifacts", + check_dependencies: bool = False, + project_type: Optional[Literal["poetry", "setuptools"]] = None, + version: str = "0.1.0", + metadata: Optional[Dict[str, Any]] = {}, + distributions: Optional[Sequence[Literal["wheel", "sdist"]]] = ["wheel"], + config_settings: Optional[Mapping[str, Union[str, Sequence[str]]]] = None, + isolation: bool = True, + skip_build_dependency_check: bool = False, +): + # root_dir = Path(".").resolve() + pyproject_path = root_dir / "pyproject.toml" + + if not pyproject_path.exists(): + check_dependencies = True + if name is None: + raise ValueError( + f"No pyproject.toml could be found in the root directory {root_dir}, " + f"you need to create one, or fill the name parameter." + ) + + dependencies = None + if check_dependencies: + if isinstance(pipeline, Path): + pipeline = edsnlp.load(pipeline) + dependencies = get_deep_dependencies(pipeline) + for dep in dependencies: + print("DEPENDENCY", dep[0].ljust(30), dep[1]) + + root_dir = root_dir.resolve() + + pyproject = None + if pyproject_path.exists(): + pyproject = toml.loads((root_dir / "pyproject.toml").read_text()) + + if "tool" in pyproject and "poetry" in pyproject["tool"]: + project_type = "poetry" + + if project_type == "poetry": + packager = PoetryPackager( + pyproject=pyproject, + pipeline=pipeline, + name=name, + version=version, + root_dir=root_dir, + artifacts_name=artifacts_name, + dependencies=dependencies, + metadata=metadata, + ) + else: + raise Exception( + "Could not infer project type, only poetry based projects are " + "supported for now" + ) + + packager.make_src_dir() + packager.build( + distributions=distributions, + config_settings=config_settings, + isolation=isolation, + skip_dependency_check=skip_build_dependency_check, + ) diff --git a/tests/utils/test_package.py b/tests/utils/test_package.py new file mode 100644 index 000000000..058abd5f3 --- /dev/null +++ b/tests/utils/test_package.py @@ -0,0 +1,185 @@ +import importlib +import subprocess +import sys + +import pytest +import torch + +import edsnlp +from edsnlp.utils.package import package + + +def test_blank_package(nlp, tmp_path): + # Missing metadata makes poetry fail due to missing author / description + if not isinstance(nlp, edsnlp.Pipeline): + pytest.skip("Only running for edsnlp.Pipeline") + + with pytest.raises(Exception): + package( + pipeline=nlp, + root_dir=tmp_path, + name="test-model", + metadata={}, + project_type="poetry", + ) + + nlp.package( + root_dir=tmp_path, + name="test-model", + metadata={ + "description": "A test model", + "authors": "Test Author ", + }, + project_type="poetry", + distributions=["wheel"], + ) + assert (tmp_path / "dist").is_dir() + assert (tmp_path / "build").is_dir() + assert (tmp_path / "dist" / "test_model-0.1.0-py3-none-any.whl").is_file() + assert not (tmp_path / "dist" / "test_model-0.1.0.tar.gz").is_file() + assert (tmp_path / "build" / "test-model").is_dir() + + +@pytest.mark.parametrize("package_name", ["my-test-model", None]) +def test_package_with_files(nlp, tmp_path, package_name): + if not isinstance(nlp, edsnlp.Pipeline): + pytest.skip("Only running for edsnlp.Pipeline") + + nlp.to_disk(tmp_path / "model") + + ((tmp_path / "test_model").mkdir(parents=True)) + (tmp_path / "test_model" / "__init__.py").write_text( + """\ +print("Hello World!") +""" + ) + (tmp_path / "pyproject.toml").write_text( + """\ +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "test-model" +version = "0.0.0" +description = "A test model" +authors = ["Test Author "] + +[tool.poetry.dependencies] +python = "^3.7" +torch = "^{}" +""".format( + torch.__version__.split("+")[0] + ) + ) + + with pytest.raises(ValueError): + package( + pipeline=nlp, + root_dir=tmp_path, + version="0.1.0", + name=package_name, + metadata={ + "description": "Wrong description", + "authors": "Test Author ", + }, + ) + + package( + name=package_name, + pipeline=tmp_path / "model", + root_dir=tmp_path, + check_dependencies=True, + version="0.1.0", + distributions=None, + metadata={ + "description": "A test model", + "authors": "Test Author ", + }, + ) + + module_name = "test_model" if package_name is None else "my_test_model" + + assert (tmp_path / "dist").is_dir() + assert (tmp_path / "dist" / f"{module_name}-0.1.0.tar.gz").is_file() + assert (tmp_path / "dist" / f"{module_name}-0.1.0-py3-none-any.whl").is_file() + assert (tmp_path / "pyproject.toml").is_file() + + # pip install the whl file + print( + subprocess.check_output( + [ + sys.executable, + "-m", + "pip", + "install", + str(tmp_path / "dist" / f"{module_name}-0.1.0-py3-none-any.whl"), + "--force-reinstall", + ], + stderr=subprocess.STDOUT, + ) + ) + + module = importlib.import_module(module_name) + + assert module.__version__ == "0.1.0" + + with open(module.__file__) as f: + assert f.read() == ( + ( + """\ +print("Hello World!") +""" + if package_name is None + else "" + ) + + """ +# ----------------------------------------- +# This section was autogenerated by edsnlp +# ----------------------------------------- + +import edsnlp +from pathlib import Path + +__version__ = '0.1.0' + +def load(device: "torch.device" = "cpu") -> edsnlp.Pipeline: + artifacts_path = Path(__file__).parent / "artifacts" + model = edsnlp.load(artifacts_path).to(device) + return model +""" + ) + module.load() + + +@pytest.fixture(scope="session", autouse=True) +def clean_after(): + yield + + print( + subprocess.check_output( + [ + sys.executable, + "-m", + "pip", + "uninstall", + "-y", + "test-model", + ], + stderr=subprocess.STDOUT, + ) + ) + + print( + subprocess.check_output( + [ + sys.executable, + "-m", + "pip", + "uninstall", + "-y", + "my-test-model", + ], + stderr=subprocess.STDOUT, + ) + )