GithubHelp home page GithubHelp logo

llama2's Introduction

注:若要正常显示公式请安装google插件:MathJax Plugin for GitHub

llama2

利用Transformer库从0开始搭建llama2

模型结构图

遇到的问题求解:

  1. 为什么要左填充

llama2 用到的技术梳理

(基本原则)讲清楚缘由:为什么这样子这么做,涉及到模型本身的搭建:

旋转位置编码 RoPE

  • 绝对位置编码: $$attent(x_m, x_n, m, n) = f(x_m, x_n, m, n)$$
  • 相对位置编码: $$attent(x_m, x_n, m, n) = f(x_m, x_n, m - n)$$
  • 旋转位置编码:旋转位置编的巧妙之处在于通过绝对编码的方式实现了相对编码的效果。

绝对位置编码的优点是计算简单高效,缺点是一般效果不如相对位置编码。相对位置编码的优点是效果较好,缺点是计算效率不如绝对位置编码。在相对位置编码中,注意力权重的结果仅仅和参与注意力计算的token向量的相对位置有关,不和绝对位置直接关联。这符合NLP领域在序列长度方向上具有平移不变性的特点,所以相对位置编码一般效果会优于绝对位置编码。绝对位置编码只需要初始化时对序列的每个位置(数量正比于序列长度)赋予位置编码即可,后续无需干预。而相对位置编码要在计算过程中获取许多个(数量正比于序列长度平方)相对位置。因此绝对位置编码更加简单高效。

个人理解:需要注意的点就是,在《attention is all your need》中,位置编码是与embeding vector求和的方式,而在llama2中位置编码的实现方式是让embeding vector与旋转矩阵相乘实现的。即

$$(R(m)x_m)^TR(n)x_n = x_m^TR^T(m)*R(n)x_n = x_m^TR(m - n)x_n$$

由此可以看出,旋转矩阵的构造需要满足:

$$R(m)^T*R(n) = R(m - n)$$

NTK

  • 直接外推:直接外推就是继续沿用现有的位置编码公式,不做任何修改。直接外推对衰减规律在长距离情况下的使用容易出现问题,导致性能下降。原因,权重的学习已经拟合了原来训练时的前后文文本长度,所以当外推长度过长时,注意力可以会为0从而导致即时在推理时用更长的前后文文本也没有效果。
  • 线性内插:线性内插需要改变位置编码公式,等效于将位置序号等比例缩小。线性内插没有改变模型学习到的衰减规律的应用范围,不考虑微调的话,其效果一般好于直接外推方案。但是,扩展倍数非常大的时候,例如从2k扩展到32k,其性能也会明显的受到影响。因为在这种情况下,衰减规律在短距离情况下的使用会受到较严重的影响,本来距离为1的两个token,长度扩展后相当于变成了距离为1/16,衰减规律在短距离时可能具有非常大的变化率(短距离变得敏感),因此对相关性的评估可能会极端地偏离合理值。
  • NTK: 考虑到RoPE的性质,短距离之间的差异(例如1和5的差异),主要体现在高频分量(i比较大)上,长距离之间的差异(例如5000和10000的差异),主要体现在低频分量(i比较小)上。因此可以考虑乘上一个与频率i有关的因子,使得短距离时(主要差异体现在高频段)具有外推性质(尽量接近1,即在短范围时使之尽量保持不懂),使得长距离时(主要差异体现在低频段)具有内插性质(尽量接近1/s,使之缩放到扩展前的范围。)

group attention

定义: 作用:

RMSNorm

  • LayerNorm的作用:通过均值-方差归一化使得输入和权重矩阵具有重新居中和重新缩放的能力,帮助稳定训练并促进模型收敛。

  • RMSNorm的作用:LayerNorm的改进,他假设LayerNorm中的重新居中性质并不是必需的,于是RMSNorm根据均方根(RMS)对某一层中的神经元输入进行规范化,同时成一个可以学习的权重系数,赋予模型重新缩放的不变性属性和隐式学习率自适应能力。相比LayerNorm,RMSNorm在计算上更简单,因此更加高效。可以保证在层归一化不会改变hidden_states对应的词向量的方向,只会改变其模长。 个人理解:这样既可以将向量映射到一个梯度敏感的区间,又可以尽量的不改变向量的性质。

代码梳理

decoder layer 叠加

模型的主体部分就是通过decoder layer 堆叠起来的。

RoPE计算的加速

可以将旋转位置编码过程由矩阵乘法运算简化成两次向量的哈达玛积求和。

attention mask的上三角实现细节 (广播机制)

利用了torch的广播机制,就实现了上三角mask矩阵的构建。简化的示意图如下:

分类问题的head 映射层

loss计算问题

llama2's People

Contributors

648591461 avatar

Stargazers

Lingzheng Kong avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.