Differential Transformer: 通过差分注意力机制提升大语言模型性能
itomcoil 2024-12-22 18:53 17 浏览
Transformer模型已经成为大语言模型(LLMs)的标准架构,但研究表明这些模型在准确检索关键信息方面仍面临挑战。今天介绍一篇名叫Differential Transformer的论文,论文的作者观察到一个关键问题:传统Transformer模型倾向于过分关注不相关的上下文信息,这种"注意力噪声"会影响模型的性能。
在这篇论文中,作者注意到transformer模型倾向于关注不相关的上下文。为了放大相关上下文的注意力分数,他们提出了一个新的注意力模型,称为差分注意力模型。在这个模型中,他们将查询和键值向量分成两组,并计算两个子注意力分数。
差分注意力机制
差分注意力机制(Differential Attention)的核心思想是通过计算两个独立的注意力图谱之差来消除注意力噪声。这种设计借鉴了电气工程中差分放大器的原理,通过对比两个信号的差异来消除共模噪声。
让我们看看论文中的第一个方程:
方程(1)
方程(1)显示,我们首先像标准注意力计算一样计算Q、K和V张量。关键点是我们将Q和K张量分成Q1、Q2和K1、K2子张量。
论文中输入X、Q1、Q2、K1、K2和V张量的形状
根据论文,Q和K张量的形状应该是Nx2d,因为Q1、Q2、K1和K2将是Nxd。输入X的形状是Nxd_model,这是论文中的嵌入维度。这就是为什么W_Q、W_K和W_V的可学习参数的形状必须是d_modelx2d。
论文中用于lambda计算的方程(2)
方程(2)展示了如何计算可学习参数lambda。在这个方程中有一个初始lambda参数。lambda是一个标量参数,但lambda_q1、lambda_k1、lambda_q2和lambda_k2是向量。这一点很关键。向量lambda_q和lambda_k的运算是点积。
用于lambda初始化的方程(3)
实验结果与性能提升
论文的实验表明,相比传统Transformer:
DIFF Transformer只需要约65%的模型参数量即可达到相同的性能,在训练token数量方面也只需要约65%就能达到相同效果
在Needle-In-A-Haystack测试中:4K上下文长度:DIFF Transformer在多目标检索任务中保持85%准确率;64K上下文长度:在深度为25%的位置检测时,比传统Transformer提升了76%的准确率
Python实现
下面我们根据论文的公式来做一个简单的实现,首先方程(3)展示了我们如何计算lambda_initial变量。现在让我们把方程转换成Python代码:
def lambda_init_fn(depth):
return 0.8 - 0.6 * math.exp(-0.3 * depth)
然后再写一个简单的Python函数,使用方程(3)。
class DifferentialAttention(nn.Module):
def __init__(self, dim_model: int, head_nums: int, depth: int):
super().__init__()
self.head_dim = dim_model // head_nums
self.Q = nn.Linear(dim_model, 2 * self.head_dim, bias=False)
self.K = nn.Linear(dim_model, 2 * self.head_dim, bias=False)
self.V = nn.Linear(dim_model, 2 * self.head_dim, bias=False)
self.scale = self.head_dim ** -0.5
self.depth = depth
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.rotary_emb = RotaryEmbedding(self.head_dim * 2)
在DifferentialAttention类中,我们实现了一个多头差分注意力机制。有dim_model(嵌入维度)、head_nums和depth参数。为Q1、Q2、K1和K2声明了四个lambda可学习参数,并使用均值为0、标准差为0.1的随机正态分布初始化它们。
def forward(self, x):
lambda_init = lambda_init_fn(self.depth)
Q = self.Q(x)
K = self.K(x)
seq_len = x.shape[1]
cos, sin = self.rotary_emb(seq_len, device=x.device)
Q, K = apply_rotary_pos_emb(Q, K, cos, sin)
Q1, Q2 = Q.chunk(2, dim=-1)
K1, K2 = K.chunk(2, dim=-1)
V = self.V(x)
A1 = Q1 @ K1.transpose(-2, -1) * self.scale
A2 = Q2 @ K2.transpose(-2, -1) * self.scale
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(Q1)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(Q2)
lambda_ = lambda_1 - lambda_2 + lambda_init
return (F.softmax(A1, dim=-1) - lambda_ * F.softmax(A2, dim=-1)) @ V
forward方法很直观。我分别实现了方程(1)和方程(2)。forward方法直接实现了论文中的伪代码。
多头差分注意力架构和伪代码
class MultiHeadDifferentialAttention(nn.Module):
def __init__(self, dim_model: int, head_nums: int, depth: int):
super().__init__()
self.heads = nn.ModuleList([DifferentialAttention(dim_model, head_nums, depth) for _ in range(head_nums)])
self.group_norm = RMSNorm(dim_model)
self.output = nn.Linear(2 * dim_model, dim_model, bias=False)
self.lambda_init = lambda_init_fn(depth)
def forward(self, x):
o = torch.cat([self.group_norm(h(x)) for h in self.heads], dim=-1)
o = o * (1 - self.lambda_init)
return self.output(o)
MultiHeadDifferentialAttention类是根据论文中的伪代码编写的。这里使用了RMSNorm而不是GroupNorm。
论文中使用多头差分注意力机制的语言模型的方程
最后使用实现的MultiHeadDifferentialAttention机制构建一个transformer解码器。
class DifferentialTransformer(nn.Module):
def __init__(self, dim: int, depth: int, heads: int = 8, head_dim: int = 64, vocab_size: int = 10000):
super().__init__()
self.vocab_size = vocab_size
self.layers = nn.ModuleList([
MultiHeadDifferentialAttention(dim, heads, depth_idx)
for depth_idx in range(depth)
])
self.ln1 = RMSNorm(dim)
self.ln2 = RMSNorm(dim)
self.ffn = FeedForward(dim, (dim // 3) * 8)
self.output = nn.Linear(dim, self.vocab_size)
def forward(self, x):
for attn in self.layers:
y = attn(self.ln1(x)) + x
x = self.ffn(self.ln2(y)) + y
return self.output(x)
性能优化
论文还提供了两种FlashAttention实现方式:
1、支持不同维度的实现:
def FlashDiffAttn_1(X, W_q, W_k, W_v, λ):
Q1, Q2 = split(X @ W_q)
K1, K2 = split(X @ W_k)
V = X @ W_v
A1 = flash_attn(Q1, K1, V)
A2 = flash_attn(Q2, K2, V)
return A1 - λ A2
固定维度的实现:
def FlashDiffAttn_2(X, W_q, W_k, W_v, λ):
Q1, Q2 = split(X @ W_q)
K1, K2 = split(X @ W_k)
V1, V2 = split(X @ W_v)
A11 = flash_attn(Q1, K1, V1)
A12 = flash_attn(Q1, K1, V2)
A1 = Concat(A11, A12)
A21 = flash_attn(Q2, K2, V1)
A22 = flash_attn(Q2, K2, V2)
A2 = Concat(A21, A22)
return A1 - λ A2
Differential Transformer论文提出的两种FlashAttention实现方案各有特色。第一种实现(FlashDiffAttn_1)采用直接计算策略,允许Q、K、V具有不同的维度,这种灵活性使其更适合需要动态调整维度的场景,但可能在某些硬件上的优化效果不如第二种方案。第二种实现(FlashDiffAttn_2)通过将计算分解为多个相同维度的子运算,虽然计算步骤增多,但每个步骤都能充分利用硬件优化,特别是在支持张量核心的现代GPU上表现更好。
这两种实现的选择主要取决于具体应用场景:如果模型架构需要频繁调整维度或者需要更灵活的注意力机制,建议使用第一种实现;如果追求极致的计算效率且维度相对固定,第二种实现可能是更好的选择。从工程实践角度看,第二种实现与现有的FlashAttention优化库的兼容性更好,更容易在现有基础设施上部署和优化。
局限性和未来研究方向
Differential Transformer虽然在多个方面展现出优秀的性能,但仍然存在一些值得关注的局限性。首要的挑战来自其计算效率方面。由于模型需要同时计算两个独立的注意力图谱,这不可避免地增加了计算开销。在实际测试中,相比传统Transformer,DIFF Transformer在3B规模模型上的计算吞吐量降低了约9%,这种性能损失虽然可以通过更少的参数量来部分抵消,但在大规模部署场景中仍然需要认真考虑。
内存使用是另一个重要的局限性。模型需要存储两组独立的查询和键值向量,这导致了更高的内存占用。尽管这种设计对于提升模型性能是必要的,但在资源受限的环境下可能会造成部署困难。特别是在处理超长序列时,内存压力会进一步加大。
训练稳定性也是一个需要特别关注的问题。模型中λ参数的初始化策略对训练过程的稳定性有显著影响。研究发现,不同的λinit取值会导致训练收敛速度和最终性能的差异。虽然论文提出了一个基于层深度的初始化策略,但这种方案并非在所有场景下都能取得最优效果,有时需要根据具体任务进行调整。
基于这些局限性,论文提出未来的研究可以沿着几个重要方向展开。首先在计算效率优化方面,可以探索更高效的注意力核心实现。这包括研究如何更好地利用现代硬件特性,例如开发专门的CUDA核心来加速差分注意力的计算。同时考虑到模型产生的稀疏注意力模式,可以设计特定的稀疏计算优化策略,这不仅能提升计算效率,还能减少内存占用。
λ参数的动态调整机制是另一个值得深入研究的方向。当前的参数计算方案虽然有效,但仍有优化空间。可以考虑设计更灵活的自适应机制,使λ参数能够根据输入内容和任务特点动态调整,从而在不同场景下都能获得最佳性能。这可能需要引入额外的上下文感知机制,或者设计新的参数更新策略。
在内存优化方面,量化技术提供了一个有前景的研究方向。考虑到DIFF Transformer在处理激活值异常方面的优势,可以探索专门的量化策略。比如,研究如何在保持模型性能的同时,对注意力权重和中间状态进行更激进的量化,从而减少内存占用。这对于模型在边缘设备上的部署具有重要意义。
长文本建模能力的进一步提升也是一个重要研究方向。虽然当前模型在64K长度的实验中表现出色,但随着应用需求的增长,可能需要处理更长的序列。这要求研究如何在更长序列上保持模型的效率和性能,可能需要开发新的注意力机制变体或优化策略。
总结
DIFF Transformer通过创新的差分注意力机制成功提升了模型性能,特别是在长文本理解、关键信息检索和模型鲁棒性等方面。虽然存在一些计算效率和内存使用的权衡,但考虑到显著的性能提升和更少的参数需求,这是一个非常有价值的改进。这项工作为大语言模型的架构设计提供了新的思路,也为后续研究指明了几个重要的优化方向。
相关推荐
- Excel新函数TEXTSPLIT太强大了,轻松搞定数据拆分!
-
我是【桃大喵学习记】,欢迎大家关注哟~,每天为你分享职场办公软件使用技巧干货!最近我把WPS软件升级到了版本号:12.1.0.15990的最新版本,最版本已经支持文本拆分函数TEXTSPLIT了,并...
- Excel超强数据拆分函数TEXTSPLIT,从入门到精通!
-
我是【桃大喵学习记】,欢迎大家关注哟~,每天为你分享职场办公软件使用技巧干货!今天跟大家分享的是Excel超强数据拆分函数TEXTSPLIT,带你从入门到精通!TEXTSPLIT函数真是太强大了,轻松...
- 看完就会用的C++17特性总结(c++11常用新特性)
-
作者:taoklin,腾讯WXG后台开发一、简单特性1.namespace嵌套C++17使我们可以更加简洁使用命名空间:2.std::variant升级版的C语言Union在C++17之前,通...
- plsql字符串分割浅谈(plsql字符集设置)
-
工作之中遇到的小问题,在此抛出问题,并给出解决方法。一方面是为了给自己留下深刻印象,另一方面给遇到相似问题的同学一个解决思路。如若其中有写的不好或者不对的地方也请不加不吝赐教,集思广益,共同进步。遇到...
- javascript如何分割字符串(javascript切割字符串)
-
javascript如何分割字符串在JavaScript中,您可以使用字符串的`split()`方法来将一个字符串分割成一个数组。`split()`方法接收一个参数,这个参数指定了分割字符串的方式。如...
- TextSplit函数的使用方法(入门+进阶+高级共八种用法10个公式)
-
在Excel和WPS新增的几十个函数中,如果按实用性+功能性排名,textsplit排第二,无函数敢排第一。因为它不仅使用简单,而且解决了以前用超复杂公式才能搞定的难题。今天小编用10个公式,让你彻底...
- Python字符串split()方法使用技巧
-
在Python中,字符串操作可谓是基础且关键的技能,而今天咱们要重点攻克的“堡垒”——split()方法,它能将看似浑然一体的字符串,按照我们的需求进行拆分,极大地便利了数据处理与文本解析工作。基本语...
- go语言中字符串常用的系统函数(golang 字符串)
-
最近由于工作比较忙,视频有段时间没有更新了,在这里跟大家说声抱歉了,我尽快抽些时间整理下视频今天就发一篇关于go语言的基础知识吧!我这我工作中用到的一些常用函数,汇总出来分享给大家,希望对...
- 无规律文本拆分,这些函数你得会(没有分隔符没规律数据拆分)
-
今天文章来源于表格学员训练营群内答疑,混合文本拆分。其实拆分不难,只要规则明确就好办。就怕规则不清晰,或者规则太多。那真是,Oh,mygod.如上图所示进行拆分,文字表达实在是有点难,所以小熊变身灵...
- Python之文本解析:字符串格式化的逆操作?
-
引言前面的文章中,提到了关于Python中字符串中的相关操作,更多地涉及到了字符串的格式化,有些地方也称为字符串插值操作,本质上,就是把多个字符串拼接在一起,以固定的格式呈现。关于字符串的操作,其实还...
- 忘记【分列】吧,TEXTSPLIT拆分文本好用100倍
-
函数TEXTSPLIT的作用是:按分隔符将字符串拆分为行或列。仅ExcelM365版本可用。基本应用将A2单元格内容按逗号拆分。=TEXTSPLIT(A2,",")第二参数设置为逗号...
- Excel365版本新函数TEXTSPLIT,专攻文本拆分
-
Excel中字符串的处理,拆分和合并是比较常见的需求。合并,当前最好用的函数非TEXTJOIN不可。拆分,Office365于2022年3月更新了一个专业函数:TEXTSPLIT语法参数:【...
- 站长在线Python精讲使用正则表达式的split()方法分割字符串详解
-
欢迎你来到站长在线的站长学堂学习Python知识,本文学习的是《在Python中使用正则表达式的split()方法分割字符串详解》。使用正则表达式分割字符串在Python中使用正则表达式的split(...
- Java中字符串分割的方法(java字符串切割方法)
-
技术背景在Java编程中,经常需要对字符串进行分割操作,例如将一个包含多个信息的字符串按照特定的分隔符拆分成多个子字符串。常见的应用场景包括解析CSV文件、处理网络请求参数等。实现步骤1.使用Str...
- 因为一个函数strtok踩坑,我被老工程师无情嘲笑了
-
在用C/C++实现字符串切割中,strtok函数经常用到,其主要作用是按照给定的字符集分隔字符串,并返回各子字符串。但是实际上,可不止有strtok(),还有strtok、strtok_s、strto...
- 一周热门
- 最近发表
- 标签列表
-
- ps像素和厘米换算 (32)
- ps图案在哪里 (33)
- super().__init__ (33)
- python 获取日期 (34)
- 0xa (36)
- super().__init__()详解 (33)
- python安装包在哪里找 (33)
- linux查看python版本信息 (35)
- python怎么改成中文 (35)
- php文件怎么在浏览器运行 (33)
- eval在python中的意思 (33)
- python安装opencv库 (35)
- python div (34)
- sticky css (33)
- python中random.randint()函数 (34)
- python去掉字符串中的指定字符 (33)
- python入门经典100题 (34)
- anaconda安装路径 (34)
- yield和return的区别 (33)
- 1到10的阶乘之和是多少 (35)
- python安装sklearn库 (33)
- dom和bom区别 (33)
- js 替换指定位置的字符 (33)
- python判断元素是否存在 (33)
- sorted key (33)