注意力机制的实现"/>
Seq2Seq中常见注意力机制的实现
引言
本文通过Pytorch实现了Seq2Seq中常用的注意力方式。
注意力方式
s c o r e ( h t , h ‾ s ) = { h t T h ‾ s dot h t T W a h ‾ s general v a T tanh ( W a [ h t ; h ‾ s ] ) concat v a T tanh ( W a h ‾ s + U a h t ) bahdanau score(h_t, \overline{h}_s) = \begin{cases} h_t^T \overline{h}_s & \text{dot} \\ h_t^T W_a \overline{h}_s & \text{general} \\ v_a^T \tanh (W_a[h_t; \overline{h}_s]) & \text{concat} \\ v_a^T \tanh (W_a\overline{h}_s + U_a h_t) & \text{bahdanau} \end{cases} score(ht,hs)=⎩⎪⎪⎪⎨⎪⎪⎪⎧htThshtTWahsvaTtanh(Wa[ht;hs])vaTtanh(Wahs+Uaht)dotgeneralconcatbahdanau
结合论文Effective Approaches to Attention-based Neural Machine Translation和NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE,我们得到上面四种计算注意力的方式。
编码器的每个输出 h i h_i hi对应的权重 α i j \alpha_{ij} αij通过如下公式计算:
α i j = e x p ( e i j ) ∑ k = 1 T x e x p ( e i k ) (6) \alpha_{ij} = \frac{exp(e_{ij})}{\sum_{k=1}^{T_x} exp(e_{ik})} \tag{6} αij=∑k=1Txexp(eik)exp(eij)(6)
其中
e i j = a ( s i − 1 , h j ) e_{ij} = a(s_{i-1},h_j) eij=a(si−1,hj)
见(论文翻译) NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE
代码实现
import torch.nn as nn
import torchclass Attention(nn.Module):def __init__(self, hidden_size, method='dot'):super(Attention, self).__init__()self.method = methodself.hidden_size = hidden_sizeif self.method not in ['dot', 'general', 'concat', 'bahdanau']:raise ValueError(self.method, "is not an appropriate attention method.")if self.method == 'general':self.Wa = nn.Linear(hidden_size, hidden_size, bias=False)elif self.method == 'concat':self.Wa = nn.Linear(hidden_size * 2, hidden_size, bias=False)self.va = nn.Parameter(torch.FloatTensor(1, hidden_size))elif self.method == 'bahdanau':self.Wa = nn.Linear(hidden_size, hidden_size, bias=False)self.Ua = nn.Linear(hidden_size, hidden_size, bias=False)self.va = nn.Parameter(torch.FloatTensor(1, hidden_size))def _score(self, last_hidden, encoder_outputs):''':param last_hidden: 解码器最后一层(若有多层的话)的输出 [1,batch_size,hidden_size] 解码器一次只处理一个时间步,并且只有一个方向: D=1:param encoder_outputs: 编码器所有时间步的隐藏状态 [seq_len, batch_size, hidden_size]'''if self.method == 'dot':# last_hidden * encoder_outputs [seq_len, batch_size, hidden_size]# sum(x, dim=2) 将第2个维度的值累计,累计第2个维度的值,使其维度大小变成1,并移除,得到 [seq_len, batch_size]# 计算每个批次内, 解码器当前时间步 与编码器每个时间步的 权重得分# 计算e_ireturn torch.sum(last_hidden * encoder_outputs, dim=2) # [seq_len, batch_size]elif self.method == 'general':energy = self.Wa(last_hidden) # [1, batch_size, hidden_size]# [seq_len, batch_size, hidden_size] x [1, batch_size, hidden_size] = [seq_len, batch_size, hidden_size]return torch.sum(encoder_outputs * energy, dim=2) # [seq_len, batch_size]elif self.method == 'concat':# last_hidden.expand(encoder_outputs.size(0), -1, -1)) # [seq_len, batch_size, hidden_size] 对维度0进行复制操作# 复制seq_len份,以支持cat操作# cat(*, dim=2) [seq_len, batch_size, hidden_size*2]# energy = tanh(self.Wa(*)) [seq_len,batch_size, hidden_size]energy = torch.tanh(self.Wa(torch.cat((encoder_outputs, last_hidden.expand(encoder_outputs.size(0), -1, -1)), dim=2)))return torch.sum(self.va * energy, dim=2) # [seq_len, batch_size]else: # method == 'bahdanau'# self.Wa(last_hidden) [1,batch_size,hidden_size]# self.Ua(encoder_outputs) [seq_len, batch_size, hidden_size]# torch.tanh(*) [seq_len, batch_size, hidden_size]energy = torch.tanh(self.Wa(last_hidden) + self.Ua(encoder_outputs))return torch.sum(self.va * energy, dim=2) # [seq_len, batch_size]def forward(self, last_hidden, encoder_outputs):# 注意力得分,见_score方法,返回的大小都是 [seq_len, batch_size]attn_energies = self._score(last_hidden, encoder_outputs)# 转置 [batch_size, seq_len]attn_energies = attn_energies.t()# 经过softmax,得到权重系数,我们要计算对每个时间步的权重,所以沿着时间步的维度计算# 并且计算之后,形状保持不变。# 计算上面公式(6) α_ireturn torch.softmax(attn_energies, dim=1) \.unsqueeze(1) # unsqueeze(1) 在dim=1处,扩展一个维度,形状变成 [batch_size, 1, seq_len]
更多推荐
Seq2Seq中常见注意力机制的实现
发布评论