Skip to content

Commit

Permalink
fixing sumcumprod_masked
Browse files Browse the repository at this point in the history
  • Loading branch information
frskplis committed Jan 8, 2024
1 parent 138df4a commit c58a234
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/sumcumprod_jax/sumcumprod_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,13 @@ def _sumcumprod_masked_lowering(ctx, input1, input2, *, platform="cpu"):
return custom_call(
op_name,
# Output types
out_types=[dtype],
result_types=[dtype],
# The inputs have to be int32 because for int64 for some reason it does not work:
operands=[mlir.ir_constant(np.int32(size)), mlir.ir_constant(np.int32(int_size_of_last_dim)), input1, input2],
# Layout specification:
operand_layouts=[(), (), layout, layout],
result_layouts=[layout]
)
).results

elif platform == "gpu":
if gpu_ops is None:
Expand All @@ -183,15 +183,15 @@ def _sumcumprod_masked_lowering(ctx, input1, input2, *, platform="cpu"):
return custom_call(
op_name,
# Output types
out_types=[dtype],
result_types=[dtype],
# The inputs:
operands=[input1, input2],
# Layout specification:
operand_layouts=[layout, layout],
result_layouts=[layout],
# GPU specific additional data
backend_config=opaque
)
).results

raise ValueError(
"Unsupported platform; this must be either 'cpu' or 'gpu'"
Expand Down

0 comments on commit c58a234

Please sign in to comment.