Skip to content

Commit

Permalink
Use npt.NDArray
Browse files Browse the repository at this point in the history
  • Loading branch information
mhostetter committed Sep 30, 2023
1 parent 1499ed9 commit aa06dcb
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 65 deletions.
7 changes: 4 additions & 3 deletions src/galois/_codes/_bch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Type, overload

import numpy as np
import numpy.typing as npt
from typing_extensions import Literal

from .._fields import GF2, Field, FieldArray
Expand Down Expand Up @@ -469,7 +470,7 @@ def encode(self, message: ArrayLike, output: Literal["codeword", "parity"] = "co
bch.detect(c)
""",
)
def detect(self, codeword: ArrayLike) -> bool | np.ndarray:
def detect(self, codeword: ArrayLike) -> bool | npt.NDArray:
# pylint: disable=useless-super-delegation
return super().detect(codeword)

Expand All @@ -488,7 +489,7 @@ def decode(
codeword: ArrayLike,
output: Literal["message", "codeword"] = "message",
errors: Literal[True] = True,
) -> tuple[FieldArray, int | np.ndarray]:
) -> tuple[FieldArray, int | npt.NDArray]:
...

@extend_docstring(
Expand Down Expand Up @@ -652,7 +653,7 @@ def decode(
def decode(self, codeword: Any, output: Any = "message", errors: Any = False) -> Any:
return super().decode(codeword, output=output, errors=errors)

def _decode_codeword(self, codeword: FieldArray) -> tuple[FieldArray, np.ndarray]:
def _decode_codeword(self, codeword: FieldArray) -> tuple[FieldArray, npt.NDArray]:
func = berlekamp_decode_jit(self.field, self.extension_field)
dec_codeword, N_errors = func(codeword, self.n, int(self.alpha), self.c, self.roots)
dec_codeword = dec_codeword.view(self.field)
Expand Down
15 changes: 8 additions & 7 deletions src/galois/_codes/_bm_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numba
import numpy as np
import numpy.typing as npt
from numba import int64

from .._domains._function import Function
Expand All @@ -19,10 +20,10 @@
MULTIPLY: Callable[[int, int], int]
RECIPROCAL: Callable[[int], int]
POWER: Callable[[int, int], int]
CONVOLVE: Callable[[np.ndarray, np.ndarray], np.ndarray]
POLY_ROOTS: Callable[[np.ndarray, np.ndarray, int], np.ndarray]
POLY_EVALUATE: Callable[[np.ndarray, np.ndarray], np.ndarray]
BERLEKAMP_MASSEY: Callable[[np.ndarray], np.ndarray]
CONVOLVE: Callable[[npt.NDArray, npt.NDArray], npt.NDArray]
POLY_ROOTS: Callable[[npt.NDArray, npt.NDArray, int], npt.NDArray]
POLY_EVALUATE: Callable[[npt.NDArray, npt.NDArray], npt.NDArray]
BERLEKAMP_MASSEY: Callable[[npt.NDArray], npt.NDArray]


class berlekamp_decode_jit(Function):
Expand Down Expand Up @@ -51,7 +52,7 @@ def key_1(self) -> Hashable:

def __call__(
self, codeword: FieldArray, design_n: int, alpha: int, c: int, roots: FieldArray
) -> tuple[FieldArray, np.ndarray]:
) -> tuple[FieldArray, npt.NDArray]:
if self.extension_field.ufunc_mode != "python-calculate":
output = self.jit(codeword.astype(np.int64), design_n, alpha, c, roots.astype(np.int64))
else:
Expand Down Expand Up @@ -82,8 +83,8 @@ def set_globals(self) -> None:

@staticmethod
def implementation(
codewords: np.ndarray, design_n: int, alpha: int, c: int, roots: np.ndarray
) -> np.ndarray: # pragma: no cover
codewords: npt.NDArray, design_n: int, alpha: int, c: int, roots: npt.NDArray
) -> npt.NDArray: # pragma: no cover
dtype = codewords.dtype
N = codewords.shape[0] # The number of codewords
n = codewords.shape[1] # The codeword size (could be less than the design n for shortened codes)
Expand Down
3 changes: 2 additions & 1 deletion src/galois/_codes/_cyclic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, overload

import numpy as np
import numpy.typing as npt
from typing_extensions import Literal

from .._fields import FieldArray
Expand Down Expand Up @@ -99,7 +100,7 @@ def decode(
codeword: ArrayLike,
output: Literal["message", "codeword"] = "message",
errors: Literal[True] = True,
) -> tuple[FieldArray, int | np.ndarray]:
) -> tuple[FieldArray, int | npt.NDArray]:
...

@extend_docstring(
Expand Down
9 changes: 5 additions & 4 deletions src/galois/_codes/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Type, cast, overload

import numpy as np
import numpy.typing as npt
from typing_extensions import Literal

from .._fields import FieldArray
Expand Down Expand Up @@ -90,7 +91,7 @@ def encode(self, message: ArrayLike, output: Literal["codeword", "parity"] = "co
parity = self._convert_codeword_to_parity(codeword)
return parity

def detect(self, codeword: ArrayLike) -> bool | np.ndarray:
def detect(self, codeword: ArrayLike) -> bool | npt.NDArray:
r"""
Detects if errors are present in the codeword $\mathbf{c}$.
Expand Down Expand Up @@ -130,7 +131,7 @@ def decode(
codeword: ArrayLike,
output: Literal["message", "codeword"] = "message",
errors: Literal[True] = True,
) -> tuple[FieldArray, int | np.ndarray]:
) -> tuple[FieldArray, int | npt.NDArray]:
...

def decode(self, codeword: Any, output: Any = "message", errors: Any = False) -> Any:
Expand Down Expand Up @@ -276,7 +277,7 @@ def _encode_message(self, message: FieldArray) -> FieldArray:

return codeword

def _detect_errors(self, codeword: FieldArray) -> np.ndarray:
def _detect_errors(self, codeword: FieldArray) -> npt.NDArray:
"""
Returns a boolean array (N,) indicating if errors are present in the codeword.
"""
Expand All @@ -290,7 +291,7 @@ def _detect_errors(self, codeword: FieldArray) -> np.ndarray:

return detected

def _decode_codeword(self, codeword: FieldArray) -> tuple[FieldArray, np.ndarray]:
def _decode_codeword(self, codeword: FieldArray) -> tuple[FieldArray, npt.NDArray]:
"""
Decodes errors in the received codeword. Returns the corrected codeword (N, ns) and array of number of
corrected errors (N,).
Expand Down
7 changes: 4 additions & 3 deletions src/galois/_codes/_reed_solomon.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Type, cast, overload

import numpy as np
import numpy.typing as npt
from typing_extensions import Literal

from .._fields import Field, FieldArray
Expand Down Expand Up @@ -429,7 +430,7 @@ def encode(self, message: ArrayLike, output: Literal["codeword", "parity"] = "co
rs.detect(c)
""",
)
def detect(self, codeword: ArrayLike) -> bool | np.ndarray:
def detect(self, codeword: ArrayLike) -> bool | npt.NDArray:
# pylint: disable=useless-super-delegation
return super().detect(codeword)

Expand All @@ -448,7 +449,7 @@ def decode(
codeword: ArrayLike,
output: Literal["message", "codeword"] = "message",
errors: Literal[True] = True,
) -> tuple[FieldArray, int | np.ndarray]:
) -> tuple[FieldArray, int | npt.NDArray]:
...

@extend_docstring(
Expand Down Expand Up @@ -608,7 +609,7 @@ def decode(
def decode(self, codeword: Any, output: Any = "message", errors: Any = False) -> Any:
return super().decode(codeword, output=output, errors=errors)

def _decode_codeword(self, codeword: FieldArray) -> tuple[FieldArray, np.ndarray]:
def _decode_codeword(self, codeword: FieldArray) -> tuple[FieldArray, npt.NDArray]:
func = berlekamp_decode_jit(self.field, self.field)
dec_codeword, N_errors = func(codeword, self.n, int(self.alpha), self.c, self.roots)
dec_codeword = dec_codeword.view(self.field)
Expand Down
9 changes: 5 additions & 4 deletions src/galois/_domains/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Generator, cast, no_type_check

import numpy as np
import numpy.typing as npt
from typing_extensions import Literal, Self

from .._helper import export, verify_isinstance, verify_literal
Expand Down Expand Up @@ -84,7 +85,7 @@ def _verify_array_like_types_and_values(cls, x: ElementLike | ArrayLike) -> Elem

@classmethod
@abc.abstractmethod
def _verify_element_types_and_convert(cls, array: np.ndarray, object_=False) -> np.ndarray:
def _verify_element_types_and_convert(cls, array: npt.NDArray, object_=False) -> npt.NDArray:
"""
Iterate across each element and verify it's a valid type. Also, convert strings to integers along the way.
"""
Expand All @@ -98,7 +99,7 @@ def _verify_scalar_value(cls, scalar: int):

@classmethod
@abc.abstractmethod
def _verify_array_values(cls, array: np.ndarray):
def _verify_array_values(cls, array: npt.NDArray):
"""
Verify all the elements of the integer array are within the valid range [0, order).
"""
Expand All @@ -116,7 +117,7 @@ def _convert_to_element(cls, element: ElementLike) -> int:

@classmethod
@abc.abstractmethod
def _convert_iterable_to_elements(cls, iterable: IterableLike) -> np.ndarray:
def _convert_iterable_to_elements(cls, iterable: IterableLike) -> npt.NDArray:
"""
Convert an iterable (recursive) to a NumPy integer array. Convert any strings to integers along the way.
"""
Expand All @@ -126,7 +127,7 @@ def _convert_iterable_to_elements(cls, iterable: IterableLike) -> np.ndarray:
###############################################################################

@classmethod
def _view(cls, array: np.ndarray) -> Self:
def _view(cls, array: npt.NDArray) -> Self:
"""
View the input array to the Array subclass `A` using the `_view_without_verification()` context manager.
This disables bounds checking on the array elements. Instead of `x.view(A)` use `A._view(x)`.
Expand Down
21 changes: 11 additions & 10 deletions src/galois/_domains/_calculate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numba
import numpy as np
import numpy.typing as npt

from .._prime import factors
from . import _lookup
Expand All @@ -23,10 +24,10 @@
ORDER: int
IRREDUCIBLE_POLY: int

INT_TO_VECTOR: Callable[[int, int, int], np.ndarray]
VECTOR_TO_INT: Callable[[np.ndarray, int, int], int]
EGCD: Callable[[int, int], np.ndarray]
CRT: Callable[[np.ndarray, np.ndarray], int]
INT_TO_VECTOR: Callable[[int, int, int], npt.NDArray]
VECTOR_TO_INT: Callable[[npt.NDArray, int, int], int]
EGCD: Callable[[int, int], npt.NDArray]
CRT: Callable[[npt.NDArray, npt.NDArray], int]
DTYPE = np.int64

MULTIPLY: Callable[[int, int], int]
Expand All @@ -36,12 +37,12 @@
POSITIVE_POWER: Callable[[int, int], int]
BRUTE_FORCE_LOG: Callable[[int, int], int]

FACTORS: np.ndarray
MULTIPLICITIES: np.ndarray
FACTORS: npt.NDArray
MULTIPLICITIES: npt.NDArray


@numba.jit(["int64[:](int64, int64, int64)"], nopython=True, cache=True)
def int_to_vector(a: int, characteristic: int, degree: int) -> np.ndarray:
def int_to_vector(a: int, characteristic: int, degree: int) -> npt.NDArray:
"""
Converts the integer representation to vector/polynomial representation.
"""
Expand All @@ -55,7 +56,7 @@ def int_to_vector(a: int, characteristic: int, degree: int) -> np.ndarray:


@numba.jit(["int64(int64[:], int64, int64)"], nopython=True, cache=True)
def vector_to_int(a_vec: np.ndarray, characteristic: int, degree: int) -> int:
def vector_to_int(a_vec: npt.NDArray, characteristic: int, degree: int) -> int:
"""
Converts the vector/polynomial representation to the integer representation.
"""
Expand All @@ -69,7 +70,7 @@ def vector_to_int(a_vec: np.ndarray, characteristic: int, degree: int) -> int:


@numba.jit(["int64[:](int64, int64)"], nopython=True, cache=True)
def egcd(a: int, b: int) -> np.ndarray: # pragma: no cover
def egcd(a: int, b: int) -> npt.NDArray: # pragma: no cover
"""
Computes the Extended Euclidean Algorithm. Returns (d, s, t).
Expand All @@ -96,7 +97,7 @@ def egcd(a: int, b: int) -> np.ndarray: # pragma: no cover


@numba.jit(["int64(int64[:], int64[:])"], nopython=True, cache=True)
def crt(remainders: np.ndarray, moduli: np.ndarray) -> int: # pragma: no cover
def crt(remainders: npt.NDArray, moduli: npt.NDArray) -> int: # pragma: no cover
"""
Computes the simultaneous solution to the system of congruences xi == ai (mod mi).
"""
Expand Down
7 changes: 4 additions & 3 deletions src/galois/_domains/_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING

import numpy as np
import numpy.typing as npt

from . import _ufunc

Expand All @@ -16,9 +17,9 @@

ORDER: int

LOG: np.ndarray
EXP: np.ndarray
ZECH_LOG: np.ndarray
LOG: npt.NDArray
EXP: npt.NDArray
ZECH_LOG: npt.NDArray
ZECH_E: int


Expand Down
17 changes: 9 additions & 8 deletions src/galois/_fields/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Generator, cast

import numpy as np
import numpy.typing as npt
from typing_extensions import Literal, Self

from .._domains import Array, _linalg
Expand Down Expand Up @@ -156,7 +157,7 @@ def _verify_array_like_types_and_values(cls, x: ElementLike | ArrayLike) -> Elem
return x

@classmethod
def _verify_element_types_and_convert(cls, array: np.ndarray, object_=False) -> np.ndarray:
def _verify_element_types_and_convert(cls, array: npt.NDArray, object_=False) -> npt.NDArray:
if array.size == 0:
return array
if object_:
Expand All @@ -169,7 +170,7 @@ def _verify_scalar_value(cls, scalar: int):
raise ValueError(f"{cls.name} scalars must be in `0 <= x < {cls.order}`, not {scalar}.")

@classmethod
def _verify_array_values(cls, array: np.ndarray):
def _verify_array_values(cls, array: npt.NDArray):
if np.any(array < 0) or np.any(array >= cls.order):
idxs = np.logical_or(array < 0, array >= cls.order)
values = array if array.ndim == 0 else array[idxs]
Expand All @@ -193,7 +194,7 @@ def _convert_to_element(cls, element: ElementLike) -> int:
return element

@classmethod
def _convert_iterable_to_elements(cls, iterable: IterableLike) -> np.ndarray:
def _convert_iterable_to_elements(cls, iterable: IterableLike) -> npt.NDArray:
if cls.dtypes == [np.object_]:
array = np.array(iterable, dtype=object)
array = cls._verify_element_types_and_convert(array, object_=True)
Expand Down Expand Up @@ -961,7 +962,7 @@ def primitive_roots_of_unity(cls, n: int) -> Self:
# Instance methods
###############################################################################

def additive_order(self) -> int | np.ndarray:
def additive_order(self) -> int | npt.NDArray:
r"""
Computes the additive order of each element in $x$.
Expand Down Expand Up @@ -995,7 +996,7 @@ def additive_order(self) -> int | np.ndarray:

return order

def multiplicative_order(self) -> int | np.ndarray:
def multiplicative_order(self) -> int | npt.NDArray:
r"""
Computes the multiplicative order $\textrm{ord}(x)$ of each element in $x$.
Expand Down Expand Up @@ -1057,7 +1058,7 @@ def multiplicative_order(self) -> int | np.ndarray:

return order

def is_square(self) -> bool | np.ndarray:
def is_square(self) -> bool | npt.NDArray:
r"""
Determines if the elements of $x$ are squares in the finite field.
Expand Down Expand Up @@ -1656,7 +1657,7 @@ def minimal_poly(self) -> Poly:
f"or 2-D to return the minimal polynomial of a square matrix, not have shape {self.shape}."
)

def log(self, base: ElementLike | ArrayLike | None = None) -> int | np.ndarray:
def log(self, base: ElementLike | ArrayLike | None = None) -> int | npt.NDArray:
r"""
Computes the discrete logarithm of the array $x$ base $\beta$.
Expand Down Expand Up @@ -1881,7 +1882,7 @@ def _print_power(cls, element: Self) -> str:
return s


def _poly_det(A: np.ndarray) -> Poly:
def _poly_det(A: npt.NDArray) -> Poly:
"""
Computes the determinant of a matrix of `Poly` objects.
"""
Expand Down
Loading

0 comments on commit aa06dcb

Please sign in to comment.