forked from meta-llama/llama-cookbook
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfsdp.py
19 lines (15 loc) · 794 Bytes
/
fsdp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from dataclasses import dataclass
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
@dataclass
class fsdp_config:
mixed_precision: bool=True
use_fp16: bool=False
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool=True
fsdp_cpu_offload: bool=False
pure_bf16: bool = False
optimizer: str= "AdamW"