cover_image

知识蒸馏技术简介及在BERT上的实践应用

vivo AI Lab vivo人工智能技术
2019年11月28日 13:00

图片





  

JEEK

容易脸红的编程大神

图片



本文简单介绍了知识蒸馏技术的原理和概念,并详细讲解了如何在BERT上进行知识蒸馏,其中包括一些技巧和指导。通过本文,你将了解到:

(1) 知识蒸馏的基本原理;

(2) 基于BERT进行知识蒸馏的技巧;

(3) 如何选择合适的蒸馏BERT结构;













背景



2018年NLP领域中诞生了很多绝妙的想法,其中最为著名的BERT的提出逐渐将多项NLP应用带入了新的范式,既 语言模型预训练 + 特定任务微调。不用再多说该范式下的最新研究成果以及傲人的成绩,现在只要是真正从事NLP的工作者都在从这种方式中受益。不需要再做模型的设计和调优,甚至对数据的依赖也有所降低,只需要在合适的开源模型上微调便能取得不错的应用指标。(如此看来以后做NLP的将不再需要大量专业的算法工程师,只要数据标的好指标自然好。Just kidding…☺)

图片

图1:在最新研究中的模型参数量变化


然而,BERT取得的成功依赖于巨大的参数量和更深的网络层数,在享受它带来的便利的同时,也为实际部署工作带来新的挑战。根据目前使用情况来看,在普通的服务器CPU上用BERT-base计算一个长度为32的序列需要70+毫秒,并且是在应用了现有可用加速技术(原生TF、MKL-DNN等)的基础上。而在微服务架构下一次服务请求的最佳超时时间需要控制在50ms内,以满足服务的吞吐量要求。由此便引出了模型压缩的需求,如何去压缩?主要有这么几种方法:参数共享或剪枝、低秩矩阵分解、量化和知识蒸馏(Knowledge Distillation)。在这几种方法中,本文将主要集中在知识蒸馏上面。不得不说每种方法都有自己成功的应用案例,但在实际中尝试其余方法时遇到一些问题导致最终难以应用。


知识蒸馏



知识蒸馏(Knowledge Distillation,有时也称为Teacher-Student learning)是一种压缩技术,在这种技术中,一个小模型被训练来重现一个大模型(或几个模型的集成)的行为。 它由Bucila等人(2006年)引入,并由Hinton等人(2015年)推广,目前大多数蒸馏的研究都基于后者的论文。本文将简单介绍该方法,更多信息可以参考论文和带幻灯片的视频讲座。
在监督学习中,一个分类模型通常被训练来预测一个类别,方法是利用对数似然来最大化正确预测的概率。 在许多情况下,一个良好的性能模型将预测输出分布,正确的类具有很高的概率,而其他类的概率接近于零。但是,其中一些“接近于零”的概率比其他的大,这在一定程度上反映了模型的泛化能力。例如,一张办公椅可能会被误认为扶手椅,但通常不应被误认为蘑菇。 这种不确定性有时被称为“黑暗知识”。理解蒸馏的另一种方法是防止模型对其预测过于确定(类似于标记平滑)。
下面是一个例子,看看这个想法在实践中。 在语言建模中,通过观察词汇表上的分布,我们可以很容易地观察到这种不确定性。 以下是BERT补全卡萨布兰卡电影中经典台词的20次猜想:

图片

图二:BERT-base对[MASK]字符top20预测


蒸馏训练过程中,小模型通过拟合大模型预测的概率分布来学习大模型中蕴含的知识。在分类问题中,神经网络用softmax激活层的输出来模拟对应每个类别的概率分布。通过蒸馏方法训练学生网络,其实就是用学生网络来模拟老师网络的概率输出。然而在实际中,很多模型对正确类别的预测概率都非常高,导致其余类别的概率值接近于0,导致提供的信息和one-hot标签比没有太大提升。

图片



其中,图片是概率输出;图片是softmax的输入,既logit; 图片是老师网络的softmax输出; 图片是学生网络的softmax输出;图片 是一个控制概率分布形状的参数,通常设置为1。图片 越大则产生的概率分布更加平滑,当图片时分布趋向于均匀分布,当图片,分布趋向于one-hot分布。原始标签和平滑后的老师模型输出分别记作“硬标签”和“软标签”。
Hinton发现在蒸馏过程中,在使用软标签作为预测目标的同时使用硬标签作预测目标能够使得模型得到更好的效果。因此蒸馏过程中需要计算两部分损失,分别是正常损失和蒸馏损失。在计算蒸馏损失的时候,teacher网络和student网络使用同样的 图片参数,而在计算正常损失的时候将图片设为1。
所以,最终损失是由正常损失和蒸馏损失两部分组成,并用一个参数图片控制两者之间的权重。

细心的朋友一定会发现,蒸馏损失中多了一个图片的系数。因为蒸馏loss中加入了温度参数图片,导致计算梯度的时候会产生一个的图片参数,所以加入图片可以使得蒸馏损失和正常损失处于相同的竞争条件。

通过图3可以更直观的看到知识蒸馏的过程和原理。

图片

图3:蒸馏过程示意图


蒸馏对数据集没有过多的要求,可以使用训练大模型时使用的数据集,也可以是新的数据集。知识 蒸馏的另外一个好处是可以在无标签数据上训练蒸馏模型,既只用蒸馏损失,如果标签数据较少的时候这一优点显得更加重要。很多研究中使用数据扩充的思想增强蒸馏的效果,但是Hinton的论文中提到只需要和老师模型相同的数据集上进行蒸馏训练就能达到很好的效果。



基于BERT的知识蒸馏



·  Hint损失

该方法最早来自于FitNet(Mar 2015),论文提出一种知识蒸馏的变种,该方法在图像分类任务上获得了更少参数的小模型却得到新的SOTA效果。最新研究TinyBERT(Sep 2019)中,该方法被应用在BERT上,得到了可观的蒸馏模型效果提升和7倍的模型推理速度提升。
这种方法的关键点在于不仅仅使用了老师模型的最终输出作为小模型的拟合目标,同时加入老师模型中间层的输出作为小模型训练的提示,该方法能加快训练过程同时提升小模型蒸馏的效果。同时,因为小模型和大模型的隐藏层节点数往往是不相同的,所以引入了多余的参数将两个模型对应隐藏层的输出映射起来。


为什么这种方法能训练出更好的蒸馏模型,本文尝试从以下两个角度进行解释。首先BERT在下游任务中的良好表现得益于在预训练时学到的不同层次的语义知识,而这种对语义的捕获能力很大程度由BERT较深的网络结构来保障(参考BERT学到了什么)。

所以,选择层数更少的BERT作为学生网络进行蒸馏训练时,产生的问题是浅层的结构难以学习到深层语义信息。因此丢失了BERT预训练学习到的知识,成为了制约蒸馏效果导致学生网络产生较大准确率损失的主要因素。第二点是从模型结构的角度来看,BERT-base中包含了上亿的参数量,具有多层次的注意力机制(12或24层,具体取决于模型),并且在每一层中包含多个(12或16)注意力“ 头”。由于模型参数不在层之间共享,因此一个BERT-base模型中包含有144种不同的注意力机制。Jesse Vig用开源工具bertviz在BERT中一共提取出6种直观可解释的注意力模型,这6种模式在BERT的不同层和不同注意力头之间多次出现,也就是说BERT中的注意力机制存在大量的冗余信息。基于这个发现,一定程度上保证了在缩减层数和隐藏节点数的情况下不会产生太多的性能损失,而基于层提示的蒸馏训练方式能够使得小模型更好的学习大模型中的注意力机制。


Hint损失的具体计算方式如下:

图片


其中,图片分别表示学生BERT和老师BERT的层数;图片分别表示学生和老师的中间层输出,图片下标表示第图片层的输出。Hint损失本质上是学生和老师对应层输出的图片损失。

在实际中将Embedding层和Logits层的图片损失加入到损失函数中能使得网络更加稳定的训练。需要注意的是,虽然在公式里面学生和老师网络的中间层是按固定间隔映射的(假设学生网络是层数=4的情况下,学生的图片老师图片层),但这并不是固不变的,在实际应用中可以进行不同的组合尝试,找到最合适的映射关系。然而,我们分别比较了另外两种极端情况:从开头取和从末尾取,这种固定间隔的方式是表现最好的。


·  Teacher Assistant蒸馏(TAKD)
 
知识蒸馏虽然是一种广为使用的模型压缩技术,但是并没有完整的指导理论体系,有时在尝试一些新的组合时达不到预期的效果也是常有的事情。TAKD(Feb 2019)是DeepMind的一篇论文,论文在大量试验的基础上总结出知识蒸馏时网络大小选择的指导性原则。
文中指出知识蒸馏的效果会随着学生和老师模型之间大小差距太大而发生衰减,对于一个给定大小的老师网络,并不是任何大小的模型都能通过蒸馏提升训练效果。也就是说,如果要蒸馏一个巨大的老师网络,选择学生网络时不应该太小。论文中提出教师助理TA来克服模型大小差异带来的蒸馏效果退化问题,与现实中的助教相似,助理模型或许效果没有老师模型好,但是在结构上与学生模型更接近所以能更好的传递知识。下面简单介绍以下论文中的两个试验。
实验一在CIFAR-10和CIFAR-100两个数据集上进行知识蒸馏,并给定了学生模型为一个典型的2层CNN网络,然后选用不同大小4, 6, 8, 10层的CNN网络作为老师网络。最终学习到的学生模型表现如图4所示,可以看到刚开始随着老师网络变大,学生模型的效果先是提升然后开始下降。具体的试验设置可以查看论文,通过这个试验可以看到在知识蒸馏时,如果给定了学生网络大小,并不能一味的通过提高老师模型的大小来提升蒸馏效果。

图片


图4:老师网络变大时的知识蒸馏效果


试验二在CIFAR-100上进行,主要通过选择不同的蒸馏路径验证了TAKD的有效性,试验结果如图5所示。实验中,选择2层CNN作学生模型和10层CNN作老师模型,然后选择不同的蒸馏路径逐级进行知识蒸馏。可以从结果中看出,选择TA可以增强最终的蒸馏效果图片,并且TA的个数越多,即蒸馏路径越平滑,最终蒸馏出的2层CNN的效果越好。


图片

图5:蒸馏路径与效果(老师模型为10层的CNN)



·  选择合适的蒸馏BERT结构

前面介绍了两种可以提升蒸馏BERT效果的方法,那么如何选择合适的BERT结构作为学生模型呢,主要考虑的因素有两点:推理时间 和 模型指标。
BERT结构中影响推理时间的主要有三个因素,层数、节点、Heads数。表1展示了部分BERT结构对应的推理时间,由于层数、节点、Heads数可以看作是三个相互独立的因素,从表中可以看出:层数与推理时间基本成正比关系;而节点数在大于256的时候影响因子大于1,所以选择最好落在256~768之间;头数对推理时间影响较弱。

表1:BERT推理时间表

图片


本文中提供一份完整的推理时间指引表(附录),记录了不同的BERT结构下在CPU和GPU上的推理速度表现。
最终在一份中文的自然语言理解数据集测试了蒸馏BERT的效果,如表2所示。

表2:蒸馏BERT的效果

图片


最终,蒸馏BERT以1.5%的精度和2.4%的召回为代价得到10倍的加速比。


结论



知识蒸馏模型作为一种压缩模型方法,被证明了是一种有效的手段,并且成本很小(本文中在一张v100上训练2个小时左右),在日后是可以大量使用的。在目前NLP的算法研究和应用过程中,如何解决巨大的网络模型的训练和推理问题是首要问题。事实上对于一个特定问题的解决,真的需要有这么多的网络节点吗,现有的网络到底学习到了哪些知识提取模式,以及如何将有用的知识保留下来,这是模型压缩的问题,同时也是算法性能提升的问题。

参考文献



[1]Distilling the Knowledge in a Neural Network;
[2]Smaller, faster, cheaper, lighter: Introducing DistilBERT, a distilled version of BERT;
[3]Distilling Task-Specific Knowledge from BERT into Simple Neural Networks;
[4]FitNets: Hints for Thin Deep Nets;
[5]TinyBERT: Distilling BERT for Natural Language Understanding;
[6]Improved Knowledge Distillation via Teacher Assistant: Bridging the Gap Between Student and Teacher


附件:

BERT推理时间对照表

图片





回顾上篇:

Learning to rank 简介


图片



图片




继续滑动看下一个
vivo人工智能技术
向上滑动看下一个