diff --git a/src/sumcumprod_jax/sumcumprod_jax.py b/src/sumcumprod_jax/sumcumprod_jax.py index 800217b..51bc72f 100644 --- a/src/sumcumprod_jax/sumcumprod_jax.py +++ b/src/sumcumprod_jax/sumcumprod_jax.py @@ -163,13 +163,13 @@ def _sumcumprod_masked_lowering(ctx, input1, input2, *, platform="cpu"): return custom_call( op_name, # Output types - out_types=[dtype], + result_types=[dtype], # The inputs have to be int32 because for int64 for some reason it does not work: operands=[mlir.ir_constant(np.int32(size)), mlir.ir_constant(np.int32(int_size_of_last_dim)), input1, input2], # Layout specification: operand_layouts=[(), (), layout, layout], result_layouts=[layout] - ) + ).results elif platform == "gpu": if gpu_ops is None: @@ -183,7 +183,7 @@ def _sumcumprod_masked_lowering(ctx, input1, input2, *, platform="cpu"): return custom_call( op_name, # Output types - out_types=[dtype], + result_types=[dtype], # The inputs: operands=[input1, input2], # Layout specification: @@ -191,7 +191,7 @@ def _sumcumprod_masked_lowering(ctx, input1, input2, *, platform="cpu"): result_layouts=[layout], # GPU specific additional data backend_config=opaque - ) + ).results raise ValueError( "Unsupported platform; this must be either 'cpu' or 'gpu'"