Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REQUEST] Some questions about deepspeed sequence parallel #6708

Open
yingtongxiong opened this issue Nov 4, 2024 · 10 comments
Open

[REQUEST] Some questions about deepspeed sequence parallel #6708

yingtongxiong opened this issue Nov 4, 2024 · 10 comments
Labels
enhancement New feature or request

Comments

@yingtongxiong
Copy link

Hello, I want to run sequence parallel on pure deepspeed repo. However, I found that it is necessary to let developer to create sequence parallel process group, is it right? I want to know there is any solutions to use sequence parallel or MoE(which also requires expert_data_process_group and so on) on pure deepspeed.

@yingtongxiong yingtongxiong added the enhancement New feature or request label Nov 4, 2024
@samadejacobs
Copy link
Contributor

@yingtongxiong , the recommended use of Deepspeed sequence parallelism (deepspeed ulysses) is to call it from a client framework/script. Please take a look at these two examples: Megatron-DeepSpeed, HugingFace transformer

@yingtongxiong
Copy link
Author

Okay Thank you very much

@yingtongxiong
Copy link
Author

https://github.com/microsoft/DeepSpeedExamples/blob/uly-hf/post_training/sequence_parallelism/test_ulysses.py#L113 I see in here, the mesh_param is commented, so I think if I want to use sp, this parameters should be transmitted, is it right? @samadejacobs

@yingtongxiong
Copy link
Author

Aslo, when I use sp all2all overlap, I found a little bug.

assert ctx.stream != None
when stream is not None, the assert is still False, so I think it should be "ctx.stream is not None".

@ronald-d-rogers
Copy link

@yingtongxiong I think mesh_param is not needed as it is inferred from the config dict parameters (i.e. sequence_parallel_size/data_parallel_size).

@yingtongxiong
Copy link
Author

@yingtongxiong I think mesh_param is not needed as it is inferred from the config dict parameters (i.e. sequence_parallel_size/data_parallel_size).

Thank you, And I don't know how to set sequence_parallel_size? In config? how to transmit to deepspeed.initialize? maybe can you give me an example?

@ronald-d-rogers
Copy link

ronald-d-rogers commented Nov 26, 2024

Sure, in the test it is here:
https://github.com/microsoft/DeepSpeedExamples/blob/uly-hf/post_training/sequence_parallelism/test_ulysses.py#L57

data_parallel_size is typically num_gpus // sequence_parallel_size, where num_gpus is total gpus (or num you want to use).

Note that the number of processes you want to launch is still the number of gpus you are using, so data_parallel_size * sequence_parallel_size.

So, for example, if you're using torchrun to run the script then you want --nproc-per-node to be set to 4 if you have 4 total gpus, whatever your arrangement of sp and dp is (so long as sp*dp=4).

Each data parallel group takes in a separate copy of the full train dataset, and each sequence parallel rank in each data parallel group needs to be given the slice of data according to that rank. You need to do this splitting yourself. You must also let the model know which parts of the sequence it's processing by providing a position_ids tensor, which is a tensor of values 0 ... seq_len, split according to rank.

AFAIK, the above test_ulysses.py example does not work without a PR to Transformers that has not yet been merged:
huggingface/transformers#32305

There is another attempt at integrating ulysses requiring changes across accelerate and transformers made by @zeyugao here:
huggingface/accelerate#2877
huggingface/transformers#31525

It seems likely that the changes to modeling_llama.py made in the above PR would also be needed.

The main question I have is if we need to use this new loss if we're using tranformers, and would that entail updating all of the trainers.

@ronald-d-rogers
Copy link

@yingtongxiong: I will try to put together a working example soon (if what I have is actually working)

@yingtongxiong
Copy link
Author

to

Thank you very much

@yingtongxiong
Copy link
Author

@yingtongxiong: I will try to put together a working example soon (if what I have is actually working)

okay thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants