Skip to content

Commit

Permalink
feat: Add CPU offloading capabilities
Browse files Browse the repository at this point in the history
Signed-off-by: Keith Valin <[email protected]>
  • Loading branch information
kdvalin committed May 9, 2024
1 parent a0476ea commit 2db4c06
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@
)

import deepspeed
from deepspeed.ops.adam import FusedAdam
from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam
from multipack_sampler import find_packing_max_batch_len_and_grad_accum
from token_dataset import setup_dataloader, setup_dataset
from tokenizer_utils import setup_tokenizer
from utils import save_hf_format_ds, set_random_seed, setup_logger, convert_loss_to_reduce_sum


def get_ds_config(world_size, samples_per_gpu, grad_accum):
def get_ds_config(world_size, samples_per_gpu, grad_accum, offload_param, offload_optimizer):
ds_config = {
"train_batch_size": samples_per_gpu * world_size * grad_accum,
"gradient_accumulation_steps": grad_accum,
"train_micro_batch_size_per_gpu": samples_per_gpu,
"steps_per_print": 1,
"zero_optimization": {
"stage": 2,
"offload_param": {"device": "none"},
"offload_optimizer": {"device": "none"},
"offload_param": {"device": offload_param},
"offload_optimizer": {"device": offload_optimizer},
},
"bf16": {"enabled": True},
"gradient_clipping": 1.0,
Expand Down Expand Up @@ -71,8 +71,12 @@ def setup_model(args, tokenizer, train_loader, grad_accum):

model = convert_loss_to_reduce_sum(model)
model.gradient_checkpointing_enable()

if args.deepspeed_optimizer == "fused":
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95))
else:
optimizer = DeepSpeedCPUAdam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95))

optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95))
lr_scheduler = get_scheduler(
name="cosine",
optimizer=optimizer,
Expand All @@ -87,6 +91,8 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
world_size=torch.distributed.get_world_size(),
samples_per_gpu=args.samples_per_gpu,
grad_accum=grad_accum,
offload_optimizer=args.offload_optimizer,
offload_param=args.offload_param
),
lr_scheduler=lr_scheduler,
dist_init_required=True,
Expand Down Expand Up @@ -270,6 +276,9 @@ def main(args):
)
parser.add_argument("--is_granite", action="store_true")
parser.add_argument("--max_batch_len", type=int, default=60000)
parser.add_argument("--offload_optimizer", type=str, default="none", choices=["none", "cpu"])
parser.add_argument("--offload_param", type=str, default="none", choices=["none", "cpu"])
parser.add_argument("--deepspeed_optimizer", type=str, default="fused", choices=["fused", "cpu"])
args = parser.parse_args()
set_random_seed(args.seed)
main(args)
Expand Down

0 comments on commit 2db4c06

Please sign in to comment.