Skip to content

Commit

Permalink
update __str__
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 11, 2024
1 parent 81cd1af commit d3f38e9
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class
from jax.interpreters.partial_eval import DynamicJaxprTracer

from ._misc import get_dtype



__all__ = [
'Quantity',
'Unit',
Expand Down Expand Up @@ -981,25 +984,27 @@ def __init__(
value = jnp.array(value, dtype=dtype)
except ValueError:
raise TypeError("All elements must be convertible to a jax array")
dtype = dtype or get_dtype(value)

# array value
if isinstance(value, Quantity):
dtype = dtype or get_dtype(value)
self._dim = value.dim
self._value = jnp.array(value.value, dtype=dtype)
return

elif isinstance(value, (np.ndarray, jax.Array)):
dtype = dtype or get_dtype(value)
value = jnp.array(value, dtype=dtype)

elif isinstance(value, (jnp.number, numbers.Number)):
dtype = dtype or get_dtype(value)
value = jnp.array(value, dtype=dtype)

elif isinstance(value, (jax.core.ShapedArray, jax.ShapeDtypeStruct)):
value = value

else:
raise TypeError(f"Invalid type for value: {type(value)}")
value = value

# value
self._value = (value if scale is None else (value * scale))
Expand Down Expand Up @@ -1330,9 +1335,13 @@ def isnan(self) -> jax.Array:
# ----------------------- #

def __repr__(self) -> str:
if isinstance(self.value, (jax.ShapeDtypeStruct, jax.core.ShapedArray, DynamicJaxprTracer)):
return f'{self.value} * {Quantity(1, dim=self.dim)}'
return self.repr_in_best_unit(python_code=True)

def __str__(self) -> str:
if isinstance(self.value, (jax.ShapeDtypeStruct, jax.core.ShapedArray, DynamicJaxprTracer)):
return f'{self.value} * {Quantity(1, dim=self.dim)}'
return self.repr_in_best_unit()

def __format__(self, format_spec: str) -> str:
Expand Down

0 comments on commit d3f38e9

Please sign in to comment.