Skip to content

Commit

Permalink
fix: more options for nlp.package() and better support for poetry inc…
Browse files Browse the repository at this point in the history
…lude fields
  • Loading branch information
percevalw committed Dec 1, 2023
1 parent 3c2e910 commit 61b6839
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 21 deletions.
6 changes: 5 additions & 1 deletion edsnlp/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,10 +885,12 @@ def package(
self,
name: Optional[str] = None,
root_dir: Union[str, Path] = ".",
build_dir: Union[str, Path] = "build",
dist_dir: Union[str, Path] = "dist",
artifacts_name: str = "artifacts",
check_dependencies: bool = False,
project_type: Optional[Literal["poetry", "setuptools"]] = None,
version: str = "0.1.0",
version: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = {},
distributions: Optional[Sequence[Literal["wheel", "sdist"]]] = ["wheel"],
config_settings: Optional[Mapping[str, Union[str, Sequence[str]]]] = None,
Expand All @@ -901,6 +903,8 @@ def package(
pipeline=self,
name=name,
root_dir=root_dir,
build_dir=build_dir,
dist_dir=dist_dir,
artifacts_name=artifacts_name,
check_dependencies=check_dependencies,
project_type=project_type,
Expand Down
4 changes: 3 additions & 1 deletion edsnlp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ class BatchSizeArg:
Examples
--------
```python
```{ .python .no-check }
def fn(batch_size: BatchSizeArg):
return batch_size
print(fn("10 samples"))
# Out: (10, "samples")
Expand All @@ -97,6 +98,7 @@ def fn(batch_size: BatchSizeArg):
print(fn(10))
# Out: (10, "samples")
```
"""

@classmethod
Expand Down
47 changes: 28 additions & 19 deletions edsnlp/utils/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,10 @@ def validate(cls, value, config=None):
print([
{k: v for k, v in {
"include": include._include,
"from": include.source,
"formats": include.formats,
}.items()}
"include": getattr(include, '_include'),
"from": getattr(include, 'source', None),
"formats": getattr(include, 'formats', None),
}.items() if v}
for include in builder._module.includes
])
Expand Down Expand Up @@ -209,11 +209,11 @@ def __init__(
self,
pyproject: Optional[Dict[str, Any]],
pipeline: Union[Path, "edsnlp.Pipeline"],
version: str,
version: Optional[str],
name: Optional[ModuleName],
root_dir: Path = ".",
build_name: Path = "build",
out_dir: Path = "dist",
build_dir: Path = "build",
dist_dir: Path = "dist",
artifacts_name: ModuleName = "artifacts",
dependencies: Optional[Sequence[Tuple[str, str]]] = None,
metadata: Optional[Dict[str, Any]] = {},
Expand All @@ -230,7 +230,9 @@ def __init__(
self.dependencies = dependencies
self.pipeline = pipeline
self.artifacts_name = artifacts_name
self.out_dir = self.root_dir / out_dir
self.dist_dir = (
dist_dir if Path(dist_dir).is_absolute() else self.root_dir / dist_dir
)

with self.ensure_pyproject(metadata):
python_executable = (
Expand All @@ -250,7 +252,9 @@ def __init__(
out = result.stdout.decode().strip().split("\n")

self.poetry_packages = eval(out[0])
self.build_dir = root_dir / build_name / self.name
self.build_dir = (
build_dir if Path(build_dir).is_absolute() else root_dir / build_dir
) / self.name
self.file_paths = [self.root_dir / file_path for file_path in out[1:]]

logger.info(f"root_dir: {self.root_dir}")
Expand All @@ -276,7 +280,7 @@ def ensure_pyproject(self, metadata):
"poetry": {
**metadata,
"name": self.name,
"version": self.version,
"version": self.version or "0.1.0",
"dependencies": {
"python": f">={py_version},<4.0",
**{
Expand Down Expand Up @@ -333,7 +337,7 @@ def build(
distributions = ["wheel"]
build_call(
srcdir=self.build_dir,
outdir=self.out_dir,
outdir=self.dist_dir,
distributions=distributions,
config_settings=config_settings,
isolation=isolation,
Expand All @@ -349,12 +353,13 @@ def update_pyproject(self):
f"project"
)

old_version = self.pyproject["tool"]["poetry"]["version"]
self.pyproject["tool"]["poetry"]["version"] = self.version
logger.info(
f"Replaced project version {old_version!r} with {self.version!r} in poetry "
f"based project"
)
if self.version is not None:
old_version = self.pyproject["tool"]["poetry"]["version"]
self.pyproject["tool"]["poetry"]["version"] = self.version
logger.info(
f"Replaced project version {old_version!r} with {self.version!r} in "
f"poetry based project"
)

# Adding artifacts to include in pyproject.toml
snake_name = snake_case(self.name.lower())
Expand All @@ -380,7 +385,7 @@ def make_src_dir(self):
"remove it from the pyproject.toml metadata."
)
os.makedirs(new_file_path.parent, exist_ok=True)
logger.info(f"COPY {file_path} TO {new_file_path}")
logger.info(f"COPY {file_path}" f"TO {new_file_path}")
shutil.copy(file_path, new_file_path)

self.update_pyproject()
Expand Down Expand Up @@ -411,10 +416,12 @@ def package(
pipeline: Union[Path, "edsnlp.Pipeline"],
name: Optional[ModuleName] = None,
root_dir: Path = ".",
build_dir: Path = "build",
dist_dir: Path = "dist",
artifacts_name: ModuleName = "artifacts",
check_dependencies: bool = False,
project_type: Optional[Literal["poetry", "setuptools"]] = None,
version: str = "0.1.0",
version: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = {},
distributions: Optional[Sequence[Literal["wheel", "sdist"]]] = ["wheel"],
config_settings: Optional[Mapping[str, Union[str, Sequence[str]]]] = None,
Expand Down Expand Up @@ -456,6 +463,8 @@ def package(
name=name,
version=version,
root_dir=root_dir,
build_dir=build_dir,
dist_dir=dist_dir,
artifacts_name=artifacts_name,
dependencies=dependencies,
metadata=metadata,
Expand Down

0 comments on commit 61b6839

Please sign in to comment.