Skip to content

Commit

Permalink
Add script to generate the stub file (#44)
Browse files Browse the repository at this point in the history
* Add script to generate the stub file

* Write specs in Python instead of parsing comments

* Move generate-script into src/ directory

* Move definition of types to top of script for visibility

* Use a as argument name when n_in == 1 if there are optional args
  • Loading branch information
gahjelle authored Nov 28, 2024
1 parent da0d66d commit 3271bd0
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 6 deletions.
85 changes: 85 additions & 0 deletions src/generate_spherely_vfunc_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import itertools
import string
from pathlib import Path

VFUNC_TYPE_SPECS = {
"_VFunc_Nin1_Nout1": {"n_in": 1},
"_VFunc_Nin2_Nout1": {"n_in": 2},
"_VFunc_Nin2optradius_Nout1": {"n_in": 2, "radius": "float"},
"_VFunc_Nin1optradius_Nout1": {"n_in": 1, "radius": "float"},
}

STUB_FILE_PATH = Path(__file__).parent / "spherely.pyi"
BEGIN_MARKER = "# /// Begin types"
END_MARKER = "# /// End types"


def update_stub_file(path, **type_specs):
stub_text = path.read_text(encoding="utf-8")
try:
start_idx = stub_text.index(BEGIN_MARKER)
end_idx = stub_text.index(END_MARKER)
except ValueError:
raise SystemExit(
f"Error: Markers '{BEGIN_MARKER}' and '{END_MARKER}' "
f"were not found in stub file '{path}'"
) from None

header = f"{BEGIN_MARKER}\n"
code = "\n\n".join(
_vfunctype_factory(name, **args) for name, args in type_specs.items()
)
updated_stub_text = stub_text[:start_idx] + header + code + stub_text[end_idx:]
path.write_text(updated_stub_text, encoding="utf-8")


def _vfunctype_factory(class_name, n_in, **optargs):
"""Create new VFunc types.
Based on the number of input arrays and optional arguments and their types."""
arg_names = (
["geography"]
if n_in == 1 and not optargs
else list(string.ascii_lowercase[:n_in])
)
class_code = [
f"class {class_name}(",
" Generic[_NameType, _ScalarReturnType, _ArrayReturnDType]",
"):",
" @property",
" def __name__(self) -> _NameType: ...",
"",
]
optarg_str = ", ".join(
f"{arg_name}: {arg_type} = ..." for arg_name, arg_type in optargs.items()
)

geog_types = ["Geography", "npt.ArrayLike"]
for arg_types in itertools.product(geog_types, repeat=n_in):
arg_str = ", ".join(
f"{arg_name}: {arg_type}"
for arg_name, arg_type in zip(arg_names, arg_types)
)
return_type = (
"_ScalarReturnType"
if all(t == geog_types[0] for t in arg_types)
else "npt.NDArray[_ArrayReturnDType]"
)
class_code.extend(
[
" @overload",
" def __call__(",
(
f" self, {arg_str}, {optarg_str}"
if optarg_str
else f" self, {arg_str}"
),
f" ) -> {return_type}: ...",
"",
]
)
return "\n".join(class_code)


if __name__ == "__main__":
update_stub_file(path=STUB_FILE_PATH, **VFUNC_TYPE_SPECS)
19 changes: 13 additions & 6 deletions src/spherely.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ _NameType = TypeVar("_NameType", bound=str)
_ScalarReturnType = TypeVar("_ScalarReturnType", bound=Any)
_ArrayReturnDType = TypeVar("_ArrayReturnDType", bound=Any)

# The following types are auto-generated. Please don't edit them by hand.
# Instead, update the generate_spherely_vfunc_types.py script and run it
# to update the types.
#
# /// Begin types
class _VFunc_Nin1_Nout1(Generic[_NameType, _ScalarReturnType, _ArrayReturnDType]):
@property
def __name__(self) -> _NameType: ...
Expand All @@ -91,15 +96,15 @@ class _VFunc_Nin2_Nout1(Generic[_NameType, _ScalarReturnType, _ArrayReturnDType]
def __call__(self, a: Geography, b: Geography) -> _ScalarReturnType: ...
@overload
def __call__(
self, a: npt.ArrayLike, b: npt.ArrayLike
self, a: Geography, b: npt.ArrayLike
) -> npt.NDArray[_ArrayReturnDType]: ...
@overload
def __call__(
self, a: Geography, b: npt.ArrayLike
self, a: npt.ArrayLike, b: Geography
) -> npt.NDArray[_ArrayReturnDType]: ...
@overload
def __call__(
self, a: npt.ArrayLike, b: Geography
self, a: npt.ArrayLike, b: npt.ArrayLike
) -> npt.NDArray[_ArrayReturnDType]: ...

class _VFunc_Nin2optradius_Nout1(
Expand All @@ -113,15 +118,15 @@ class _VFunc_Nin2optradius_Nout1(
) -> _ScalarReturnType: ...
@overload
def __call__(
self, a: npt.ArrayLike, b: npt.ArrayLike, radius: float = ...
self, a: Geography, b: npt.ArrayLike, radius: float = ...
) -> npt.NDArray[_ArrayReturnDType]: ...
@overload
def __call__(
self, a: Geography, b: npt.ArrayLike, radius: float = ...
self, a: npt.ArrayLike, b: Geography, radius: float = ...
) -> npt.NDArray[_ArrayReturnDType]: ...
@overload
def __call__(
self, a: npt.ArrayLike, b: Geography, radius: float = ...
self, a: npt.ArrayLike, b: npt.ArrayLike, radius: float = ...
) -> npt.NDArray[_ArrayReturnDType]: ...

class _VFunc_Nin1optradius_Nout1(
Expand All @@ -136,6 +141,8 @@ class _VFunc_Nin1optradius_Nout1(
self, a: npt.ArrayLike, radius: float = ...
) -> npt.NDArray[_ArrayReturnDType]: ...

# /// End types

# Geography properties

get_dimensions: _VFunc_Nin1_Nout1[Literal["get_dimensions"], Geography, Any]
Expand Down

0 comments on commit 3271bd0

Please sign in to comment.