From 4d70f4079f7496d802e9cba914bf4a9923edf4e8 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Thu, 25 Apr 2024 11:59:47 +0200 Subject: [PATCH] ENH: add `Path` to allowed `load()` arguments --- docs/conf.py | 1 + src/qrules/io/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 2ef91d0c..1935d1fb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -73,6 +73,7 @@ def create_constraints_inventory() -> None: "NodeQuantumNumber": ("obj", "qrules.quantum_numbers.NodeQuantumNumber"), "NodeType": "typing.TypeVar", "ParticleWithSpin": ("obj", "qrules.particle.ParticleWithSpin"), + "Path": "pathlib.Path", "qrules.topology.EdgeType": "typing.TypeVar", "qrules.topology.NodeType": "typing.TypeVar", "SpinFormalism": ("obj", "qrules.transition.SpinFormalism"), diff --git a/src/qrules/io/__init__.py b/src/qrules/io/__init__.py index 915c45cb..27efe3a6 100644 --- a/src/qrules/io/__init__.py +++ b/src/qrules/io/__init__.py @@ -119,7 +119,7 @@ def asdot( return print_dot(instance) -def load(filename: str) -> object: +def load(filename: str | Path) -> object: with open(filename) as stream: file_extension = _get_file_extension(filename) if file_extension == "json": @@ -170,7 +170,7 @@ def write(instance: object, filename: str) -> None: raise NotImplementedError(msg) -def _get_file_extension(filename: str) -> str: +def _get_file_extension(filename: str | Path) -> str: path = Path(filename) extension = path.suffix.lower() if not extension: