Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

w/o model-parallel usability numbers reproduce #284

Open
g-karthik opened this issue Jul 6, 2020 · 31 comments
Open

w/o model-parallel usability numbers reproduce #284

g-karthik opened this issue Jul 6, 2020 · 31 comments
Assignees

Comments

@g-karthik
Copy link

I've been using DeepSpeed successfully with my large model train jobs. But this blog post says ZeRO-1 and ZeRO-2 power up to 6B and 13B param training respectively w/o model-parallelism.

Where exactly is the code that validates/enables others to reproduce this claim? I want to see what model was used, what was the batch size, what was the max sequence length, etc. for this claim.

@samyam
Copy link
Contributor

samyam commented Aug 24, 2020

@g-karthik, sorry for the late response. In case this is still relevant to you, these are for GPT-2 style language model, even though you should be able to run any model of similar sizes using ZeRO-1 and ZeRO-2. The exact model configuration, batch size, GPU count etc are as follows:
image

The sequence length is 1K. You can find this table in the appendix of our archive paper (https://arxiv.org/pdf/1910.02054.pdf). You can find more details in the paper itself.

To reproduce the results, you can use our Megatron-LM tutorial (https://www.deepspeed.ai/tutorials/megatron/). Once you have Megatron-LM running with DeepSpeed, you can set the model parallelism degree to 1 to disable model parallelism, use model configurations and GPU count shown in the table above and enable ZeRO-2 in deepspeed config file.

Please note that to run a 13B parameter model with ZeRO-2, you would need around 128 GPUs with 32 GB memory each to share the optimizer and gradient states. But you should be able to fit a 5-8B parameter model with as little as 16 GPUs.

@g-karthik
Copy link
Author

g-karthik commented Aug 27, 2020

Hi @samyam thanks for the detailed info!

I did try to use the tutorial you linked earlier, and it seemed to work in some cases and failed in others.

Back then, I faced this error in those other cases (pulled from my logs): FileExistsError: [Errno 17] File exists: '/path/to/DeepSpeedExamples/Megatron-LM/data/webtext/openwebtext.lazy'

Note that I did not create a symbolic link to the data as mentioned in the tutorial you linked - I put the actual data (openwebtext.json) there.

Do you know why this is happening?

UPDATE:

I tried running it again now, I got the same FileExistsError (multiple rows showing that in the logs) but after the error, it shows:

FileExistsError: [Errno 17] File exists: '/path/to/DeepSpeedExamples/Megatron-LM/data/webtext/openwebtext.lazy'

> padded vocab (size: 50257) with 175 dummy tokens (new size: 50432)
> found end-of-document token: 50256
building GPT2 model ...
 > number of parameters on model parallel rank 1: 329328640
 > number of parameters on model parallel rank 0: 329328640

Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "pretrain_gpt2.py", line 707, in <module>
  File "pretrain_gpt2.py", line 707, in <module>
  File "pretrain_gpt2.py", line 707, in <module>
    main()
    main()
    main()
  File "pretrain_gpt2.py", line 652, in main
  File "pretrain_gpt2.py", line 652, in main
  File "pretrain_gpt2.py", line 652, in main
    args.eod_token = get_train_val_test_data(args)
    args.eod_token = get_train_val_test_data(args)
    args.eod_token = get_train_val_test_data(args)
  File "pretrain_gpt2.py", line 619, in get_train_val_test_data
  File "pretrain_gpt2.py", line 619, in get_train_val_test_data
  File "pretrain_gpt2.py", line 619, in get_train_val_test_data
    group=mpu.get_model_parallel_group())
    group=mpu.get_model_parallel_group())
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/distributed/distributed_c10d.py", line 810, in broadcast
    group=mpu.get_model_parallel_group())
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/distributed/distributed_c10d.py", line 810, in broadcast
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/distributed/distributed_c10d.py", line 810, in broadcast
    work = group.broadcast([tensor], opts)
    work = group.broadcast([tensor], opts)
    work = group.broadcast([tensor], opts)
RuntimeError: Socket Timeout
RuntimeError: Socket Timeout
RuntimeError: Socket Timeout

@tjruwase
Copy link
Contributor

Hi @g-karthik.

I am surprised by the FileExistsError. The expected path for webtext data is data/webtext/data.json. Did you modify the code to match your data path of data/webtext/openwebtext.lazy?

Also, can you share the model-parallel (MP) and data-parallel (DP) degrees of your run? Are you able to run with MP=DP=1, which should fit your .3B parameter model?

@g-karthik
Copy link
Author

Hi @tjruwase, the .lazy file wasn't created by me, it was created automatically when I ran the code.

And yeah I'd modified the expected path to the webtext data in that line you pointed to, to contain the path to my openwebtext.json file. And I also changed the text_key from "text" to "content" in that same class since my JSON has the text in that key. Other than that, I didn't make any code changes as such.

Yeah the MP degree is 2 and DP degree is 8 when I get this FileExistsError followed by Socket Timeout. See command below:

GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"

python -m torch.distributed.launch $DISTRIBUTED_ARGS \
       pretrain_gpt2.py \
       --model-parallel-size 2 \
       --num-layers 48 \
       --hidden-size 1024 \
       --num-attention-heads 32 \
       --batch-size 8 \
       --seq-length 1024 \
       --max-position-embeddings 1024 \
       --train-iters 320000 \
       --save checkpoints/gpt2_xl_mp2 \
       --load checkpoints/gpt2_xl_mp2 \
       --resume-dataloader \
       --train-data webtext \
       --lazy-loader \
       --tokenizer-type GPT2BPETokenizer \
       --split 949,50,1 \
       --distributed-backend nccl \
       --lr 0.00015 \
       --no-load-optim \
       --lr-decay-style cosine \
       --weight-decay 1e-2 \
       --clip-grad 1.0 \
       --warmup .01 \
       --checkpoint-activations \
       --fp16 \
       --deepspeed \
       --deepspeed_config scripts/ds_zero2_config.json

I haven't tried the MP=DP=1 case with this script yet. I just kicked off the above once again and got the FileExistsError again, but the processes continue to run. I'm waiting to see if I get the Socket Timeout again (I suspect I will).

@tjruwase
Copy link
Contributor

@g-karthik thanks for sharing those details. Another question, why are you using python to launch rather than the deepspeed launcher?

@g-karthik
Copy link
Author

@tjruwase oh that's just the contents of the pretrain_gpt2_model_parallel.sh shell script, with my minor modifications to the parameters to be used.

I do use the deepspeed launcher with my integration of DeepSpeed into my own codebase outside of the DeepSpeedExamples repo. Although as a side-note, I prefer to use python with --deepspeed and --deepspeed_mpi to scale up my jobs effectively with mpirun on Kubernetes.

@g-karthik
Copy link
Author

@tjruwase just an update on the MP=1, DP=1 case you asked me to run, I set model-parallel-size=1, GPUS_PER_NODE=1. It seems to be running fine, it hasn't thrown a FileExistsError or a Socket Timeout error.

validation loss at iteration 1000 | LM loss: 6.297972E+00 | LM PPL: 5.434685E+02

@tjruwase
Copy link
Contributor

@g-karthik thanks for sharing that update. So it seems we can assume the problem is related to distributed training. Can you enable NCCL debugging information for the distributed case? I believe you can set the NCCL_DEBUG environment variable like:
NCCL_DEBUG=INFO

Also, I will suggest using deepspeed launcher because it handles all the NCCL configuration issues for distributed training. It would be great to find out if it can fix the Socket Timeout error.

Finally, for the distributed case I am curious if all or just some of the processes report the FileExists error. I see that this filename is created here. Although, I have not looked closely at this code before, but I suspect this file is meant to be created only once per node.

@g-karthik
Copy link
Author

g-karthik commented Dec 1, 2020

@tjruwase @samyam

Hope you're doing well! I'm having some trouble reproducing the 1K sequence length setting with Hugging Face's GPT-2 class.

Specifically, I'm using 8 V100 GPUs (32 GB each) and the following deepspeed configuration:

{
  "train_batch_size": 128,
  "gradient_accumulation_steps": 8,
  "gradient_clipping": 1.0,
  "optimizer": {
    "type": "adam",
    "params": {
      "lr": 6.25e-5
    }
  },
  "fp16": {
    "enabled": true
  },
  "zero_optimization": {
    "stage": 2,
    "cpu_offload": true,
    "contiguous_gradients": true,
    "overlap_comm": false,
    "allgather_bucket_size": 500000000
  },

  "activation_checkpointing": {
    "partition_activations": true,
    "contiguous_memory_optimization": true,
    "cpu_checkpointing": true
  }

}

I'm kicking these jobs off with python using --deepspeed_mpi and mpi4py instead of the regular deepspeed launcher. I'm consistently getting CUDA OOM for a seq length of 1K with even 1.5B parameters. I can certainly scale up my cluster and increase the effective train batch size to eliminate these errors (due to the inverse dependency on DP degree for memory consumption with ZeRO), but it seems like it shouldn't be needed for 1.5B. I also understand there are implementation differences between the Megatron GPT-2 class and the Hugging Face GPT-2 class, but I've trained these models with slightly smaller sequence lengths than 1K before. Any ideas on what I am missing?

@g-karthik
Copy link
Author

@tjruwase @jeffra @samyam @ShadenSmith also note that with the above config, I'm able to train with a sequence length of 512 just fine.

I vaguely remember seeing somewhere in the DeepSpeed docs (or the paper) that you use 16 attention heads for 1.5B, whereas Hugging Face's version uses 25 attention heads. Assuming my memory is indeed correct, I wonder if your 16 attention head version was indeed a 1.5B version? Or perhaps you compensated for it by increasing another dimension like hidden size?

@tjruwase
Copy link
Contributor

tjruwase commented Dec 2, 2020

@g-karthik, I think you should be able to train 1.5B model easily on a single 32GB GPU with cpu-offload. Can you please share a link to the the GPT-2 script? I am curious as to whether you are passing an optimizer into deepspeed.initialize()? I am also interested in repro'ing the OOM.

@g-karthik
Copy link
Author

@tjruwase Yeah I'm using this implementation of the GPT2LMHeadModel by Hugging Face: https://github.com/huggingface/transformers/blob/v3.1.0/src/transformers/modeling_gpt2.py#L652

No, I'm not passing an optimizer into deepspeed.initialize().

Although technically my deepspeed_config.json above has "activation_checkpointing" set up, it would go unused if using the above GPT2LMHeadModel as-is. I've now configured deepspeed.checkpointing and added explicit deepspeed.checkpointing.checkpoint() calls within the underlying base Transformer decoder, I haven't tested it out yet. However, would you say that a 1K sequence length should fit easily for a 1.5B configuration even without deepspeed's activation checkpointing set up?

@g-karthik
Copy link
Author

g-karthik commented Mar 2, 2021

@tjruwase @jeffra @samyam

I am unable to reproduce the claim in Section 10.4 of the ZeRO paper: "Fig. 4 shows that ZeRO-100B can train models with up to 13B parameters without MP on 128 GPUs, achieving throughput over 40 TFlops per GPU on average."

I am using Hugging Face's GPT2LMHeadModel with an 8 billion parameter configuration and when I run the new Flops Profiler provided with DeepSpeed, I hit about 2.7 TFLOPS (which I presume is per GPU) as printed by the profiler. I just use 8 GPUs total.

I am also unable to fit 1024 sequence length in my micro-batches, due to presumable deficiencies in deepspeed's activation checkpointing, see relevant discussions #598 (comment) and #598 (comment). It would be great if these could be addressed in a subsequent PR.

Could you please help identify how I could hit 40 TFLOPS per GPU with the Hugging Face implementation, just as you seem to have hit that with the Megatron implementation with MP=1?

@g-karthik
Copy link
Author

@ShadenSmith @tjruwase @jeffra have you had a chance to look at the above?

@g-karthik
Copy link
Author

@ShadenSmith @tjruwase @jeffra @samyam Hey guys I'd really appreciate a response on the specific points I raised above because it seems that DeepSpeed CANNOT help achieve high scaling efficiency as claimed in the paper on a standard V100 cluster with high RDMA bandwidth unless these are addressed.

I would greatly appreciate a quick response on this.

@tjruwase
Copy link
Contributor

@g-karthik, I really apologize for the radio silence on this issue. Thanks for your patience. I am working on this now, trying to recover the configuration settings for the results. I will update you asap.

@tjruwase
Copy link
Contributor

tjruwase commented Mar 26, 2021

I refreshed myself a bit of the numbers and perhaps I can now be of help. So it sounds like you want to repro the 8B numbers of 46TFLOPs/GPU. The configurations for that result on 128 GPUs are as follows:

GPT2 config: hidden=3072, layers=72, attention-heads =24, batch-size = 8

DS config :

{
  "zero_optimization": {
   "stage": 2,
   "contiguous_gradients": true,
   "reduce_scatter": false,
   "reduce_bucket_size": 1000000000,
  "allgather_bucket_size": 200000000
  },
   "activation_checkpointing": {
    "partitioned_activations": false,
    "number_checkpoints": 1,
    "contiguous_memory_optimization": true,
    "cpu_checkpointing": false,
    "profile": true,
    "synchronize_checkpoint_boundary": true
    }
}

Please let us know how it goes.

@g-karthik
Copy link
Author

@tjruwase thank you so much for your response!

As I mention above, I'm not using Megatron-style model-parallelism and the numbers I'm seeing are for Hugging Face-style GPT-2 models. What surprises me is the vast difference between your numbers and mine, when all you're doing is set MP=1 (i.e., degree of model-parallelism = 1).

Do you have any thoughts on the specific points and references I report above in #284 (comment)?

Given that y'all are supporting Hugging Face-style models via a direct integration into the HF Trainer, I think it's critical that those models are able to achieve 40+ TFLOPS/GPU when one tries to scale the number of parameters up.

This is the config for a GPT-2 XL model on Hugging Face: you can instantiate it by doing config = GPT2Config.from_pretrained("gpt2-xl").

I created an equivalent config JSON file for an 8B parameter version of GPT-2, by setting:

  "n_embd": 3584,
  "n_head": 32,
  "n_layer": 50,

with the remaining config keys exactly the same as that for the XL model I linked above.

Then, I passed the path to that new JSON as follows: config = GPT2Config.from_pretrained("/path/to/my/config.json").

And I tried to train that version of a GPT2LMHeadModel by doing model = GPT2LMHeadModel(config).

This is where I see a terrible TFLOPS/GPU. Can you please try reproing this?

@jeffra
Copy link
Collaborator

jeffra commented Mar 26, 2021

Hi @g-karthik, so sorry on our late replies and your issues here :( this is a tricky issue. One issue is that we haven't done much performance analysis on the hugging face implementation of GPT. I think in the medium term we would really love to dive into this deeper as there are lots of folks using this implementation that could really benefit from various deepspeed features.

In terms of the issues you're seeing w.r.t. performance, I suspect there could be non-deepspeed related differences going on here that could be contributing to the low perf. I see that you were trying to use the Megatron version at one point earlier in this thread, did you ever get that running? or was data the main issue here? If data is an issue i would be happy to send you our small test dataset we use for performance tuning. If so, please email me or DM me on twitter (see email/twitter in my github profile).

@g-karthik
Copy link
Author

@jeffra @tjruwase I don't think the recent observations I report above have anything to do with non-DeepSpeed issues or even data issues. In fact, as you can see in #284 (comment), I actually link to detailed observations and discussions regarding presumable deficiencies in DeepSpeed's activation checkpointing when attempting to use it with no model-parallelism.

Yes, I originally tried to use the Megatron version for reproducibility, only because y'all used it to report your no-model-parallel (MP=1) numbers. Seeing as I was getting socket timeouts and file exists errors when using a world size > 1, I eventually just gave up on the version of the code in the DeepSpeedExamples repo. I use Hugging Face implementations now because those implementations are standard and highly used, so I think it is very important that any claims of high TFLOPS/GPU with DeepSpeed stages enabled are validated with these implementations.

@g-karthik
Copy link
Author

@tjruwase @jeffra I think it should be fairly straightforward to find out what kind of performance (TFLOPS/GPU and all-reduce times) you're seeing with larger sized Hugging Face models. You could even just use T5-11B with the HF Trainer and DeepSpeed, with logging for wall-clock breakdown and the flops profiler enabled.

ccing @stas00 since I believe he has done some testing with T5-11B.

@stas00
Copy link
Collaborator

stas00 commented Mar 30, 2021

I'd be totally happy to do the performance analysis for HF - if you could just give me some guidelines on how to do it.

And yes I did experiments with t5-11b DS/HF.

But let's finish the integration of zero3 first and then count me in!

@tjruwase
Copy link
Contributor

Yes, I originally tried to use the Megatron version for reproducibility, only because y'all used it to report your no-model-parallel (MP=1) numbers. Seeing as I was getting socket timeouts and file exists errors when using a world size > 1, I eventually just gave up on the version of the code in the DeepSpeedExamples repo. I use Hugging Face implementations now because those implementations are standard and highly used, so I think it is very important that any claims of high TFLOPS/GPU with DeepSpeed stages enabled are validated with these implementations.

@g-karthik, I just wanted to chime in my agreement on the importance of DeepSpeed maintaining high TFLOPS/GPU on HF models in general. I also think that both GPT-2 is and t5-11B are exciting starting points once zero integration is completed, thanks to the amazing work of @stas00 .

@g-karthik
Copy link
Author

g-karthik commented Mar 31, 2021

@stas00 thanks for your help - when running T5-11B, just enable the flops profiler and wall-clock breakdown in your DS config JSON as described here: https://www.deepspeed.ai/features/#performance-analysis-and-debugging You can also check out #284 (comment) to see how I tested this for GPT-2 with a different GPT2Config.

ZeRO-3 increases communication volume by 1.5x, so if ZeRO-2 itself is not giving good TFLOPS/GPU for HF models, then ZeRO-3 likely won't either.

@tjruwase @jeffra can you please take a look at #598 (comment) and #598 (comment) and let me know what I am missing?

It is my understanding that partition_activations, cpu_checkpointing and contiguous_memory_optimization are being supported for regular non-model-parallel modules with PyTorch Lightning - I am confused why these require model-parallelism (MP > 1).

@tjruwase
Copy link
Contributor

@tjruwase @jeffra can you please take a look at #598 (comment) and #598 (comment) and let me know what I am missing?

It is my understanding that partition_activations, cpu_checkpointing and contiguous_memory_optimization are being supported for regular non-model-parallel modules with PyTorch Lightning - I am confused why these require model-parallelism (MP > 1).

@g-karthik, regarding this issue, quite a number of improvements have gone into activation checkpointing support. Can you please check which of your concerns remain?

@g-karthik
Copy link
Author

@tjruwase I just installed the latest version of deepspeed and now I see that activation checkpointing cannot be configured prior to distributed initialization any more. So I'm unable to test those flags now.

Specifically, @samyam's addition of this logger.info() in the ZeRO-3 release last month is failing: https://github.com/microsoft/DeepSpeed/blame/master/deepspeed/runtime/activation_checkpointing/checkpointing.py#L762

And understandably so, because the previous approach to configuring deepspeed's activation checkpointing was to configure it PRIOR to deepspeed.initialize() (which internally initializes torch.distributed). But now we can no longer configure it prior to deepspeed.initialize()

Can you please fix this?

@g-karthik
Copy link
Author

Also @stas00 are you free to do some profiling with HF models now?

@stas00
Copy link
Collaborator

stas00 commented Apr 9, 2021

At this moment I don't have any resources to dedicate to this, as I'm trying to figure out this bfloat16-pretrained models not doing well on deepspeed or any other mixed precision platforms.

But when I'm done, absolutely!

But before we do that, we need to merge this: #910

and also this too #873 but it breaks transformers tests.

I suggest the next deepspeed release performance-improvement release.

@g-karthik
Copy link
Author

@tjruwase @jeffra can you please take a look at #598 (comment) and #598 (comment) and let me know what I am missing?
It is my understanding that partition_activations, cpu_checkpointing and contiguous_memory_optimization are being supported for regular non-model-parallel modules with PyTorch Lightning - I am confused why these require model-parallelism (MP > 1).

@g-karthik, regarding this issue, quite a number of improvements have gone into activation checkpointing support. Can you please check which of your concerns remain?

@tjruwase @jeffra @ShadenSmith this issue still remains. I ran two jobs with a Hugging Face GPT-2 model with DeepSpeed's activation checkpointing:

  1. partition_activations, cpu_checkpointing and contiguous_memory_optimization are set to False
  2. partition_activations, cpu_checkpointing and contiguous_memory_optimization are set to True

The embedding dimension I used was "n_embd": 3584.

Job 1 ran fine. Job 2 threw this error:

    loss = model_engine.backward(loss)
  File "/usr/local/lib/python3.6/dist-packages/deepspeed/runtime/engine.py", line 997, in backward
    self.optimizer.backward(loss)
  File "/usr/local/lib/python3.6/dist-packages/deepspeed/runtime/zero/stage2.py", line 1636, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/usr/local/lib/python3.6/dist-packages/deepspeed/runtime/fp16/loss_scaler.py", line 53, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/usr/local/lib/python3.6/dist-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
  File "/usr/local/lib/python3.6/dist-packages/torch/autograd/function.py", line 77, in apply
    return self._forward_cls.backward(self, *args)
  File "/usr/local/lib/python3.6/dist-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 660, in backward
    outputs = ctx.run_function(*detached_inputs)
  File "/modeling_gpt2.py", line 212, in custom_forward
    return tuple(output for output in module(*inputs, use_cache, output_attentions))
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_gpt2.py", line 287, in forward
    self.ln_1(hidden_states),
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/normalization.py", line 153, in forward
    input, self.normalized_shape, self.weight, self.bias, self.eps)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py", line 1696, in layer_norm
    torch.backends.cudnn.enabled)
RuntimeError: Given normalized_shape=[3584], expected input with shape [*, 3584], but got input of size[3670016]

Can you please help with this asap? I would greatly appreciate detailed and specific answers/explanations for #598 (comment) and #598 (comment).

Also, how are PyTorch Lightning and DeepSpeedExamples able to use these flags without model-parallelism (MP=1)? Does ZeRO-3 need to be enabled to get this to work?

@tjruwase
Copy link
Contributor

@g-karthik, can you please provide how to repro this failure? I could not repro with default HF GPT-2 settings. It ran without problems for me with
{ "zero_optimization": { "stage": 2 }, "fp16": { "enabled": true, "initial_scale_power": 10 }, "activation_checkpointing": { "partition_activations": true, "contiguous_memory_optimization": true, "cpu_checkpointing": true, "synchronize_checkpoint_boundary": true }
How did you change n_embd?

@g-karthik
Copy link
Author

@tjruwase I suspect you tried to set those specific activation checkpointing args to true without actually changing the model class. For you to be able to repro this, you need to change the underlying model class' checkpointing from torch's checkpointing to deepspeed's checkpointing here.

Change that to deepspeed.checkpointing.checkpoint. And also make sure you run:
deepspeed.checkpointing.configure(mpu_=None, deepspeed_config=args.deepspeed_config) before instantiating the model class.

About changing model configurations such as n_embd, I did that by creating a separate config JSON off this and using that JSON to instantiate the GPT2Config class. However, I do not think that would be necessary for you to repro this issue. You could easily repro this with an existing GPT-2 config like XL.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants