Skip to content

Commit

Permalink
update assign_units and add kcal_per_h (#89)
Browse files Browse the repository at this point in the history
* updates

* fix test
  • Loading branch information
chaoming0625 authored Dec 28, 2024
1 parent 9063787 commit 1753e86
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 62 deletions.
194 changes: 137 additions & 57 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from contextlib import contextmanager
from copy import deepcopy
from functools import wraps, partial
from typing import Union, Optional, Sequence, Callable, Tuple, Any, List, Dict
from typing import Union, Optional, Sequence, Callable, Tuple, Any, List, Dict, cast

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -4470,7 +4470,8 @@ def new_f(*args, **kwds):
f"'{value}'"
)
raise DimensionMismatchError(
error_message, get_dim(newkeyset[k])
error_message,
get_dim(newkeyset[k])
)

result = f(*args, **kwds)
Expand Down Expand Up @@ -4782,78 +4783,157 @@ def new_f(*args, **kwds):
return do_check_units


class CallableAssignUnit(Callable):
without_result_units = Callable

def __call__(self, *args, **kwargs):
pass


class Missing():
pass


missing = Missing()


@set_module_as('brainunit')
def assign_units(**au):
def assign_units(f: Callable = missing, **au) -> CallableAssignUnit | Callable[[Callable], CallableAssignUnit]:
"""
Decorator to transform units of arguments passed to a function
"""
if f is missing:
return partial(assign_units, **au)

@wraps(f)
def new_f(*args, **kwds):
arg_names = f.__code__.co_varnames[0: f.__code__.co_argcount]
newkeyset = kwds.copy()
for n, v in zip(arg_names, args[0: f.__code__.co_argcount]):
newkeyset[n] = v
for n, v in tuple(newkeyset.items()):
if n in au and v is not None:
specific_unit = au[n]

def do_assign_units(f):
@wraps(f)
def new_f(*args, **kwds):
newkeyset = kwds.copy()
arg_names = f.__code__.co_varnames[0: f.__code__.co_argcount]
for n, v in zip(arg_names, args[0: f.__code__.co_argcount]):
if n in au and v is not None:
specific_unit = au[n]
# if the specific unit is a boolean, just check and return
if specific_unit == bool:
if isinstance(v, bool):
newkeyset[n] = v
else:
raise TypeError(
f"Function '{f.__name__}' expected a boolean value for argument '{n}' but got '{v}'")
if (
jax.tree.structure(specific_unit, is_leaf=_is_quantity)
!=
jax.tree.structure(v, is_leaf=_is_quantity)
):
raise TypeError(
f"For argument '{n}', we expect the input type "
f"with the structure like {specific_unit}, "
f"but we got {v}"
)

elif specific_unit == 1:
if isinstance(v, Quantity):
newkeyset[n] = v.to_decimal()
elif isinstance(v, (jax.Array, np.ndarray, int, float, complex)):
newkeyset[n] = v
else:
specific_unit = jax.typing.ArrayLike
raise TypeError(f"Function '{f.__name__}' expected a unitless Quantity object"
f"or {specific_unit} for argument '{n}' but got '{v}'")

elif isinstance(specific_unit, Unit):
if isinstance(v, Quantity):
v = v.to_decimal(specific_unit)
newkeyset[n] = v
else:
raise TypeError(
f"Function '{f.__name__}' expected a Quantity object for argument '{n}' but got '{v}'"
)
else:
raise TypeError(
f"Function '{f.__name__}' expected a target unit object or"
f" a Number, boolean object for checking, but got '{specific_unit}'"
)
else:
newkeyset[n] = v
v = jax.tree.map(
partial(_remove_unit, f.__name__, n),
specific_unit,
v,
is_leaf=_is_quantity
)
newkeyset[n] = v

result = f(**newkeyset)
if "result" in au:
if isinstance(au["result"], Callable) and au["result"] != bool:
expected_result = au["result"](*[get_unit(a) for a in args])
else:
expected_result = au["result"]
result = f(**newkeyset)
if "result" in au:
if isinstance(au["result"], Callable) and au["result"] != bool:
expected_result = au["result"](*[get_unit(a) for a in args])
else:
expected_result = au["result"]

expected_pytree = jax.tree.structure(
expected_result,
is_leaf=lambda x: isinstance(x, Quantity) or x is None
)
result_pytree = jax.tree.structure(result, is_leaf=lambda x: isinstance(x, Quantity) or x is None)
if (
expected_pytree
!=
result_pytree
):
raise TypeError(
f"Expected a return value of pytree {expected_pytree} with type {expected_result}, "
f"but got the pytree {result_pytree} and the value {result}"
)

result = jax.tree.map(
partial(_assign_unit, f),
result,
expected_result,
is_leaf=lambda x: isinstance(x, Quantity) or x is None
)
return result

def without_result_units(*args, **kwds):
arg_names = f.__code__.co_varnames[0: f.__code__.co_argcount]
newkeyset = kwds.copy()
for n, v in zip(arg_names, args[0: f.__code__.co_argcount]):
newkeyset[n] = v
for n, v in tuple(newkeyset.items()):
if n in au and v is not None:
specific_unit = au[n]

if (
jax.tree.structure(expected_result, is_leaf=_is_quantity)
jax.tree.structure(specific_unit, is_leaf=_is_quantity)
!=
jax.tree.structure(result, is_leaf=_is_quantity)
jax.tree.structure(v, is_leaf=_is_quantity)
):
raise TypeError(
f"Expected a return value of type {expected_result} but got {result}"
f"For argument '{n}', we expect the input type {specific_unit} but got {v}"
)

result = jax.tree.map(
partial(_assign_unit, f), result, expected_result,
v = jax.tree.map(
partial(_remove_unit, f.__name__, n),
specific_unit,
v,
is_leaf=_is_quantity
)
return result
newkeyset[n] = v

result = f(**newkeyset)
return result

new_f.without_result_units = without_result_units

return cast(CallableAssignUnit, new_f)

return new_f

return do_assign_units
def _remove_unit(fname, n, unit, v):
if unit is None:
return v

# if the specific unit is a boolean, just check and return
elif unit is bool:
if isinstance(v, bool):
return v
else:
raise TypeError(
f"Function '{fname}' expected a boolean "
f"value for argument '{n}' but got '{v}'"
)

elif isinstance(unit, Unit):
if isinstance(v, Quantity):
v = v.to_decimal(unit)
return v
else:
raise TypeError(
f"Function '{fname}' expected a Quantity "
f"object for argument '{n}' but got '{v}'"
)

elif unit == 1:
if isinstance(v, Quantity):
raise TypeError(
f"Function '{fname}' expected a Number object for argument '{n}' but got '{v}'"
)
return v

else:
raise TypeError(
f"Function '{fname}' expected a target unit object or"
f" a Number, boolean object for checking, but got '{unit}'"
)


def _check_unit(f, val, unit):
Expand Down
4 changes: 3 additions & 1 deletion brainunit/_unit_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"hectare", "acre", "gallon", "gallon_US", "gallon_imp", "fluid_ounce", "fluid_ounce_US", "fluid_ounce_imp",
"bbl", "barrel", "speed_unit", "kmh", "mph", "mach", "speed_of_sound", "knot", "degree_Fahrenheit", "eV",
"electron_volt", "calorie", "calorie_th", "calorie_IT", "erg", "Btu", "Btu_IT", "Btu_th", "ton_TNT", "hp",
"horsepower", "dyn", "dyne", "lbf", "pound_force", "kgf", "kilogram_force", "IMF"
"horsepower", "dyn", "dyne", "lbf", "pound_force", "kgf", "kilogram_force", "IMF", 'kcal_per_h'
]

# ----- Mass -----
Expand Down Expand Up @@ -136,6 +136,8 @@
# ----- Power -----
hp = horsepower = Unit.create(watt.dim, name="horsepower", dispname="hp", scale=watt.scale + 2,
factor=7.4569987158227022)
kcal_per_h = Unit.create(watt.dim, name="kcal per hour", dispname="kcal/h", scale=watt.scale,
factor=1.162222)

# ----- Force -----
dyn = dyne = Unit.create(newton.dim, name="dyne", dispname="dyn", scale=newton.scale - 5, factor=1.)
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ We are building the `brain dynamics programming ecosystem <https://ecosystem-for
mathematical_functions/customize_functions.ipynb
mathematical_functions/array_creation.ipynb
mathematical_functions/numpy_functions.ipynb
mathematical_functions/elinstein_operations.ipynb
mathematical_functions/einstein_operations.ipynb
mathematical_functions/linalg_functions.ipynb
mathematical_functions/fft_functions.ipynb
mathematical_functions/lax_functions.ipynb
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Elinstein Operations\n",
"# Einstein Operations\n",
"\n",
"[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/mathematical_functions/elinstein_operation.ipynb)\n",
"[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/mathematical_functions/elinstein_operation.ipynb)"
"[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/mathematical_functions/einstein_operation.ipynb)\n",
"[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/mathematical_functions/einstein_operation.ipynb)"
]
},
{
Expand Down

0 comments on commit 1753e86

Please sign in to comment.