Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] call backend JAX bindings #74

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
22 changes: 20 additions & 2 deletions cuequivariance/cuequivariance/tensor_product_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ class Buffer(int):


class InBuffer(Buffer):
pass
def __repr__(self):
return f"InBuffer({int(self)})"


class OutBuffer(Buffer):
pass
def __repr__(self):
return f"OutBuffer({int(self)})"


T = TypeVar("T")
Expand Down Expand Up @@ -165,6 +167,22 @@ def num_inputs_per_operand(self) -> tuple[int, ...]:
def num_outputs_per_operand(self) -> tuple[int, ...]:
return tuple(len(s) for s in self.out_buffers_per_operand)

def get_in_buffer_operands(self, buffer: int) -> set[int]:
return {
ope
for c in self.computations
for ope, b in enumerate(c)
if isinstance(b, InBuffer) and b == buffer
}

def get_out_buffer_operands(self, buffer: int) -> set[int]:
return {
ope
for c in self.computations
for ope, b in enumerate(c)
if isinstance(b, OutBuffer) and b == buffer
}

def map_buffers(
self,
f_in: Optional[Callable[[int], int]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def equivariant_tensor_product(
*inputs: cuex.RepArray | jax.Array,
dtype_output: jnp.dtype | None = None,
dtype_math: jnp.dtype | None = None,
precision: jax.lax.Precision = jax.lax.Precision.HIGHEST,
algorithm: str = "sliced",
use_custom_primitive: bool = True,
use_custom_kernels: bool = False,
use_custom_kernels: bool | None = False,
name: str | None = None,
**options,
) -> cuex.RepArray:
"""Compute the equivariant tensor product of the input arrays.

Expand All @@ -37,10 +37,9 @@ def equivariant_tensor_product(
*inputs (RepArray or jax.Array): The input arrays.
dtype_output (jnp.dtype, optional): The data type for the output array. Defaults to None.
dtype_math (jnp.dtype, optional): The data type for computational operations. Defaults to None.
precision (jax.lax.Precision, optional): The precision for the computation. Defaults to ``jax.lax.Precision.HIGHEST``.
algorithm (str, optional): One of "sliced", "stacked", "compact_stacked", "indexed_compact", "indexed_vmap", "indexed_for_loop". Defaults to "sliced". See :class:`cuex.tensor_product <cuequivariance_jax.tensor_product>` for more information.
use_custom_primitive (bool, optional): Whether to use custom JVP rules. Defaults to True.
use_custom_kernels (bool, optional): Whether to use custom kernels. Defaults to True.
name (str, optional): The name of the operation. Defaults to None.

Returns:
RepArray: The result of the equivariant tensor product.
Expand Down Expand Up @@ -74,10 +73,10 @@ def equivariant_tensor_product(
*inputs,
dtype_output=dtype_output,
dtype_math=dtype_math,
precision=precision,
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name,
**options,
)

if len(inputs) != e.num_inputs:
Expand Down Expand Up @@ -109,10 +108,10 @@ def equivariant_tensor_product(
*inputs,
dtype_output=dtype_output,
dtype_math=dtype_math,
precision=precision,
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name,
**options,
)

return cuex.RepArray(e.output, x)
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def symmetric_tensor_product(
*inputs: jax.Array,
dtype_output: jnp.dtype | None = None,
dtype_math: jnp.dtype | None = None,
precision: jax.lax.Precision = jax.lax.Precision.HIGHEST,
algorithm: str = "sliced",
use_custom_primitive: bool = True,
use_custom_kernels: bool = False,
use_custom_kernels: bool | None = False,
name: str | None = None,
**options,
) -> jax.Array:
"""
Compute the sum of the STPs evaluated on the input (all input operands are the same).
Expand All @@ -41,10 +41,9 @@ def symmetric_tensor_product(
*inputs (jax.Array): The input arrays. The last input is repeated to match the number of input operands of each STP.
dtype_output (jnp.dtype, optional): The data type for the output array.
dtype_math (jnp.dtype, optional): The data type for mathematical operations.
precision (jax.lax.Precision, optional): The precision for the computation. Defaults to jax.lax.Precision.HIGHEST.
algorithm (str, optional): One of "sliced", "stacked", "compact_stacked", "indexed_compact", "indexed_vmap", "indexed_for_loop". Defaults to "sliced".
use_custom_primitive (bool, optional): Whether to use custom JVP rules. Defaults to True.
use_custom_kernels (bool, optional): Whether to use custom kernels. Defaults to True.
name (str, optional): The name of the operation.

Returns:
jax.Array: The result of the tensor product computation.
Expand All @@ -54,6 +53,9 @@ def symmetric_tensor_product(
"""
assert any(d.num_operands >= 2 for d in ds)

if name is None:
name = "symmetric_tensor_product"

# currying
if len(inputs) == 0:

Expand All @@ -63,10 +65,10 @@ def fn(*inputs) -> jax.Array:
*inputs,
dtype_output=dtype_output,
dtype_math=dtype_math,
precision=precision,
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name,
**options,
)

return fn
Expand Down Expand Up @@ -132,10 +134,10 @@ def fn(*inputs) -> jax.Array:
*d_inputs,
dtype_output=dtype_output,
dtype_math=dtype_math,
precision=precision,
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name + f"_{n_in - n_un}",
**options,
)

return output
Loading