所有文章 > 正文

ICLR 2020 | reformer高效处理长序列,单机能跑,计算资源贫困人士的福音

作者: 刘杰鹏

时间: 2020-05-08 21:06

基于Transformer的各种巨型模型在各种自然语言处理任务中常常能够取得最优结果,但这些模型的训练成本往往过高,在针对长序列文本上尤甚。为此,本文提出两种技术以改善基于Transformer的这类模型,名为Reformer。第一,使用局部敏感hash,替换原始的点乘方式的attention,从而将其空间复杂度从O(L^2)降低到O(Llog L),其中L表示文本序列的长度。第二,使用逆残差层代替标准的残差,这使得训练过程中只需存储一次激活值,而无需N次,其中N表示网络层数。最终的结果表明Reformer性能与Transformer相当,同时在长序列上具有更高的内存效率和更快的速度。

Reformer:高效的Transformer

机构:Google Research 、U.C. Berkeley

作者:Nikita Kitaev、Łukasz Kaiser、Anselm Levskaya

收录会议:ICLR2020

1. 介绍

那训练Transformer模型是否真需要很多资源且很低效?以现有的最大Transformer层为例,该Transformer层中参数量是0.5B,这需要2GB的内存。(1M=1024KB,1KB=1024Byte。所以1GB=1024M=1024x1024KB=1024x1024x1024Byte=1073741824Byte。float占用4个Byte。0.5B即5亿参数,需要的内存量为5亿*4字节=20亿字节。这差不多是1.86GB即约为2GB)对于由64Ktokens组成的序列,如果嵌入层的尺寸是1024,batch size是8,那么激活值需要64K * 1K * 8=0.5B个浮点数来存储,这又需要2GB的内存。如果每层的内存占用只有上述提到的这些的话,那么在单加速器上使用Transformer处理64K长度的序列也是轻而易举。此外,如此前提下训练BERT的整个语料库也只需17GB的内存。然而,现实并非如此,真实环境下为何甚至不能在单台机器上对这些模型进行微调呢?

这是因为上述仅仅考虑单层参数的内存占用和输入激活值的内存消耗,而忽略了 Transformer 在内存占用上的主要问题:

- 需要存储激活值用于反向传播,那么N层模型内存占用是单层的N倍;

- 由于中间全连接层的深度d_{ff}通常远大于注意力激活层的深度d_{model},而这需要占用很大的内存;

- 长度为L的序列的 attention 的时间和空间复杂度是O(L^2),那么对于64K tokens的序列就会耗尽内存。

为此,本文提出Reformer模型以解决上述问题,具体采用如下方案:

- 可逆层(Reversible layer),在整个模型中只使用单个副本,可以消除层数因子N。

- 前馈层(feed-forward layer)分开激活和分块处理,从而消除d_{ff}因子的影响,降低前馈层的内存占用。

- 采用基于局部敏感哈希(locality-sensitive hashing,LSH)的近似注意力计算,让注意力层的O(L^2)因子变为O(L log L) ,这使得在长序列上的处理成为可能。

Reformer模型在以下3个任务上进行实验:合成任务、文本任务(enwik8,序列长度为64K)和图像生成任务(imagenet-64,序列长度为12K)。实验结果表明Reformer结果与Transformer相当,但是更快、内存也更高效。

2. 局部敏感哈希ATTENTION

点乘attention:

标准的Transformer使用点乘的attention,queries和keys的维度都是d_k,values的维度是d_v。query先与key做点乘,再除以根号d_k,再输入到softmax中得到value的权重,最后权重再与value相乘,得到最终的结果。在实际操作过程中是以矩阵方式进行批量操作,queries组成矩阵Q,keys组成矩阵K,values组成矩阵V,上述流程概况如下:

点乘.png

多头attention:

上述的attention操作并行地进行h次,再输出维度为d_v的输出结果。再将这些结果拼接,再做一次投射操作得到最终的结果。即所谓的多头attention。

高效内存attention:

先来算下上述attention机制消耗的内存。假设Q,K,V的尺寸为[batch_size,length,d_model]。QK^T的尺寸为[batch_size,length,length]。当length=64k,即使batch_size=1,那么64k*64k大小的矩阵,如果用32位浮点数来存储的话,需要16GB内存。鉴于此,在长序列上使用Transformer显得不切实际。但是需要注意的是,QK^T矩阵可以不必全部放在内存中,可以对每个query分别计算attention。反向传播计算梯度时再重新计算一次。这种方式计算attention虽然低效,但是所占用的内存与length成正比。这种方法在本文这里作为一种全attention的baseline。

Q,K,V从何处来?

上述讨论了Q、K、V,但是一般我们只会得到大小为[batch_size,length,d_model]的激活值A,这些值是token的嵌入所组成的句向量。那么为了从A中得到Q、K、V,Transformer使用了3个不同的线性层(参数不同)将A投射为Q、K、V。对于使用局部敏感哈希attention的模型,我们希望queries和keys(即Q和K)相同。只需要A投射到Q和A投射到K时采用相同线性变换参数即可,而A投射到V时采用不同参数。这种方式成为共享QK-Transformer。实验表明共享QK并不会影响Transformer的性能,即使添加一项d_k的归一化项。

Hashing attention:

在LSH attention中,假设Q、K、V的尺寸为[batch_size,length,d_model],同时仍然使用此前介绍的多头attention机制。那么QK^T的尺寸为[batch_size,length,length]。由于softmax(QK^T)的计算结果主要取决于值最大的部分,对于每个query只需关注K中与query最接近的点。当K的长度是64k,那么对个每个query,本文仅仅考虑其最近的的32或64个keys。如此会更加高效,那么如何找寻最近的那些keys呢?

局部敏感哈希(LSH):

在高纬空间中找寻最近邻可以使用局部敏感哈希(LSH)。将每个向量x通过hash函数h(x)进行映射,如果近处的向量获得相同的hash,且具有高概率,而远处的向量没有,那么这样的hash称为位置敏感型hash。在此处例子中,我们实际上只要求近邻的向量以高概率具有相同的hash值,并且hash桶也以高概率具有相同的大小。

具体是使用如Figure 1所示的随机投射方法:

局部.png

上图的angular LSH是一种常用的LSH算法,它将点投射到一个单位球上,这个单位球被划分为预定义的区域,每个区域都有一个特定的代码。然后一系列随机旋转的点定义了这些点所归属的桶。

LSH attention:

综合考虑上述的LSH策略和hashing attention,先重写单个query在位置i的常规attention:

lsh.png

其中P_i表示query在位置i所需要attend的集合,z表示配分函数(partition function)比如softmax中的归一化项。为了书写清楚,这里省略了缩放项根号d_k。

对于批量操作,当遮蔽掉不在P_i中的元素,此时常规attention定义如下:

atten.png

即对于不能attend到的位置,m(j, P_i)为正无穷,那么q_i* k_j减去正无穷再去exp操作,其结果为0。这样就不需要对于每个位置i都有单独的P_i。

在LSH attention中,query中位置i所能够attend的限制集合P_i被限制到一个hash桶中。Figure 2(a-b)展示的是全attention和hash attention的对比。

对比.png

图a:常规的attention机制中,黑点代表的是softmax中占主导的位置。注意这边的attention使用的是encoder的attention,否则q_3无法attend到k_6。另外,这种全attention(即encoder中的attention)的attention矩阵一般是稀疏的,但计算中并没有利用这种稀疏性,所以可以利用这个降低时间空间复杂度。

图b:计算query和key所归属的hash桶。再按照桶进行排序,同一个桶又按照原本的位置进行排序得到图b。可以看到,同一个桶,可以出现多个query但keys很少的情况,例如图中蓝色的桶query有3个,都attend到同一个key中。由于相似的item很有可能落在同一个桶里,所以只在每个桶内部进行attention就可以近似全attention。

图c:为了缓解桶中q和k不均衡问题,本文通过令$k_{j}=\frac{q_{j}}{\left\|q_{j}\right\|}$使得h(k_j)=h(q_j),即使用了share-QK attention。然后先按照桶序号对queries排序,每个桶中,仍按照原本的position 位置大小排序。得到图c。对比b图和c图可以看出,纵轴的k已经变成了q。这时就能保证对角线都是attend 到的而且q和k在桶中的个数一样(因为Q=K)。排序后的attention矩阵,相同桶的值会在对角线附近聚集。注意到图中对角线的点为空心,这是因为虽然在正常情况下,q会attend to本身位置的value,但是在share-QK的实现下,如果attend to本身,会导致其值特别大,其他的值特别小,经过softmax之后,其他都是0,就自己本身是1。所以为了避免这种情况,q不会去attend 自身位置的值,除非只有自己本身可以attend。

图d:即使Q=K,还是会出现一个问题:有的桶中个数多,有的桶中个数少。比如一个极端情况,2个桶,其中一个桶占据了所有的keys,另一个桶为空,那么LSH attention就没有起作用。于是在图c的基础上,增加了chunk的操作。对输入进行排序之后(即图c中先桶排序,同个桶内按照token 的 position排序)得到新的序列顺序s_i,比如图中原来的序列顺序是[q_1,q_2,q_3,q_4,q_5,q_6],新的序列顺序是[q_1,q_2,q_4,q_3,q_6,q_5] 。每个chunk内query的上限个数为$m=\frac{2 l}{n_{\text {buckets}}}$, (l为输入query的长度) ,每个桶平均大小为$m=\frac{l}{n_{\text {buckets}}}$,这里假设桶中数量增加到均值两倍的概率足够低。对于桶中的每个query,都可以attend to自己以及前一个桶中相同hash 值的key。

小结下,LSH attention做了以下两个事情:

第一,找到Q、K矩阵的LSH hashes。

第二,在同一个hash桶内计算k和q向量的标准attention。

更具体来说可分为以下5个步骤:

第一,令输入序列queries=keys

第二,做LSH bucketing,即进行hash计算,得到每个query和key所归属的桶(不同颜色表示不同的桶)。

第三,根据桶编号对query进行排序,同个桶中,按照query原本的位置进行排序。

第四,对于排序后的新序列,进行 chunk 拆分

第五,对于每个query只attend自己以及自己之前的chunk,对于这些候选集中相同桶的key进行attend。

多轮LSH attention:

LSH 有近似性,即不能保证相似的输入能在同一个桶中。为了减轻这个问题,采用了multi-round LSH attention。即重复上述过程多次,以使类似的item以尽可能高的概率落入相同的桶中,尽量避免相似item落入不同桶。更多的细节参考附件A。

3. 可逆层

如上所述,attention的复杂度可以被减少为与序列长度成线性正比,但是,参数量占的复杂度依旧很高,如何进一步减少呢?这里就开始尝试解决前文介绍部分所提到的第二和第三个问题,即大量的encoder和decoder层、全连接层FFN的深度问题。

Reversible residual Network (RevNet)

RevNet的思想是每一层的activations可以根据下一层的activations推导获得,从而不需要在内存中储存activations。在原本的residual layer中,由公式y=x+F(x)输出得到activations。其中F是residual 函数。在RevNet中,先将输入x分为两个部分x_1和x_2,然后通过不同residual functions: F()和G()得到输出y_1和y_2:

可逆1.png

再根据以下结构,从输出获得输入:

可你2.png

Reversible Transformer

那么如何在Transformer中引入RevNet?将attention layer和 FFN layer通过ResNet 连接,从而减少内存的消耗。具体是令F函数为attention 层,G函数作为FFN层。需要注意的一点是layer normalization是包含在residual blocks中的。

trrr.png

如此,使用可逆的Transformer在每一层中就无需存储激活值,也就避免了n_l这一项。可逆层代替标准的残差层,可以在训练过程中只存储一次激活,而不是N次。

Chunking

上述消除了n_l项的影响,深层的网络仍然占有大量内存。在FFN中中间隐藏层的纬度通常非常大,比如d_{ff}=4k或者更大。由于FFN的计算与序列中的位置完全无关,因此计算可以被分割成c个块,以降低内存的使用。虽然该操作其实可并行处理,但是每次只计算一个chunk,通过时间换取内存空间。

ch.png

另外,可逆操作和反向传播操作也分块处理。除FFN之外,对于词汇量大的模型(单词类型>d_{model}),还对输出处的log- probability分块,并一次计算序列各部分的损失。

4. 实验结果

对图像生成任务imagenet64(序列长度为12K)和文本任务enwik8-64K(即序列长度为64K)进行了实验,评价了可逆层、共享query-key、LSH attention对内存、精度和速度的影响。

可逆层和共享query-key的影响:

结果.png

Figure 3中的左部分验证共享query-key的影响。从perplexity曲线结果可以看出,共享QK attention并不会明显逊色于常规attention。且在enwik8数据集中收敛更快。换句话说,使用共享QK attention并不会牺牲准确性。

Figure 3中的右部分验证的是可逆层的影响。实验中对比的可逆层和常规Transformer参数量相同,且学习曲线看起来也几乎相同。这些结果表明,可逆Transformer在节省内存的同时并不会牺牲精度。

LSH attention的影响:

如Figure 4所示,可以看出随着hash数的增多精度也提升了。

giuy.png

更大的Reformer模型:

Figure 5展示了不同层数的Reformer在envik8和imagenet64上的表现。下图(左)是Big Reformer随层数变化指标结果,20层依然无压力。而下图(右)是普通attention和LSH attention在不同序列长度的速度比较,当序列很长的时候,LSH具有显著的优势。

结果5.png

5. 总结

Reformer将Transformer的建模能力与能够在长序列上高效执行的体系结构相结合,使其即使处理大模型时,也可以使用较小的内存。这将有助于大型、海量参数化的Transformer模型变得更广泛可用。此外,处理长序列的能力为Reformer在许多生成任务上的使用开辟了道路。除了生成非常长的连贯文本外,Reformer可以把Transformer模型的能力应用到其他领域,如时间序列预测、音乐、图像等。

点击”代码“查看论文代码

作者:刘杰鹏(微信号:onepieceand)

毕业院校:华中科技大学

研究方向:机器阅读理解、文本生成等。

相关阅读:

ICLR 2020 | 知识图谱推理框架:基于向量空间的推理和数值逻辑推理

ICLR 2020 | PairNorm: Tackling Oversmoothing in GNNs

ICLR 2020 | 预训练图神经网络模型

ICLR 2020 | 探索新的图表征学习思路

ICLR 2020 | NLP预训练模型的全新范式:对比学习

ICLR 2020 | 互信息视角下的表征学习

ICLR 2020 | 反事实因果理论如何帮助深度学习?

ICLR 2020 | 浅谈GNN:能力与局限

ICLR 2020 | 一种高效、表达能力强的可微分归纳逻辑推理模型

ICLR 2020 | 基于谱方法的高效多级图嵌入框架

ICLR 2020 满分论文解读| 一种镜像生成式机器翻译模型:MGNMT

[关于转载]:本文为“AMiner”官网文章。转载本文请联系原作者获取授权,转载仅限全文转载并保留文章标题及内容,不得删改、添加内容绕开原创保护,且文章开头必须注明:转自“AMiner”官网。谢谢您的合作。

推荐阅读 更多