Skip to content

Commit

Permalink
Try to use result_collection_constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N committed Jan 6, 2025
1 parent fac65de commit de34ef6
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,13 @@ def _transform_by_pattern(
domain_expr = domain.as_expr()

assert isinstance(tmp_expr.type, ts.TypeSpec)
tmp_names: str | tuple[str | tuple, ...] = type_info.apply_to_primitive_constituents(
lambda x: uids.sequential_id(),
tmp_expr.type,
tuple_constructor=lambda *elements: tuple(elements),
) # TODO: how should tuple_constructorb e handled?

tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = (
type_info.apply_to_primitive_constituents(
type_info.extract_dtype,
tmp_expr.type,
tuple_constructor=lambda *elements: tuple(elements),
)
) # TODO: how should tuple_constructorb e handled?
tmp_names: str | tuple[str | tuple, ...] = type_info.type_tree_map(
result_collection_constructor=lambda elements: tuple(elements)
)(lambda x: uids.sequential_id())(tmp_expr.type)

tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = type_info.type_tree_map(
result_collection_constructor=lambda elements: tuple(elements)
)(type_info.extract_dtype)(tmp_expr.type)

# allocate temporary for all tuple elements
def allocate_temporary(tmp_name: str, dtype: ts.ScalarType):
Expand Down

0 comments on commit de34ef6

Please sign in to comment.