Skip to content

Commit

Permalink
Merge pull request #22 from quantumlib/new_constructor
Browse files Browse the repository at this point in the history
Optimize object creation
  • Loading branch information
NoureldinYosri authored Nov 20, 2024
2 parents 7a5c062 + 1d6b103 commit 6793bc8
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 25 deletions.
4 changes: 2 additions & 2 deletions test/cython/test_with_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ def test_array() -> None:

u = np.array(val([val(2, units=m), val(3, units=m)]))
assert isinstance(u, np.ndarray)
assert isinstance(u[0], float)
assert np.array_equal([2, 3], u)
assert isinstance(u[0], Value)
assert np.array_equal(val(1, units=m) * [2, 3], u)

u = np.array(val([2, 3]))
assert isinstance(u, np.ndarray)
Expand Down
15 changes: 15 additions & 0 deletions test/test_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,18 @@ def test_division_type() -> None:
assert t == 1.5 * tunits.ns

assert isinstance(1 / t, tunits.Value)


def test_dimension_with_wrong_unit_raises() -> None:

with pytest.raises(ValueError, match='not a valid unit for dimension'):
_ = tunits.Time(1)

with pytest.raises(ValueError, match='not a valid unit for dimension'):
_ = tunits.TimeArray([1])

with pytest.raises(ValueError, match='not a valid unit for dimension'):
_ = tunits.Time(1, 'm')

with pytest.raises(ValueError, match='not a valid unit for dimension'):
_ = tunits.TimeArray([1], 'm')
4 changes: 2 additions & 2 deletions tunits/core/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ class Value(WithUnit):
@overload
def __truediv__(self, other: ValueArray) -> ValueArray: ...
@overload
def __rtruediv__(self: ValueType, other: int | float | complex) -> ValueType: ...
def __rtruediv__(self, other: int | float | complex) -> Value: ...
@overload
def __rtruediv__(self, other: list[Any] | tuple[Any] | NDArray[Any]) -> ValueArray: ...
@overload
Expand Down Expand Up @@ -329,7 +329,7 @@ class ValueArray(WithUnit):
def __truediv__(self: ArrayType, other: _NUMERICAL_TYPE_OR_ARRAY) -> ArrayType: ...
@overload
def __truediv__(self: ArrayType, other: WithUnit) -> ValueArray: ...
def __rtruediv__(self: ArrayType, other: _NUMERICAL_TYPE_OR_ARRAY) -> ArrayType: ...
def __rtruediv__(self: ArrayType, other: _NUMERICAL_TYPE_OR_ARRAY) -> ValueArray: ...
def __floordiv__(self, other: Any) -> NDArray[Any]: ...
@overload
def __mul__(self: ArrayType, other: _NUMERICAL_TYPE_OR_ARRAY) -> ArrayType: ...
Expand Down
15 changes: 11 additions & 4 deletions tunits/core/cython/dimension.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,17 @@ class _Acceleration(Dimension):
return AccelerationArray


class ValueWithDimension(Dimension, Value): ...


class ArrayWithDimension(Dimension, ValueArray): ...
class ValueWithDimension(Dimension, Value):
def __init__(self, val, unit=None, validate:bool=True):
super().__init__(val, unit=unit)
if validate and not type(self).is_valid(self):
raise ValueError(f'{self.unit} is not a valid unit for dimension {type(self)}')

class ArrayWithDimension(Dimension, ValueArray):
def __init__(self, val, unit=None, validate:bool=True):
super().__init__(val, unit=unit)
if validate and not type(self).is_valid(self):
raise ValueError(f'{self.unit} is not a valid unit for dimension {type(self)}')


class Acceleration(_Acceleration, ValueWithDimension): ...
Expand Down
10 changes: 8 additions & 2 deletions tunits/core/cython/with_unit.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,18 @@ cpdef raw_WithUnit(value,
val = value
target_type = value_class
elif isinstance(value, (list, tuple, np.ndarray)):
val = np.array(value)
val = np.asarray(value)
target_type = array_class
val, unit = _canonize_data_and_unit(val, None)
if unit is not None:
conv = conversion_times(conv, unit.conv)
display_units *= unit.display_units
base_units *= unit.base_units
else:
raise NotTUnitsLikeError("Unrecognized value type: {}".format(type(value)))

cdef WithUnit result = target_type(val)
cdef WithUnit result = target_type.__new__(target_type)
result.value = val
result.conv = conv
result.base_units = base_units
result.display_units = display_units
Expand Down
34 changes: 19 additions & 15 deletions tunits/core/cython/with_unit_value_array.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@ import numpy as np

T = TypeVar('ValueArray', bound='ValueArray')

def _canonize_data_and_unit(data, unit=None):
first_item = next(data.flat, None)
if isinstance(first_item, WithUnit):
shared_unit = first_item.unit
scalar = first_item[shared_unit]
inferred_dtype = np.array([scalar]).dtype

it = np.nditer([data, None],
op_dtypes=[data.dtype, inferred_dtype],
flags=['refs_ok'],
op_flags=[['readonly'], ['writeonly', 'allocate']])
for inp, out in it:
out[()] = inp[()][shared_unit]

data = it.operands[1]
unit = shared_unit if unit is None else unit * shared_unit
return data, unit

class ValueArray(WithUnit):

def __init__(WithUnit self, data, unit=None):
Expand All @@ -32,21 +50,7 @@ class ValueArray(WithUnit):

# If the items have units, we're supposed to extract a shared unit.
data = np.asarray(data)
first_item = next(data.flat, None)
if isinstance(first_item, WithUnit):
shared_unit = first_item.unit
scalar = first_item[shared_unit]
inferred_dtype = np.array([scalar]).dtype

it = np.nditer([data, None],
op_dtypes=[data.dtype, inferred_dtype],
flags=['refs_ok'],
op_flags=[['readonly'], ['writeonly', 'allocate']])
for inp, out in it:
out[()] = inp[()][shared_unit]

data = it.operands[1]
unit = shared_unit if unit is None else unit * shared_unit
data, unit = _canonize_data_and_unit(data, unit)

super().__init__(data, unit)

Expand Down

0 comments on commit 6793bc8

Please sign in to comment.