from: https://zhuanlan.zhihu.com/p/463052305
参考:
attention-is-all-you-need-pytorch
Transformer代码详解-pytorch版
Transformer模型结构
Transformer模型结构如下图:
- Transformer的整体结构就是分成Encoder和Decoder两部分,并且两部分之间是有联系的,可以注意到Encoder的输出是Decoder第二个Multi-head Attention中和的输入。
- Encoder和Decoder分别由N个EncoderLayer和DecoderLayer组成。N默认为6个。
- EncoderLayer由两个SubLayers组成,分别是Multi-head Attention和Feed Forward。DecoderLayer则是由三个SubLayers组成,分别是Masked Multi-head Attention,Multi-head Attention和Feed Forward。
- Multi-head Attention是用ScaledDotProductAttention和Linear组成。Feed Forward是由Linear组成。
- Add & Norm指的是残差连接之后再进行LayerNorm。
各模块结构结构
Multi-head Attention结构
Feed Forward结构
EncoderLayer结构
DecoderLayer结构
Encoder结构
Decoder结构
ScaledDotProductAttention模块
ScaledDotProductAttention做的是一个attention计算。公式如下:
输入q k v,可以q先除以根号d_k(d_k默认为64,根号d_k就为8),再与k的转置相乘,再经过softmax,最后与v相乘。下图的操作和公式所做的东西是一样的。
class ScaledDotProductAttention(nn.Module): ''' Scaled Dot-Product Attention ''' def __init__(self, temperature, attn_dropout=0.1): super().__init__() # 其实就是论文中的根号d_k self.temperature = temperature self.dropout = nn.Dropout(attn_dropout) def forward(self, q, k, v, mask=None): # sz_b: batch_size 批量大小 # len_q,len_k,len_v: 序列长度 在这里他们都相等 # n_head: 多头注意力 默认为8 # d_k,d_v: k v 的dim(维度) 默认都是64 # 此时q的shape为(sz_b, n_head, len_q, d_k) (sz_b, 8, len_q, 64) # 此时k的shape为(sz_b, n_head, len_k, d_k) (sz_b, 8, len_k, 64) # 此时v的shape为(sz_b, n_head, len_k, d_v) (sz_b, 8, len_k, 64) # q先除以self.temperature(论文中的根号d_k) k交换最后两个维度(这样才可以进行矩阵相乘) 最后两个张量进行矩阵相乘 # attn的shape为(sz_b, n_head, len_q, len_k) attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) if mask is not None: # 用-1e9代替0 -1e9是一个很大的负数 经过softmax之后接近与0 # 其一:去除掉各种padding在训练过程中的影响 # 其二,将输入进行遮盖,避免decoder看到后面要预测的东西。(只用在decoder中) attn = attn.masked_fill(mask == 0, -1e9) # 先在attn的最后一个维度做softmax 再dropout 得到注意力分数 attn = self.dropout(F.softmax(attn, dim=-1)) # 最后attn与v进行矩阵相乘 # output的shape为(sz_b, 8, len_q, 64) output = torch.matmul(attn, v) # 返回 output和注意力分数 return output, attn
MultiHeadAttention和PositionwiseFeedForward模块
MultiHeadAttention做的是将q k v先经过线性层投影,再做ScaledDotProductAttention ,最后经过一个线性层。也就是下图的操作:
对应着Transformer的模块是:
PositionwiseFeedForward其实就是MLP。对应着Transformer的模块是:
# q k v 先经过不同的线性层 再用ScaledDotProductAttention 最后再经过一个线性层
class MultiHeadAttention(nn.Module): ''' Multi-Head Attention module ''' def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): # 这里的n_head, d_model, d_k, d_v分别默认为8, 512, 64, 64 super().__init__() self.n_head = n_head self.d_k = d_k self.d_v = d_v self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) self.fc = nn.Linear(n_head * d_v, d_model, bias=False) self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) def forward(self, q, k, v, mask=None): d_k, d_v, n_head = self.d_k, self.d_v, self.n_head # len_q, len_k, len_v 为输入的序列长度 sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) # 用作残差连接 residual = q # Pass through the pre-attention projection: b x lq x (n*dv) # Separate different heads: b x lq x n x dv # q k v 分别经过一个线性层再改变维度 # 由(sz_b, len_q, n_head*d_k) => (sz_b, len_q, n_head, d_k) (sz_b, len_q, 8*64) => (sz_b, len_q, 8, 64) q = self.w_qs(q).view(sz_b, len_q, n_head, d_k