Seq2Seq中常见注意力机制的实现

编程入门 行业动态 更新时间:2024-10-24 05:17:57

Seq2Seq中常见<a href=https://www.elefans.com/category/jswz/34/1769627.html style=注意力机制的实现"/>

Seq2Seq中常见注意力机制的实现

引言

本文通过Pytorch实现了Seq2Seq中常用的注意力方式。

注意力方式

s c o r e ( h t , h ‾ s ) = { h t T h ‾ s dot h t T W a h ‾ s general v a T tanh ⁡ ( W a [ h t ; h ‾ s ] ) concat v a T tanh ⁡ ( W a h ‾ s + U a h t ) bahdanau score(h_t, \overline{h}_s) = \begin{cases} h_t^T \overline{h}_s & \text{dot} \\ h_t^T W_a \overline{h}_s & \text{general} \\ v_a^T \tanh (W_a[h_t; \overline{h}_s]) & \text{concat} \\ v_a^T \tanh (W_a\overline{h}_s + U_a h_t) & \text{bahdanau} \end{cases} score(ht​,hs​)=⎩⎪⎪⎪⎨⎪⎪⎪⎧​htT​hs​htT​Wa​hs​vaT​tanh(Wa​[ht​;hs​])vaT​tanh(Wa​hs​+Ua​ht​)​dotgeneralconcatbahdanau​

结合论文Effective Approaches to Attention-based Neural Machine Translation和NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE,我们得到上面四种计算注意力的方式。

编码器的每个输出 h i h_i hi​对应的权重 α i j \alpha_{ij} αij​通过如下公式计算:
α i j = e x p ( e i j ) ∑ k = 1 T x e x p ( e i k ) (6) \alpha_{ij} = \frac{exp(e_{ij})}{\sum_{k=1}^{T_x} exp(e_{ik})} \tag{6} αij​=∑k=1Tx​​exp(eik​)exp(eij​)​(6)
其中
e i j = a ( s i − 1 , h j ) e_{ij} = a(s_{i-1},h_j) eij​=a(si−1​,hj​)

见(论文翻译) NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE

代码实现

import torch.nn as nn
import torchclass Attention(nn.Module):def __init__(self, hidden_size, method='dot'):super(Attention, self).__init__()self.method = methodself.hidden_size = hidden_sizeif self.method not in ['dot', 'general', 'concat', 'bahdanau']:raise ValueError(self.method, "is not an appropriate attention method.")if self.method == 'general':self.Wa = nn.Linear(hidden_size, hidden_size, bias=False)elif self.method == 'concat':self.Wa = nn.Linear(hidden_size * 2, hidden_size, bias=False)self.va = nn.Parameter(torch.FloatTensor(1, hidden_size))elif self.method == 'bahdanau':self.Wa = nn.Linear(hidden_size, hidden_size, bias=False)self.Ua = nn.Linear(hidden_size, hidden_size, bias=False)self.va = nn.Parameter(torch.FloatTensor(1, hidden_size))def _score(self, last_hidden, encoder_outputs):''':param last_hidden: 解码器最后一层(若有多层的话)的输出 [1,batch_size,hidden_size] 解码器一次只处理一个时间步,并且只有一个方向: D=1:param encoder_outputs: 编码器所有时间步的隐藏状态 [seq_len, batch_size, hidden_size]'''if self.method == 'dot':# last_hidden * encoder_outputs [seq_len, batch_size, hidden_size]# sum(x, dim=2) 将第2个维度的值累计,累计第2个维度的值,使其维度大小变成1,并移除,得到 [seq_len, batch_size]# 计算每个批次内, 解码器当前时间步 与编码器每个时间步的 权重得分# 计算e_ireturn torch.sum(last_hidden * encoder_outputs, dim=2)  # [seq_len, batch_size]elif self.method == 'general':energy = self.Wa(last_hidden)  # [1, batch_size, hidden_size]# [seq_len, batch_size, hidden_size] x [1, batch_size, hidden_size] = [seq_len, batch_size, hidden_size]return torch.sum(encoder_outputs * energy, dim=2)  # [seq_len, batch_size]elif self.method == 'concat':# last_hidden.expand(encoder_outputs.size(0), -1, -1)) # [seq_len, batch_size, hidden_size] 对维度0进行复制操作# 复制seq_len份,以支持cat操作# cat(*, dim=2)   [seq_len, batch_size, hidden_size*2]# energy = tanh(self.Wa(*))  [seq_len,batch_size, hidden_size]energy = torch.tanh(self.Wa(torch.cat((encoder_outputs, last_hidden.expand(encoder_outputs.size(0), -1, -1)), dim=2)))return torch.sum(self.va * energy, dim=2)  # [seq_len, batch_size]else:  # method == 'bahdanau'# self.Wa(last_hidden)  [1,batch_size,hidden_size]# self.Ua(encoder_outputs) [seq_len, batch_size, hidden_size]# torch.tanh(*)  [seq_len, batch_size, hidden_size]energy = torch.tanh(self.Wa(last_hidden) + self.Ua(encoder_outputs))return torch.sum(self.va * energy, dim=2)  # [seq_len, batch_size]def forward(self, last_hidden, encoder_outputs):# 注意力得分,见_score方法,返回的大小都是 [seq_len, batch_size]attn_energies = self._score(last_hidden, encoder_outputs)# 转置 [batch_size, seq_len]attn_energies = attn_energies.t()# 经过softmax,得到权重系数,我们要计算对每个时间步的权重,所以沿着时间步的维度计算# 并且计算之后,形状保持不变。# 计算上面公式(6) α_ireturn torch.softmax(attn_energies, dim=1) \.unsqueeze(1)  # unsqueeze(1) 在dim=1处,扩展一个维度,形状变成 [batch_size, 1, seq_len]

更多推荐

Seq2Seq中常见注意力机制的实现

本文发布于:2024-03-23 19:02:04,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1741688.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:注意力   机制   常见   Seq2Seq

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!