Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support device mapping for Paged Attention #1011

Merged
merged 55 commits into from
Jan 1, 2025

Conversation

cdoko
Copy link
Contributor

@cdoko cdoko commented Dec 28, 2024

Added support for device mapping in Paged Attention by passing the device list from the mapper to the cache engine. Manual moving was required for certain unmapped tensors in the paged attention forward pass. I have tested the device mapping support on several text models and it appears to be functional.

Memory allocation is currently calculated as if all memory is available on a single device, resulting in that memory being split across devices; ideally, we should calculate available memory per GPU. If this PR is fine, I will work on this next.

Additionally, this feature currently only supports GPU devices due to an error when attempting to use reshape_and_cache() with non-cuda tensors.

Please let me know if you'd like me to revise anything!

Copy link

github-actions bot commented Dec 28, 2024

Code Metrics Report
  ===============================================================================
 Language            Files        Lines         Code     Comments       Blanks
===============================================================================
 C Header                2           35           28            0            7
 Dockerfile              1           41           22           10            9
 JSON                   12          105          104            0            1
 Python                 63         2706         2338           71          297
 Shell                   1           57           22           18           17
 Plain Text              3         3723            0         2413         1310
 TOML                   18          605          539            2           64
 YAML                    2           21           19            2            0
-------------------------------------------------------------------------------
 Jupyter Notebooks       4            0            0            0            0
 |- Markdown             2           77           32           31           14
 |- Python               2          205          178            1           26
 (Total)                            282          210           32           40
-------------------------------------------------------------------------------
 Markdown               43         3333            0         2526          807
 |- BASH                 6          103          100            0            3
 |- JSON                 1           12           12            0            0
 |- Python               7          121          109            0           12
 |- Rust                12          406          344            0           62
 |- TOML                 2           75           63            0           12
 (Total)                           4050          628         2526          896
-------------------------------------------------------------------------------
 Rust                  296        89600        80403         1861         7336
 |- Markdown           143         1593           25         1448          120
 (Total)                          91193        80428         3309         7456
===============================================================================
 Total                 445       100226        83475         6903         9848
===============================================================================
  

@cdoko
Copy link
Contributor Author

cdoko commented Dec 29, 2024

Currently, device-mapped paged attention is approximately 10% slower compared to single device paged attention. I found the slowdown is at least partially due to the overhead of moving tensors to the device on every layer forward pass.

I implemented a temporary workaround in the ModelWeights::forward method by creating copies of the tensors, one for each device, before iterating over the layers, and passing the correct tensor to each layer. This reduced the slowdown to around 5% compared to single device paged attention.

Ideally, we would avoid making copies of the tensors in the model's forward and instead perform this operation in the inputs_processor. However, I have not yet found a way to pass the layer_devices information there.

Copy link
Owner

@EricLBuehler EricLBuehler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @cdoko! Thanks for the PR.

Currently, device-mapped paged attention is approximately 10% slower compared to single device paged attention. I found the slowdown is at least partially due to the overhead of moving tensors to the device on every layer forward pass.

With the comments I made, hopefully this should be adressed.

However, I have not yet found a way to pass the layer_devices information there.

Please feel free to make whatever changes you find necessary to get this to work!

I think a nice way to do this would be to add an API to the mapper (device_map.rs) to extract all the devices which will be mapped to (inlcuding normal_loading_metadata.real_device). One of my comments suggests a way to utilize this information, and I think this method would work nicely with that. What do you think?

mistralrs-core/src/pipeline/gguf.rs Outdated Show resolved Hide resolved
mistralrs-core/src/pipeline/gguf.rs Outdated Show resolved Hide resolved
@@ -70,6 +70,34 @@ impl PagedAttention {
input_metadata.slot_mappings.clone()
};

// When device mapping, these Tensors are fixed on the first device, and must be moved to the same device as q,k,v
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you mentioned a performance penalty:

Currently, device-mapped paged attention is approximately 10% slower compared to single device paged attention. I found the slowdown is at least partially due to the overhead of moving tensors to the device on every layer forward pass.

To avoid this, can you please update PagedAttentionInputMetadata to store all tensors as hashmaps of device location to the actual tensor? Do you think this is a good solution?

This takes up more memory on each GPU but requires only one copy (in the inputs processor) and enables us to remove this section. I'm thinking something similar to this where we create multiple RoPE instantiations on different devices.

Additionally, I just merged #1014. Can you please merge with master to get these new changes, otherwise a conflict will occur with the addition of the changes I requested above.

@@ -284,6 +284,7 @@ impl Attention {
v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?
};

let start_offsets_kernel = start_offsets_kernel.to_device(q.device())?;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge these into RotaryEmbedding in layers.rs?

@cdoko cdoko marked this pull request as draft December 31, 2024 08:22
@cdoko
Copy link
Contributor Author

cdoko commented Dec 31, 2024

I updated PagedAttentionInputMetadata to store tensors as hashmaps of device location to the actual tensor. The performance penalty has decreased, but there's still a remaining small penalty of a few percent; I even tested a similar optimization to the attention mask in the model forward pass, since it's also copied every layer, but it didn't give any further performance improvement. I suspect this is due to the unavoidable mapping of xs in the layer forwards. Tensor parallelism might help mitigate the remaining penalty and potentially provide a performance gain.

To implement the updated PagedAttentionInputMetadata, I made some changes to the pipeline trait. Since the struct now uses hashmaps to store tensors, we need to have the mapper available wherever the struct is instantiated. I added the mapper to the MetadataMixin trait, allowing it to be passed from the pipeline to the input processor, and then used to create the PagedAttentionInputMetadata instance. I also noted that most pipelines don't use paged attention, so I used an Option to represent the mapper and passed None for those pipelines.

As a side note, I noticed that Qwen2 and Quantized Llama's RotaryEmbedding uses candle directly - was this intentional or should it be using mistralrs instead, which now includes the Tensor::to_device addition? Other models use the mistralrs one.

@cdoko cdoko marked this pull request as ready for review December 31, 2024 10:12
@EricLBuehler
Copy link
Owner

Thanks for the updates. I think this is close to merge.

I updated PagedAttentionInputMetadata to store tensors as hashmaps of device location to the actual tensor. The performance penalty has decreased, but there's still a remaining small penalty of a few percent; I even tested a similar optimization to the attention mask in the model forward pass, since it's also copied every layer, but it didn't give any further performance improvement. I suspect this is due to the unavoidable mapping of xs in the layer forwards. Tensor parallelism might help mitigate the remaining penalty and potentially provide a performance gain.

Sounds great! I agree, TP is something we should look into. I think the hard part is integrating it nicely with the existing codebase.

To implement the updated PagedAttentionInputMetadata, I made some changes to the pipeline trait. Since the struct now uses hashmaps to store tensors, we need to have the mapper available wherever the struct is instantiated. I added the mapper to the MetadataMixin trait, allowing it to be passed from the pipeline to the input processor, and then used to create the PagedAttentionInputMetadata instance. I also noted that most pipelines don't use paged attention, so I used an Option to represent the mapper and passed None for those pipelines.

Sounds good.

As a side note, I noticed that Qwen2 and Quantized Llama's RotaryEmbedding uses candle directly - was this intentional or should it be using mistralrs instead, which now includes the Tensor::to_device addition? Other models use the mistralrs one

Yes, can you please update it to use the ones in mistralrs?

@cdoko
Copy link
Contributor Author

cdoko commented Dec 31, 2024

Yes, can you please update it to use the ones in mistralrs?

I already did, just wanted to confirm.

Sounds great! I agree, TP is something we should look into. I think the hard part is integrating it nicely with the existing codebase.

Personally I'm interested in speculative decoding for the much higher T/s. I took a look at the SpeculativePipeline implementation some time ago. From what I recall, the cache implementation is incomplete:
https://github.com/EricLBuehler/mistral.rs/blob/master/mistralrs-core/src/pipeline/speculative.rs#L247
Is resolving these cache issues the main blocker for getting speculative decoding working? And what about with PA?

If the PR is ok, I'll probably be working on the VRAM calculations for mistralrs-server next, because currently the flags like --pa-gpu-mem assume single device and don't account for multi-device setups.

Copy link
Owner

@EricLBuehler EricLBuehler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cdoko thanks for the PR!

Is resolving these cache issues the main blocker for getting speculative decoding working?

Yes, it's just that using the Normal cache isn't supported yet.

And what about with PA?

I think the main problem is that we need some extensive management of the KV cache (in particular, rolling back the cache) for PA, which I haven't implemented yet.

If the PR is ok, I'll probably be working on the VRAM calculations for mistralrs-server next, because currently the flags like --pa-gpu-mem assume single device and don't account for multi-device setups.

Sounds great!

@EricLBuehler EricLBuehler merged commit c345954 into EricLBuehler:master Jan 1, 2025
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants