diff --git a/teaal/trans/partitioner.py b/teaal/trans/partitioner.py index bfb686e..cec2075 100644 --- a/teaal/trans/partitioner.py +++ b/teaal/trans/partitioner.py @@ -110,9 +110,12 @@ def unpartition(self, tensor: Tensor) -> Statement: swizzled_ranks, part_ir.get_all_parts(), False) for part in valid_parts: trans.append((part, tensor.get_ranks())) + + # If this is a partition, apply all partitioning + # If this is a flatten, just flatten tensor.update_ranks( part_ir.partition_ranks( - tensor.get_ranks(), {part}, True, False)) + tensor.get_ranks(), {part}, len(part) == 1, False)) new_ranks = tensor.get_ranks()