Skip to content

Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs

License

Notifications You must be signed in to change notification settings

JT-Ushio/MHA2MLA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

147 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MHA2MLA

This repo contains the code for the paper "Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs".

alt text

News

  • [2025.06.13] Release the refactored code and add support for the Qwen model.
  • [2025.03.12] Released the inference code implemented using PyTorch (support for FlashMLA inference requires additional development time).
  • [2025.03.04] The four MLA checkpoints ($d_{kv}$=8/16/32/128) derived from SmolLM-135M/360M/1B7 are publicly available.
  • [2025.03.03] The four MLA checkpoints ($d_{kv}$=16/32/64/256) derived from Llama-2-7B are publicly available.
  • [2025.02.21] The paper of MHA2MLA is publicly available: https://arxiv.org/abs/2502.14837
  • [2025.02.19] Released the first version of the MHA2MLA code, providing usage code for Llama fine-tuning and evaluating.

TO-DO

  • Provide the code for incorporating the projection matrix and inference.
  • Thanks to DeepSeek for open-sourcing the FlashMLA inference framework. It’s theoretically possible to save more GPU memory usage using this framework. Let’s see how economical MHA2MLA + FlashMLA (+ KV quanto) can be!
  • Release the code of MHA2MLA based on HuggingFace Transformers

Datasets

First download the datasets.

Secondly, process the datasets according to https://github.com/huggingface/nanotron/blob/main/docs/nanoset.md.

Environment

Install pytorch and other packages.

conda create -n mha2mla python=3.11
pip install torch==2.4.0 torchvision==0.19.0
pip install -r requirements.txt

Fine-Tuning

First, prepare configuration files referencing 135M_4GPU.yaml.

For information on the configuration of mha2mla, you can refer to the arguments.py file.

Then, use the following command for fine-tuning:

torchrun --nproc_per_node 4 \
    ./src/mha2mla/run_train.py \
    --cfg_file ./cfgs/SmolLM1-135M-4GPU.yml

If you want to use the partial-RoPE version 2-norm, you should get the qk_tensor first. Using the following command, you can get the qk_tensor:

torchrun --nproc_per_node 1 \
    ./src/mha2mla/2_norm.py \
    --config_file ./cfgs/SmolLM1-135M-8GPU.yaml \

Lighteval Evaluation

For evaluation, you can use the following command:

accelerate launch --multi_gpu --num_processes=4 \
    ./eval/eval.py \
    accelerate \
    --model_args "pretrained=${model_name_or_path},revision=main,dtype=bfloat16,max_length=2048" \
    --override_batch_size 48 \
    --custom_tasks "./eval/tasks.py" \
    --tasks "./eval/smollm1_base.txt" \
    --output_dir "./eval_results/"

LongBench Evaluation

For the baseline evaluation, you can use the following command:

torchrun --nproc_per_node=4 \
    ./eval/longbench.py \
    --model_path ${model_name_or_path} \
    --tokenizer_path ${model_name_or_path} \
    --longbench True \
    --lb_max_tokens 2048 \
    --lb_batch_size 16 \
    --output_dir ./longbench/bf16 \
    --dtype "bfloat16"

If you want to use the quantized KV cache, you can use the following command:

torchrun --nproc_per_node=4 \
    ./eval/longbench.py \
    --model_path ${model_name_or_path} \
    --tokenizer_path ${model_name_or_path} \
    --longbench True \
    --lb_max_tokens 2048 \
    --lb_batch_size 16 \
    --output_dir ./longbench/${model_name_or_path}_hqq_int4 \
    --dtype "bfloat16" \
    --cache_implementation "quantized" \
    --backend "HQQ" \
    --nbits 4 \
    --residual_length 128

Citation

@misc{ji2025economicalinferenceenablingdeepseeks,
      title={Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs}, 
      author={Tao Ji and Bin Guo and Yuanbin Wu and Qipeng Guo and Lixing Shen and Zhan Chen and Xipeng Qiu and Qi Zhang and Tao Gui},
      year={2025},
      eprint={2502.14837},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2502.14837}, 
}

About

Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •  

Languages