Skip to content

zkyredstart/LLaVA-NPU

Repository files navigation

LLaVA-NPU: Running LLaVA on a NPU device, e.g., Ascend 910B

In this project, we transfer the LLaVA from the CUDA device to the NPU device. If you think our project helpful, give our project a star. Thanks for your support!

Change history

[2024-09-02]: We support the common visual projectors, e.g. LDPv2, Resampler. Try it!

[2024-08-27]: We add SigLip encoder to the LLaVA-NPU.

[2024-08-25]: We reproduce the results on the MMBench. Weights are coming soon.

[2024-08-15]: We create the project and update the source code.

Installation

<1> Install the LLaVA. pip install -e .["train"]

Train and Test

<1> You can train and test LLaVA as those in the official repo! If you are in china, you can download the model from modelscope.

<2> Important! ! ! ! !: LLaVA-NPU does not support lora tuing and zero3-offload. Please use the full tuning. We train LLaVA on 8 Ascend 910B NPUs with 65GB memory.

<3> Training details. The hyper-parameters used in the pertraining and visual instruction tuning are as followed.

  1. Pretraining
Hyperparameter Global Batch Size Learning rate Epochs Max length Weight decay
LLaVA-v1.5-7B 256 1e-3 1 2048 0
  1. visual instruction tuning
Hyperparameter Global Batch Size Learning rate Epochs Max length Weight decay
LLaVA-v1.5-7B 64 1e-5 1 2048 0

<4> Model Performance comparison.

Model Image encoder Language Model Projector MMBench MMVet
LLaVA-v1.5-7B (official) CLIP Vicuna-7B MLP 66.7 31.1
LLaVA-v1.5-7B (ours) CLIP Vicuna-7B MLP 67.7 32.2
LLaVA-v1.5-7B (ours) SigLip Vicuna-7B MLP 66.4 -
LLaVA-v1.5-7B (ours) CLIP Vicuna-7B Adaptive Pool 64.6 -
LLaVA-v1.5-7B (ours) CLIP Vicuna-7B Resampler 63.1 27.1
LLaVA-v1.5-7B (ours) CLIP Vicuna-7B LDPv2 65.7 28.9
LLaVA-v1.5-7B (ours) CLIP Vicuna-7B TokenPacker 63.1 -
LLaVA-v1.5-7B (ours) CLIP Vicuna-7B C-Abstract 65.1 31.8

Core code

<1> LLaVA-NPU changes the flash_atten implementation. The code can be found in here.

<2> We modify the evaluation code in LLaVA. The code is here.

<3> We support nn.MultiHeadAttention on the NPU device.

import torch
import torch.nn as nn
import torch_npu
import math

class MultiheadfusionAttention(nn.Module):
    """
    MultiHeadAttention Implementation on the NPU device
    """
    def __init__(self, d_model, h):
        super().__init__()
        assert d_model % h == 0

        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h

        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model) 
    
    def forward(self,query,key,value,attn_mask=None,dropout=1.):
        # import pdb;pdb.set_trace()
        batch_size = query.size(0)
        ns = key.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [l(x) for l, x in zip(self.linear_layers, (query, key, value))]   

        scale = 1 / math.sqrt(self.d_k)     

        attn_output = torch_npu.npu_fusion_attention(query, key, value,
                                                     self.h,
                                                     pse=None,
                                                     padding_mask=None,
                                                     atten_mask=attn_mask,
                                                     scale=scale,
                                                     keep_prob=dropout,
                                                     input_layout="SBH",
                                                     pre_tockens=65536,
                                                     next_tockens=0,
                                                     inner_precise=0)

        return attn_output

# Usage
inputs = torch.rand(1,2304,1024)
Attention = MultiheadfusionAttention(1024,8)
outputs = Attention(inputs)[0]

Acknowledgement

<1> We would like to express the sincere thanks to this repo for its implementation on the NPU. Our project is based on it!

<2> Many thanks to LLaVA for its great contribution to MLLM community!

About

Run LLaVA on a NPU device, e.g. Ascend 910B

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published