Skip to content

Commit

Permalink
Support SPMD fsdp compute dtype (#13)
Browse files Browse the repository at this point in the history
* feat: support compute dtype in spmd fsdp

* fix: import functions from python fsdp

* fix: unused import

* fix: remove comment
  • Loading branch information
lausannel authored Sep 20, 2024
1 parent a32543b commit 063384c
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions torch_xla/experimental/spmd_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from typing import (Any, Callable, Dict, Optional, Union)
import warnings
from typing import (Any, Callable, Dict, Optional, Union)

import numpy as np
import torch
import torch.nn as nn
from torch._prims_common import TensorLike, TensorSequenceType

import numpy as np

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as spmd
from torch_xla.distributed.fsdp.wrap import recursive_wrap
from torch_xla.distributed.fsdp._init_utils import _materialize_module
from torch_xla.distributed.fsdp.wrap import recursive_wrap
from torch_xla.distributed.fsdp.xla_fully_sharded_data_parallel import _cast_floats_tensors

FLOAT_DTYPES = [torch.float32, torch.float16, torch.bfloat16]


def _prepare_spmd_partition_spec(param):
Expand Down Expand Up @@ -40,13 +42,18 @@ class SpmdFullyShardedDataParallel(nn.Module):
The callable should have the signature (output, mesh) -> None.
If None, the default implementation will shard the first tensor in the output.
If the output is a tuple, only the first tensor will be sharded.
compute_dtype (torch.dtype, Optional):
dtype for full parameters for computation. This defaults to
``torch.float32`` but can be set to ``torch.float16`` or
``torch.bfloat16``. The sharded parameters will always be in FP32.
"""

def __init__(
self,
module: nn.Module,
mesh: Optional[spmd.Mesh] = None,
shard_output: Optional[Callable] = None,
compute_dtype: Optional[torch.dtype] = None,
auto_wrap_policy: Optional[Callable] = None,
auto_wrapper_callable: Optional[Callable] = None,
):
Expand Down Expand Up @@ -96,6 +103,11 @@ def __init__(
)
self._auto_wrap(auto_wrap_kwargs, fsdp_kwargs)

if compute_dtype is not None and compute_dtype not in FLOAT_DTYPES:
raise ValueError(
f"compute_dtype must be one of {FLOAT_DTYPES}, not {compute_dtype}")
self.compute_dtype = compute_dtype or torch.float32

_materialize_module(
module,
None, [],
Expand Down Expand Up @@ -150,6 +162,9 @@ def module(self) -> nn.Module:
return self._orig_module

def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
if self.compute_dtype != torch.float32:
# Cast the input float tensors to the specified compute_dtype
args, kwargs = _cast_floats_tensors(self.compute_dtype, *args, **kwargs)
output = self.module(*args, **kwargs)
# Need to shard the output of the forward to instruct the compiler
# to enforce the FSDP algorithm.
Expand Down

0 comments on commit 063384c

Please sign in to comment.