diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 5440bbe58..4d4c60bc0 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -238,8 +238,8 @@ def apply_tp( torch.ops._c10d_functional.reduce_scatter_tensor.default, # for low precision training, it's useful to always save # the result of max(abs(tensor)) - # torch.ops.aten.abs.default, - # torch.ops.aten.max.default, + torch.ops.aten.abs.default, + torch.ops.aten.max.default, }