From d3f38e9e6950f6196afbb45a707b7d947ab9bf05 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Tue, 11 Jun 2024 18:11:00 +0800 Subject: [PATCH] update __str__ --- brainunit/_base.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/brainunit/_base.py b/brainunit/_base.py index 0efd32a..cd48720 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -27,9 +27,12 @@ import jax.numpy as jnp import numpy as np from jax.tree_util import register_pytree_node_class +from jax.interpreters.partial_eval import DynamicJaxprTracer from ._misc import get_dtype + + __all__ = [ 'Quantity', 'Unit', @@ -981,25 +984,27 @@ def __init__( value = jnp.array(value, dtype=dtype) except ValueError: raise TypeError("All elements must be convertible to a jax array") - dtype = dtype or get_dtype(value) # array value if isinstance(value, Quantity): + dtype = dtype or get_dtype(value) self._dim = value.dim self._value = jnp.array(value.value, dtype=dtype) return elif isinstance(value, (np.ndarray, jax.Array)): + dtype = dtype or get_dtype(value) value = jnp.array(value, dtype=dtype) elif isinstance(value, (jnp.number, numbers.Number)): + dtype = dtype or get_dtype(value) value = jnp.array(value, dtype=dtype) elif isinstance(value, (jax.core.ShapedArray, jax.ShapeDtypeStruct)): value = value else: - raise TypeError(f"Invalid type for value: {type(value)}") + value = value # value self._value = (value if scale is None else (value * scale)) @@ -1330,9 +1335,13 @@ def isnan(self) -> jax.Array: # ----------------------- # def __repr__(self) -> str: + if isinstance(self.value, (jax.ShapeDtypeStruct, jax.core.ShapedArray, DynamicJaxprTracer)): + return f'{self.value} * {Quantity(1, dim=self.dim)}' return self.repr_in_best_unit(python_code=True) def __str__(self) -> str: + if isinstance(self.value, (jax.ShapeDtypeStruct, jax.core.ShapedArray, DynamicJaxprTracer)): + return f'{self.value} * {Quantity(1, dim=self.dim)}' return self.repr_in_best_unit() def __format__(self, format_spec: str) -> str: