From 6fc6b2029ef7243367845c85cfab7d1e948729a8 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 15 Jan 2025 05:33:25 -0800 Subject: [PATCH 1/6] Refactor tensor_product to use jax.extend.core for Primitive and update type annotations --- .../cuequivariance_jax/primitives/tensor_product.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py index bc52ea6..cb81c35 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py @@ -16,9 +16,10 @@ from functools import partial import jax +import jax.core +import jax.extend import jax.lax import jax.numpy as jnp -from jax import core from jax.interpreters import ad, batching, mlir, xla from cuequivariance import segmented_tensor_product as stp @@ -160,7 +161,7 @@ def _partial(*remaining_inputs: jax.Array) -> jax.Array: ################################################################################ -tensor_product_p = core.Primitive("tensor_product") +tensor_product_p = jax.extend.core.Primitive("tensor_product") tensor_product_p.multiple_results = True @@ -249,12 +250,12 @@ def dispatch( def tensor_product_abstract_eval( - *inputs: core.ShapedArray, + *inputs: jax.core.ShapedArray, shapes: tuple[tuple[int, ...], ...], d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, -) -> tuple[core.ShapedArray, ...]: +) -> tuple[jax.core.ShapedArray, ...]: # assert that all input/output are used assert exe.max_in_buffer + 1 == len(exe.in_buffers) == len(inputs) assert exe.max_out_buffer + 1 == len(exe.out_buffers) @@ -269,7 +270,7 @@ def tensor_product_abstract_eval( outputs = [None] * len(exe.out_buffers) for c in exe.computations: - out = core.ShapedArray( + out = jax.core.ShapedArray( shape=shapes[c.out_operand] + (d.operands[c.out_operand].size,), dtype=options["dtype_output"], ) @@ -291,6 +292,7 @@ def tensor_product_jvp( out_tangents = [ad.Zero(p.aval) for p in out_primals] jvp = exe.jvp([not isinstance(t, ad.Zero) for t in tangents]) + del exe permutations: list[tuple[int, ...]] = d.symmetries() for multiplicator, exe in jvp.group_by_symmetries(permutations): From 1f3fc337e965147c129029b85835d36c8cba74af Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 15 Jan 2025 08:09:20 -0800 Subject: [PATCH 2/6] update ruff version --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 62bcfec..c8c64b1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.2 # Use the latest stable version of Black + rev: v0.9.1 hooks: - id: ruff args: ["--fix"] From fe057b06b47a1bf2c67a85b02dd3e89bf1b56def Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 15 Jan 2025 08:09:34 -0800 Subject: [PATCH 3/6] ruff --- .../primitives/equivariant_tensor_product.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py index 7a9ec81..d5d61f5 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py @@ -87,16 +87,16 @@ def equivariant_tensor_product( for i, (x, rep) in enumerate(zip(inputs, e.inputs)): if isinstance(x, cuex.RepArray): - assert ( - x.rep(-1) == rep - ), f"Input {i} should have representation {rep}, got {x.rep(-1)}." + assert x.rep(-1) == rep, ( + f"Input {i} should have representation {rep}, got {x.rep(-1)}." + ) else: - assert ( - x.ndim >= 1 - ), f"Input {i} should have at least one dimension, got {x.ndim}." - assert ( - x.shape[-1] == rep.dim - ), f"Input {i} should have dimension {rep.dim}, got {x.shape[-1]}." + assert x.ndim >= 1, ( + f"Input {i} should have at least one dimension, got {x.ndim}." + ) + assert x.shape[-1] == rep.dim, ( + f"Input {i} should have dimension {rep.dim}, got {x.shape[-1]}." + ) if not rep.is_scalar(): raise ValueError( f"Input {i} should be a RepArray unless the input is scalar. Got {type(x)} for {rep}." From 0df8fe036fd819938e53ff0bab87f6335ce91246 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 15 Jan 2025 08:32:09 -0800 Subject: [PATCH 4/6] convert Computation in a class --- .../tensor_product_execution.py | 64 ++++++++++++------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/cuequivariance/cuequivariance/tensor_product_execution.py b/cuequivariance/cuequivariance/tensor_product_execution.py index 3564b4f..8d4651d 100644 --- a/cuequivariance/cuequivariance/tensor_product_execution.py +++ b/cuequivariance/cuequivariance/tensor_product_execution.py @@ -34,32 +34,42 @@ class OutBuffer(Buffer): T = TypeVar("T") -class Computation(tuple): - def __new__(cls, elements): - elements = list(elements) - assert all(isinstance(b, Buffer) for b in elements), elements - assert sum(isinstance(b, OutBuffer) for b in elements) == 1, elements - return super().__new__(cls, elements) +class Computation: + buffers: tuple[Buffer, ...] # one buffer per operand + + def __init__(self, buffers: Sequence[Buffer]): + if isinstance(buffers, Computation): + buffers = buffers.buffers + self.buffers = tuple(buffers) + assert all(isinstance(b, Buffer) for b in self.buffers), self.buffers + assert sum(isinstance(b, OutBuffer) for b in self.buffers) == 1, self.buffers + + def __hash__(self) -> int: + return hash(self.buffers) @property def num_operands(self) -> int: - return len(self) + return len(self.buffers) @property def in_buffers(self) -> tuple[InBuffer, ...]: - return tuple(b for b in self if isinstance(b, InBuffer)) + return tuple(b for b in self.buffers if isinstance(b, InBuffer)) @property def out_buffer(self) -> OutBuffer: - return next(b for b in self if isinstance(b, OutBuffer)) + return next(b for b in self.buffers if isinstance(b, OutBuffer)) @property def in_operands(self) -> tuple[int, ...]: - return tuple(oid for oid, b in enumerate(self) if isinstance(b, InBuffer)) + return tuple( + oid for oid, b in enumerate(self.buffers) if isinstance(b, InBuffer) + ) @property def out_operand(self) -> int: - return next(oid for oid, b in enumerate(self) if isinstance(b, OutBuffer)) + return next( + oid for oid, b in enumerate(self.buffers) if isinstance(b, OutBuffer) + ) def map_operands( self, @@ -68,12 +78,14 @@ def map_operands( ) -> list[Optional[T]]: in_buffers = list(in_buffers) if out_buffers is None: - return [in_buffers[b] if isinstance(b, InBuffer) else None for b in self] + return [ + in_buffers[b] if isinstance(b, InBuffer) else None for b in self.buffers + ] else: out_buffers = list(out_buffers) return [ in_buffers[b] if isinstance(b, InBuffer) else out_buffers[b] - for b in self + for b in self.buffers ] def map_inputs( @@ -108,7 +120,8 @@ def __repr__(self): text += [ " " + " ".join( - IVARS[b] if isinstance(b, InBuffer) else OVARS[b] for b in comp + IVARS[b] if isinstance(b, InBuffer) else OVARS[b] + for b in comp.buffers ) ] return "\n".join(text) @@ -121,7 +134,7 @@ def is_trivial(self) -> bool: def num_operands(self) -> int: assert not self.is_trivial for c in self.computations: - return len(c) + return c.num_operands @property def in_buffers(self) -> tuple[int, ...]: @@ -182,7 +195,7 @@ def map_buffers( if isinstance(b, InBuffer) else OutBuffer(int(f_out(b))) ) - for b in comp + for b in comp.buffers ) for comp in self.computations ) @@ -215,7 +228,7 @@ def jvp(self, has_tangent: list[bool]) -> "TensorProductExecution": if bid is None: continue # the tangent is zero - c = list(computation) + c = list(computation.buffers) c[oid] = InBuffer(bid) new_computations.append(Computation(c)) @@ -257,18 +270,18 @@ def transpose( continue # cotangent is zero for oid in comp.in_operands: - if not is_undefined_primal[comp[oid]]: + if not is_undefined_primal[comp.buffers[oid]]: continue # nothing to transpose - c = [None] * len(comp) + c = [None] * comp.num_operands # undefined primal -> output - c[oid] = OutBuffer(primals_new_bid[comp[oid]]) + c[oid] = OutBuffer(primals_new_bid[comp.buffers[oid]]) # output -> cotangent input c[comp.out_operand] = InBuffer(cotangents_new_bid[comp.out_buffer]) # rest of inputs for i in range(comp.num_operands): if i != oid and i != comp.out_operand: - c[i] = InBuffer(primals_new_bid[comp[i]]) + c[i] = InBuffer(primals_new_bid[comp.buffers[i]]) new_computations.append(Computation(c)) @@ -289,8 +302,11 @@ def group_by_symmetries( for c in self.computations: found_bucket = False for bucket in buckets: - rep = bucket[0] - if any(Computation(rep[p] for p in perm) == c for perm in permutations): + rep: Computation = bucket[0] + if any( + Computation(rep.buffers[p] for p in perm) == c + for perm in permutations + ): bucket.append(c) found_bucket = True break @@ -311,7 +327,7 @@ def group_by_identical_buffers( def partition(computation: Computation) -> list[list[int]]: bid_to_oid = defaultdict(list) - for oid, b in enumerate(computation): + for oid, b in enumerate(computation.buffers): b = (type(b), b) bid_to_oid[b].append(oid) return sorted(map(sorted, bid_to_oid.values())) From f569e9f2ac26a29ad10ba9e6b92a78bee13cded6 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 22 Jan 2025 06:50:39 -0800 Subject: [PATCH 5/6] fix introduced bug and add corresponding test --- .../tensor_product_execution.py | 37 ++++++++++++----- .../tests/tensor_product_execution_test.py | 40 +++++++++++++++++++ 2 files changed, 68 insertions(+), 9 deletions(-) create mode 100644 cuequivariance/tests/tensor_product_execution_test.py diff --git a/cuequivariance/cuequivariance/tensor_product_execution.py b/cuequivariance/cuequivariance/tensor_product_execution.py index 8d4651d..7800e02 100644 --- a/cuequivariance/cuequivariance/tensor_product_execution.py +++ b/cuequivariance/cuequivariance/tensor_product_execution.py @@ -24,11 +24,19 @@ class Buffer(int): class InBuffer(Buffer): - pass + def __eq__(self, other: Any) -> bool: + return isinstance(other, InBuffer) and int(self) == int(other) + + def __hash__(self) -> int: + return hash(("in", int(self))) class OutBuffer(Buffer): - pass + def __eq__(self, other: Any) -> bool: + return isinstance(other, OutBuffer) and int(self) == int(other) + + def __hash__(self) -> int: + return hash(("out", int(self))) T = TypeVar("T") @@ -44,9 +52,20 @@ def __init__(self, buffers: Sequence[Buffer]): assert all(isinstance(b, Buffer) for b in self.buffers), self.buffers assert sum(isinstance(b, OutBuffer) for b in self.buffers) == 1, self.buffers + def __eq__(self, other: Any) -> bool: + return isinstance(other, Computation) and self.buffers == other.buffers + def __hash__(self) -> int: return hash(self.buffers) + def __repr__(self): + IVARS = "abcdefghijklmnopqrstuvwxyz" + OVARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + + return " ".join( + IVARS[b] if isinstance(b, InBuffer) else OVARS[b] for b in self.buffers + ) + @property def num_operands(self) -> int: return len(self.buffers) @@ -103,6 +122,12 @@ class TensorProductExecution: def __init__(self, computations: tuple[Computation, ...]): self.computations = tuple(Computation(c) for c in computations) + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, TensorProductExecution) + and self.computations == other.computations + ) + def __hash__(self) -> int: return hash(self.computations) @@ -117,13 +142,7 @@ def __repr__(self): ) ] for comp in self.computations: - text += [ - " " - + " ".join( - IVARS[b] if isinstance(b, InBuffer) else OVARS[b] - for b in comp.buffers - ) - ] + text += [f" {comp}"] return "\n".join(text) @property diff --git a/cuequivariance/tests/tensor_product_execution_test.py b/cuequivariance/tests/tensor_product_execution_test.py new file mode 100644 index 0000000..65301d3 --- /dev/null +++ b/cuequivariance/tests/tensor_product_execution_test.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cuequivariance as cue +from cuequivariance.tensor_product_execution import InBuffer, OutBuffer + + +def test_group_by_symmetries(): + # x^3 + exe = cue.TensorProductExecution( + [(InBuffer(0), InBuffer(0), InBuffer(0), OutBuffer(0))] + ) + mul, exe = next( + exe.jvp([True]).group_by_symmetries( + [ + (0, 1, 2, 3), + (0, 2, 1, 3), + (1, 0, 2, 3), + (1, 2, 0, 3), + (2, 0, 1, 3), + (2, 1, 0, 3), + ] + ) + ) + # d/dx (x^3) = 3x^2 + assert mul == 3 + assert exe == cue.TensorProductExecution( + [(InBuffer(1), InBuffer(0), InBuffer(0), OutBuffer(0))] + ) From 68d67a3a38a844d2e17702e6dc5a0c2882011b89 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 22 Jan 2025 07:58:28 -0800 Subject: [PATCH 6/6] tensor_product_p: shapes -> output_shapes --- .../primitives/tensor_product.py | 80 +++++++++++-------- .../primitives/tensor_product_vanilla_impl.py | 10 ++- .../tests/primitives/tensor_product_test.py | 11 +++ 3 files changed, 62 insertions(+), 39 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py index cb81c35..7dd2f37 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py @@ -147,7 +147,7 @@ def _partial(*remaining_inputs: jax.Array) -> jax.Array: jnp.reshape(input, (1,) * (len(output_shape) + 1 - input.ndim) + input.shape) for input in inputs ] - shapes = tuple(input.shape[:-1] for input in inputs) + (output_shape,) + output_shapes = tuple(None for _ in inputs) + (output_shape,) exe = TensorProductExecution( [ Computation( @@ -155,7 +155,9 @@ def _partial(*remaining_inputs: jax.Array) -> jax.Array: ) ] ) - (output,) = tensor_product_prim(*inputs, shapes=shapes, d=d, exe=exe, **options) + (output,) = tensor_product_prim( + *inputs, output_shapes=output_shapes, d=d, exe=exe, **options + ) return output @@ -184,7 +186,7 @@ def clean_inputs( def tensor_product_prim( *inputs: jax.Array, # input buffers - shapes: tuple[tuple[int, ...], ...], # shapes of the operands + output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, @@ -198,18 +200,18 @@ def tensor_product_prim( if options.pop("use_custom_primitive", True): return tensor_product_p.bind( - *unique_inputs, shapes=shapes, d=d, exe=exe, **options + *unique_inputs, output_shapes=output_shapes, d=d, exe=exe, **options ) else: return tensor_product_vanilla_impl( - *unique_inputs, shapes=shapes, d=d, exe=exe, **options + *unique_inputs, output_shapes=output_shapes, d=d, exe=exe, **options ) def tensor_product_impl( platform: str | None, *inputs: jax.Array, - shapes: tuple[tuple[int, ...], ...], + output_shapes: tuple[tuple[int, ...] | None, ...], d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, @@ -228,7 +230,7 @@ def dispatch( pass return tensor_product_vanilla_impl( - *inputs, shapes=shapes, d=d, exe=exe, **options + *inputs, output_shapes=output_shapes, d=d, exe=exe, **options ) outputs = [0] * len(exe.out_buffers) @@ -251,7 +253,7 @@ def dispatch( def tensor_product_abstract_eval( *inputs: jax.core.ShapedArray, - shapes: tuple[tuple[int, ...], ...], + output_shapes: tuple[tuple[int, ...] | None, ...], d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, @@ -262,16 +264,15 @@ def tensor_product_abstract_eval( for c in exe.computations: for oid, x in zip(c.in_operands, c.map_inputs(inputs)): - expected_shape = shapes[oid] + (d.operands[oid].size,) - if x.shape != expected_shape: + if x.shape[-1] != d.operands[oid].size: raise ValueError( - f"cuex.tensor_product: expected input to have shape {expected_shape}, got {x.shape}" + f"cuex.tensor_product: expected input to have size {d.operands[oid].size}, got {x.shape[-1]}" ) outputs = [None] * len(exe.out_buffers) for c in exe.computations: out = jax.core.ShapedArray( - shape=shapes[c.out_operand] + (d.operands[c.out_operand].size,), + shape=output_shapes[c.out_operand] + (d.operands[c.out_operand].size,), dtype=options["dtype_output"], ) assert outputs[c.out_buffer] is None or outputs[c.out_buffer] == out @@ -283,12 +284,14 @@ def tensor_product_jvp( primals: tuple[jax.Array, ...], tangents: tuple[jax.Array | ad.Zero, ...], *, - shapes: tuple[tuple[int, ...], ...], + output_shapes: tuple[tuple[int, ...] | None, ...], d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[tuple[jax.Array, ...], tuple[jax.Array | ad.Zero, ...]]: - out_primals = tensor_product_prim(*primals, shapes=shapes, d=d, exe=exe, **options) + out_primals = tensor_product_prim( + *primals, output_shapes=output_shapes, d=d, exe=exe, **options + ) out_tangents = [ad.Zero(p.aval) for p in out_primals] jvp = exe.jvp([not isinstance(t, ad.Zero) for t in tangents]) @@ -300,7 +303,7 @@ def tensor_product_jvp( tmp = tensor_product_prim( *primals, *[t for t in tangents if not isinstance(t, ad.Zero)], - shapes=shapes, + output_shapes=output_shapes, d=multiplicator * d, exe=exe.map_buffers(None, lambda b: exe.out_buffers.index(b)), **options, @@ -314,11 +317,27 @@ def tensor_product_jvp( def tensor_product_transpose( cotangents: tuple[jax.Array | ad.Zero, ...], *inputs: jax.Array | ad.UndefinedPrimal, - shapes: tuple[tuple[int, ...], ...], + output_shapes: tuple[tuple[int, ...] | None, ...], d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[jax.Array | ad.Zero | None, ...]: + # The cotangents replace the outputs as inputs + # The undefined primal inputs become outputs + del output_shapes + output_shapes = [None] * d.num_operands + for comp in exe.computations: + for oid, x in zip(comp.in_operands, comp.map_inputs(inputs)): + if ad.is_undefined_primal(x): + undefined_primal_shape = x.aval.shape[:-1] + # if the following assert fails, we need to change the internal API of the primitive + assert ( + output_shapes[oid] is None + or output_shapes[oid] == undefined_primal_shape + ) + output_shapes[oid] = undefined_primal_shape + output_shapes = tuple(output_shapes) + tr = exe.transpose( [ad.is_undefined_primal(x) for x in inputs], [not isinstance(x, ad.Zero) for x in cotangents], @@ -326,7 +345,7 @@ def tensor_product_transpose( tmp = tensor_product_prim( *[x for x in inputs if not ad.is_undefined_primal(x)], *[x for x in cotangents if not isinstance(x, ad.Zero)], - shapes=shapes, + output_shapes=output_shapes, d=d, exe=tr.map_buffers(None, lambda b: tr.out_buffers.index(b)), **options, @@ -348,7 +367,7 @@ def tensor_product_batching( batched_inputs: tuple[jax.Array, ...], batch_axes: tuple[int | None, ...], *, - shapes: tuple[tuple[int, ...], ...], + output_shapes: tuple[tuple[int, ...] | None, ...], d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, @@ -359,33 +378,24 @@ def prepare(input: jax.Array, axis: int | None) -> jax.Array: else: return jnp.moveaxis(input, axis, 0) + assert len(batched_inputs) == len(batch_axes) batched_inputs = [ prepare(input, axis) for input, axis in zip(batched_inputs, batch_axes) ] new_dim = max(input.shape[0] for input in batched_inputs) - new_shapes = [None] * d.num_operands + new_output_shapes = [None] * d.num_operands for comp in exe.computations: - # inputs - for oid, input in zip(comp.in_operands, comp.map_inputs(batched_inputs)): - expected = input.shape[:-1] - if new_shapes[oid] is None: - new_shapes[oid] = expected - assert new_shapes[oid] == expected - - # output oid = comp.out_operand - expected = (new_dim,) + shapes[oid] - if new_shapes[oid] is None: - new_shapes[oid] = expected - assert new_shapes[oid] == expected - - new_shapes = tuple(new_shapes) - assert all(s is not None for s in new_shapes) + expected = (new_dim,) + output_shapes[oid] + if new_output_shapes[oid] is None: + new_output_shapes[oid] = expected + assert new_output_shapes[oid] == expected + new_output_shapes = tuple(new_output_shapes) outputs = tensor_product_prim( *batched_inputs, - shapes=new_shapes, + output_shapes=new_output_shapes, d=d, exe=exe, **options, diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_vanilla_impl.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_vanilla_impl.py index f09a86d..7d94de4 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_vanilla_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_vanilla_impl.py @@ -29,7 +29,7 @@ def tensor_product_vanilla_impl( *inputs: jax.Array, # input buffers - shapes: tuple[tuple[int, ...], ...], # shapes of the operands + output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, @@ -39,18 +39,20 @@ def tensor_product_vanilla_impl( outputs = [0] * len(exe.out_buffers) for c in exe.computations: + shape = output_shapes[c.out_operand] + assert shape is not None out = sum_cat_list_list( d.operands[c.out_operand], tp_list_list( *c.map_inputs(inputs), - shape=shapes[c.out_operand], + shape=shape, d=d.move_operand_last(c.out_operand), **options, ), - shapes[c.out_operand], + shape, options["dtype_output"], ) - assert out.shape == shapes[c.out_operand] + (d.operands[c.out_operand].size,) + assert out.shape == shape + (d.operands[c.out_operand].size,) outputs[c.out_buffer] += out return tuple(outputs) diff --git a/cuequivariance_jax/tests/primitives/tensor_product_test.py b/cuequivariance_jax/tests/primitives/tensor_product_test.py index 8cbbaa1..f1db5d7 100644 --- a/cuequivariance_jax/tests/primitives/tensor_product_test.py +++ b/cuequivariance_jax/tests/primitives/tensor_product_test.py @@ -178,3 +178,14 @@ def f(w, x): return jnp.sum(a) + jnp.sum(b) jax.jit(jax.grad(f, 0))(w, x) + + +def test_multiple_operand_shape_bug(): + # This was causing an issue in the past. + # Before, it was not possible to have an input + # with a different shape than the output of the same operand. + def h(x): + d = cue.descriptors.spherical_harmonics(cue.SO3(1), [2]).d + return cuex.tensor_product(d, x, x) + + assert jax.jacobian(h)(jnp.array([1.0, 0.0, 0.0])).shape == (5, 3)