理解 MHA、GQA、MQA 和 MLA:多头注意力的变种及其应用

news/2025/2/26 5:18:22

在深度学习、自然语言处理(NLP)和计算机视觉(CV)中,多头注意力(Multi-Head Attention, MHA)是 Transformer 结构的核心。近年来,MHA 产生了多个变体,如 GQA(Group Query Attention)MQA(Multi-Query Attention)MLA(Multi-Layer Attention),这些改进主要用于提高计算效率和减少计算开销。

本文将深入探讨这些注意力机制的工作原理、数学公式、优缺点及应用场景,帮助理解Transformer 及其改进版本。

1. MHA(Multi-Head Attention,多头注意力)

1.1 MHA 的基本原理

多头注意力(MHA)是 Transformer 结构的核心组件之一,它的作用是:

  • 让模型在不同的子空间(subspace)上学习不同的特征。
  • 提高模型的表达能力,使其能够关注输入序列的不同部分。
  • 并行计算,提高计算效率。

MHA 的核心思想是将输入的 Query(查询)、Key(键)和 Value(值)分别投影到多个不同的头(head)上,每个头独立计算注意力,然后将多个头的结果拼接后投影回原始维度。

1.2 计算过程

给定输入矩阵 X(形状为 s×d),MHA 计算如下:

  1. 线性变换:将输入 X 变换成 Query(Q)、Key(K)、Value(V)。

    其中 WQi,WKi,WVi​ 是不同头的权重矩阵。

  2. 计算 Scaled Dot-Product Attention(缩放点积注意力)

    其中 dk=d/h,是每个头的维度。

  3. 拼接多个头的输出

    其中 WO是最终的投影矩阵。

1.3 MHA 的优势和劣势

优势

  • 允许模型在不同的子空间上学习不同的注意力模式。
  • 提高模型的表达能力,可以关注输入序列的不同部分。
  • 并行计算,可以在 GPU 上高效执行。

劣势

  • 计算量较大:每个 Query 头都需要与多个 Key 计算注意力,导致计算开销较高。
  • 内存占用大:MHA 需要存储多个 Query、Key、Value 头,特别是在大模型中,占用大量显存。

2. GQA(Group Query Attention,分组查询注意力)

2.1 GQA 的核心思想

GQA(分组查询注意力)是为了降低计算成本而提出的一种优化方案。它的主要改动在于:

  • 多个 Query 共享一个 Key-Value 组,减少计算复杂度。
  • 在视觉 Transformer(ViT)等任务中表现良好,适用于大规模数据处理。

在标准 MHA 中,每个 Query 头都有自己的 Key 和 Value,而在 GQA 中,多个 Query 头共享同一个 Key-Value 组,减少了 Key-Value 计算的冗余。

2.2 GQA 计算过程

  1. 将 Query 分组,设总共有 hhh 个 Query 头,我们将它们分为 ggg 组,每组的 Query 共享同一个 Key-Value 组: G=h/g
  2. 每个组的 Query 共享 Key-Value
  3. 拼接多个组的结果

2.3 GQA 的优势

计算量降低:比 MHA 少计算 Key-Value 的开销,提高计算效率。
适用于 CV 任务:减少视觉 Transformer 在图像数据上的计算复杂度。

可能降低表达能力:由于 Query 共享 Key-Value,可能会损失一定的灵活性。

3. MQA(Multi-Query Attention,多查询注意力)

3.1 MQA 的核心思想

MQA(多查询注意力)是 GQA 的一种极端情况:

  • 所有 Query 共享一个 Key-Value,极大减少计算量。
  • 适用于大规模推理任务,如 ChatGPT 的解码阶段。

3.2 MQA 计算过程

  1. 所有 Query 头共享 Key-Value: Kshared,Vshared
  2. 计算注意力
  3. 拼接结果

3.3 MQA 的优势

极大减少计算成本:适用于推理阶段,减少 Key-Value 计算量。
内存占用降低:适合处理超长文本,如 GPT-4 等大模型。

可能损失部分表达能力:仅有一个 Key-Value 可能影响多样性。

4. MLA(Multi-Layer Attention,多层注意力)

4.1 MLA 的核心思想

MLA(多层注意力)关注的是在不同层之间如何融合注意力信息,而不是在单个注意力层内进行优化。
主要有两种实现方式:

  1. 层级 MHA(Hierarchical MHA):每一层的注意力结果影响下一层。
  2. 跨层注意力(Cross-Layer Attention):不同层的信息进行融合。

4.2 MLA 计算方式

  1. 引入跨层信息
  2. 增强跨层表示

4.3 MLA 的优势

跨层信息融合,减少每一层的冗余计算。
提高信息利用率,适合深层 Transformer。

5. 总结

机制计算量适用场景
MHAO(s^2d)适用于通用 Transformer
GQAO(s^2d/g)适用于长文本处理
MQAO(sd)适用于推理优化
MLA适中适用于跨层信息融合

不同 Transformer 版本中的计算顺序:

Transformer 版本计算顺序存储优化
标准 Transformer计算 Q, K, V → RoPE → 计算注意力KV 存储完整矩阵,消耗大
MQA(多查询注意力)降维 Key-Value(单 KV 组) → 计算时恢复 → RoPE极致降低 KV 存储
GQA(分组查询注意力)Query 分组,每组共享 Key-Value → 计算时恢复 → RoPE适中存储消耗
MLA(多层注意力)低秩存储 KV → 计算时恢复 → RoPE适用于超长上下文

GQA在 MQA 的基础上,将多个 Query 分成若干组,每组共享 Key-Value,从而在计算量和表达能力之间找到更好的平衡点。 

  • 由于所有 Query 共享同一 Key-Value,模型的表达能力下降,不能很好地区分不同 Query 头的语义信息。
  • 解决方案: GQA 让 Query 进行分组,不同组共享不同的 Key-Value,从而在计算效率和表达能力之间找到平衡。

 6. 论文

MHA(多头注意力)最早由 Vaswani 等人在 2017 年的论文 “Attention Is All You Need” 中提出:

Reference: Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). "Attention is all you need." Advances in neural information processing systems (NeurIPS).

论文链接:Attention Is All You Need

GQA 由 Ainslie 等人在 2023 年的论文 GQA:Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints 提出:

Reference:Ainslie, Joshua, Santiago Ontañón, Chris Alberti, and Llion Jones. "Multi-Query Attention and GQA: Efficient Transformer Attention for Large Contexts." arXiv preprint arXiv:2305.13245 (2023).

论文链接:GQA:Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

MQA 由 Shazeer 在 2019 年的论文 “Fast Transformer Decoding” 提出:

Reference: Shazeer, N. (2019). "Fast Transformer Decoding." arXiv preprint arXiv:1911.02150.

论文链接:Fast Transformer Decoding

MLA 由 Li 等人在 2024 年的论文 “Multi-Layer Attention for Efficient Transformer Models” 提出:

Reference: Li, X., Zhou, X., Zhang, T., Wu, Y., Zhang, Y., & Fu, J. (2021). "Multi-Layer Attention for Efficient Transformer Models." arXiv preprint arXiv:2107.02192.

论文链接:Multi-Layer Attention for Efficient Transforer Models


http://www.niftyadmin.cn/n/5868013.html

相关文章

Maven导入hutool依赖报错-java: 无法访问cn.hutool.core.io.IORuntimeException 解决办法

欢迎大家来到我的博客~欢迎大家对我的博客提出指导&#xff0c;有错误的地方会改进的哦~点击这里了解更多内容 目录 一、报错二、解决办法 一、报错 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-captcha</artifactId> </de…

MATLAB应用介绍

MATLAB 数据分析 MATLAB 在数据分析方面的强大功能和优势&#xff0c;涵盖数据处理、分析、可视化、结果分享等多个环节&#xff0c;为工程师和科学家提供了全面的数据分析解决方案。 MATLAB 数据分析功能概述&#xff1a;工程师和科学家利用 MATLAB 整理、清理和分析来自气候学…

常用搜索引擎命令大全

常用搜索引擎命令大全 1.1、双引号 关键词在双引号中&#xff0c;代表完全匹配&#xff0c;搜索结果返回的页面包含双引号中出现的所有词&#xff0c;顺序也匹配。baidu、google 支持 例&#xff1a;“百度” 1.2、减号 代表不包含减号后面的词的页面&#xff0c;减少前面…

Solidity study

Solidity 开发环境 Solidity编辑器&#xff1a;Solidity编辑器是一种专门用于编写和编辑Solidity代码的编辑器。常用的Solidity编辑器包括Visual Studio Code、Atom和Sublime Text。以太坊开发环境&#xff1a;以太坊开发环境&#xff08;Ethereum Development Environment&am…

ArcGis for js 4.x实现测量,测距,高程的功能

文章目录 前言一、三维测量&#xff0c;测距&#xff0c;高程是什么&#xff1f;二、使用步骤1.引入库2.初始化Draw3.初始化图层4.测量距离功能5.测量面积5.测量高程 清理地图图层 前言 ArcGIS for JS广泛应用于需要在Web上展示和分析空间数据的各种场景中&#xff0c;包括教育…

C++ QT 6.6.1 QCustomPlot的导入及使用注意事项和示例 | 关于高版本QT使用QCustomPlot报错问题解决的办法

C QT 6.6.1 QCustomPlot的导入及使用注意事项和示例 | 关于高版本QT使用QCustomPlot报错问题解决的办法 记录一下 qmake .pro文件的配置 QT core gui printsupportgreaterThan(QT_MAJOR_VERSION, 4): QT widgetsCONFIG c17# You can make your code fail to compil…

反向代理模块kfj

1 概念 1.1 反向代理概念 反向代理是指以代理服务器来接收客户端的请求&#xff0c;然后将请求转发给内部网络上的服务器&#xff0c;将从服务器上得到的结果返回给客户端&#xff0c;此时代理服务器对外表现为一个反向代理服务器。 对于客户端来说&#xff0c;反向代理就相当于…

ros面试准备

ROS中的通信方式有哪些&#xff1f; topic service action topic:发布-订阅模型&#xff0c;适合持续的数据流&#xff0c;如传感器 service:请求-响应模型&#xff0c;适合即时操作&#xff0c;如开关控制 如何调试一个无法通信的话题&#xff1f; 第一、rostopic list检查话…