Vision Transformer详解(附代码)

编程入门 行业动态 更新时间:2024-10-12 16:30:00

Vision Transformer<a href=https://www.elefans.com/category/jswz/34/1770044.html style=详解(附代码)"/>

Vision Transformer详解(附代码)

1 引言

T r a n s f o r m e r \mathrm{Transformer} Transformer在 N L P \mathrm{NLP} NLP中大获成功, V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer则将 T r a n s f o r m e r \mathrm{Transformer} Transformer模型架构扩展到计算机视觉的领域中,并且它可以很好的地取代卷积操作,在不依赖卷积的情况下,依然可以在图像分类任务上达到很好的效果。卷积操作只能考虑到局部的特征信息,而 T r a n s f o r m e r \mathrm{Transformer} Transformer中的注意力机制可以综合考量全局的特征信息。 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer尽力做到在不改变 T r a n s f o r m e r \mathrm{Transformer} Transformer中 E n c o d e r \mathrm{Encoder} Encoder架构的前提下,直接将其从 N L P \mathrm{NLP} NLP领域迁移到计算机视觉领域中,目的是让原始的 T r a n s f o r m e r \mathrm{Transformer} Transformer模型开箱即用。如果想要了解 T r a n s f o r m e r \mathrm{Transformer} Transformer原理详细的介绍可以看我的上一篇文章《Transformer详解(附代码)》。

2 注意力机制应用

在正式详细介绍 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer之前,先介绍两个注意力机制在计算机视觉中应用的例子。 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer并不是第一个将注意力机制应用到计算机视觉的领域中去的,其中 S A G A N \mathrm{SAGAN} SAGAN和 A t t n G A N \mathrm{AttnGAN} AttnGAN就早已经在 G A N \mathrm{GAN} GAN的框架中引入了注意力机制,并且它们大大提高了图像生成的质量。

2.1 Self-Attention GAN

S A G A N \mathrm{SAGAN} SAGAN在 G A N \mathrm{GAN} GAN的框架中利用自注意力机制来捕获图像特征的长距离依赖关系,使得合成的图像中考量了所有的图像特征信息。 S A G A N \mathrm{SAGAN} SAGAN中自注意力机制的操作原理如下图所示。
给定一个 3 3 3通道的输入特征图 X = ( X 1 , X 2 , X 3 ) ∈ R 3 × 3 × 3 X=(X^1,X^2,X^3)\in \mathbb{R}^{3\times 3\times 3} X=(X1,X2,X3)∈R3×3×3,其中 X i ∈ R 3 × 3 X^{i}\in \mathbb{R}^{3\times 3} Xi∈R3×3, i ∈ { 1 , 2 , 3 } i\in\{1,2,3\} i∈{1,2,3}。将 X X X分别输入到三个不同的 1 × 1 1\times 1 1×1的卷积层中,并生成 q u e r y \mathrm{query} query特征图 Q ∈ R 3 × 3 × 3 Q\in \mathbb{R}^{3\times 3\times 3} Q∈R3×3×3, k e y \mathrm{key} key特征图 K ∈ R 3 × 3 × 3 K\in \mathbb{R}^{3\times 3\times 3} K∈R3×3×3和 v a l u e \mathrm{value} value特征图 V ∈ R 3 × 3 × 3 V\in \mathbb{R}^{3\times 3\times 3} V∈R3×3×3。生成 Q Q Q具体的计算过程为,给定三个卷积核 W q 1 W^{q1} Wq1, W q 2 W^{q2} Wq2和 W q 3 ∈ R 1 × 1 × 3 W^{q3}\in\mathbb{R}^{1\times1\times3} Wq3∈R1×1×3,并用这三个卷积核分别与 X X X做卷积运算得到 Q 1 Q^1 Q1, Q 2 Q^2 Q2和 Q 3 ∈ R 3 × 3 Q^3\in \mathbb{R}^{3 \times 3} Q3∈R3×3,即 { Q 1 = X ∗ W q 1 Q 2 = X ∗ W q 2 Q 3 = X ∗ W q 3 \left\{\begin{aligned}Q^1&=X * W^{q1}\\Q^2&=X * W^{q2}\\Q^3&=X*W^{q3}\end{aligned}\right. ⎩⎪⎨⎪⎧​Q1Q2Q3​=X∗Wq1=X∗Wq2=X∗Wq3​其中 ∗ * ∗表示卷积运算符号。同理生成 K K K和 V V V的计算过程与 Q Q Q的计算过程类似。然后再利用 Q Q Q和 K K K进行注意力分数的计算得到矩阵 A ∈ R 3 × 3 A\in \mathbb{R}^{3 \times 3} A∈R3×3,其中矩阵 A A A的元素 a m l a_{ml} aml​的计算公式为 a m l = Q m ∗ K l , m ∈ { 1 , 2 , 3 } , l ∈ { 1 , 2 , 3 } a_{ml}=Q^m * K^l,\quad m \in \{1,2,3\},l\in \{1,2,3\} aml​=Qm∗Kl,m∈{1,2,3},l∈{1,2,3}再对矩阵 A A A利用 s o f t m a x \mathrm{softmax} softmax函数进行注意力分布的计算得到注意力分布矩阵 S ∈ R 3 × 3 S\in \mathbb{R}^{3\times 3} S∈R3×3,其中矩阵 S S S的元素 s m l s_{ml} sml​的计算公式为 s m l = exp ⁡ ( a m l ) ∑ i = j 3 exp ⁡ ( a m j ) , m ∈ { 1 , 2 , 3 } , l ∈ { 1 , 2 , 3 } s_{ml}=\frac{\exp(a_{ml})}{\sum\limits_{i=j}^{3}\exp(a_{mj})},\quad m \in \{1,2,3\},l\in\{1,2,3\} sml​=i=j∑3​exp(amj​)exp(aml​)​,m∈{1,2,3},l∈{1,2,3}最后利用注意力分布矩阵 S S S和 v a l u e \mathrm{value} value特征图 V V V得到最后的输出 O = ( O 1 , O 2 , O 3 ) ∈ R 3 × 3 × 3 O=(O^1,O^2,O^3)\in \mathbb{R}^{3\times 3\times 3} O=(O1,O2,O3)∈R3×3×3,即 { O 1 = s 11 ⋅ V 1 + s 12 ⋅ V 2 + s 13 ⋅ V 3 O 2 = s 21 ⋅ V 1 + s 22 ⋅ V 2 + s 23 ⋅ V 3 O 3 = s 31 ⋅ V 1 + s 32 ⋅ V 2 + s 33 ⋅ V 3 \left\{\begin{aligned}O^1&=s_{11}\cdot V^1+s_{12}\cdot V^2+s_{13}\cdot V^3\\O^2&=s_{21}\cdot V^1+s_{22}\cdot V^2+s_{23}\cdot V^3\\O^3&=s_{31}\cdot V^1+s_{32}\cdot V^2+s_{33}\cdot V^3\end{aligned}\right. ⎩⎪⎨⎪⎧​O1O2O3​=s11​⋅V1+s12​⋅V2+s13​⋅V3=s21​⋅V1+s22​⋅V2+s23​⋅V3=s31​⋅V1+s32​⋅V2+s33​⋅V3​

2.2 AttnGAN

A t t n G A N \mathrm{AttnGAN} AttnGAN通过利用注意力机制来实现多阶段细颗粒度的文本到图像的生成,它可以通过关注自然语言中的一些重要单词来对图像的不同子区域进行合成。比如通过文本“一只鸟有黄色的羽毛和黑色的眼睛”来生成图像时,会对关键词“鸟”,“羽毛”,“眼睛”,“黄色”,“黑色”给予不同的生成权重,并根据这些关键词的引导在图像的不同的子区域中进行细节的丰富。 A t t n G A N \mathrm{AttnGAN} AttnGAN中注意力机制的操作原理如下图所示。
 给定输入图像特征向量 h = ( h 1 , h 2 , h 3 , h 4 ) ∈ R D ^ × 4 h=(h^1,h^2,h^3,h^4)\in\mathbb{R}^{\hat{D}\times 4} h=(h1,h2,h3,h4)∈RD^×4和词特征向量 e = ( e 1 , e 2 , e 3 , e 4 ) e=(e^1,e^2,e^3,e^4) e=(e1,e2,e3,e4),其中 h i ∈ R D ^ × 1 h^i\in \mathbb{R}^{\hat{D}\times 1} hi∈RD^×1, e i ∈ R D × 1 e^i\in \mathbb{R}^{D\times 1} ei∈RD×1, i ∈ { 1 , 2 , 3 , 4 } i\in \{1,2,3,4\} i∈{1,2,3,4}。首先利用矩阵 W W W进行线性变换将词特征空间 R D \mathbb{R}^{D} RD的向量转换成图像特征空间 R D ^ \mathbb{R}^{\hat{D}} RD^的向量,则有 e ^ = W ⋅ e = ( e ^ 1 , e ^ 2 , e ^ 3 , e ^ 4 ) ∈ R D ^ × 4 \hat{e}=W\cdot e=(\hat{e}^1,\hat{e}^2,\hat{e}^3,\hat{e}^4)\in \mathbb{R}^{\hat{D}\times 4} e^=W⋅e=(e^1,e^2,e^3,e^4)∈RD^×4然后再利用转换后的词特征 e ^ \hat{e} e^与图像特征 h h h进行注意力分数的计算得到注意力分数矩阵 S S S,其中的分量 s i j s_{ij} sij​的计算公式为 s i j = ( h i ) ⊤ ⋅ e ^ j , i ∈ { 1 , 2 , 3 , 4 } , j ∈ { 1 , 2 , 3 , 4 } s_{ij}=(h^i)^{\top}\cdot \hat{e}^j,\quad i\in \{1,2,3,4\},j\in\{1,2,3,4\} sij​=(hi)⊤⋅e^j,i∈{1,2,3,4},j∈{1,2,3,4} 再对矩阵 S S S利用 s o f t m a x \mathrm{softmax} softmax函数进行注意力分布的计算得到注意力分布矩阵 β ∈ R 4 × 4 \beta\in \mathbb{R}^{4\times 4} β∈R4×4,其中矩阵 β \beta β的元素 β i j \beta_{ij} βij​的计算公式为 β i j = exp ⁡ ( s i j ) ∑ k = 1 3 exp ⁡ ( s i k ) , i ∈ { 1 , 2 , 3 , 4 } , l ∈ { 1 , 2 , 3 , 4 } \beta_{ij}=\frac{\exp(s_{ij})}{\sum\limits_{k=1}^{3}\exp(s_{ik})},\quad i \in \{1,2,3,4\},l\in\{1,2,3,4\} βij​=k=1∑3​exp(sik​)exp(sij​)​,i∈{1,2,3,4},l∈{1,2,3,4}最后利用注意力分布矩阵 β \beta β和图像特征 h h h得到最后的输出 o = ( o 1 , o 2 , o 3 , o 4 ) ∈ R D ^ × 4 o=(o^1,o^2,o^3,o^4)\in \mathbb{R}^{\hat{D}\times 4} o=(o1,o2,o3,o4)∈RD^×4,即 { o 1 = β 11 ⋅ h 1 + β 12 ⋅ h 2 + β 13 ⋅ h 3 + β 14 ⋅ h 4 o 2 = β 21 ⋅ h 1 + β 22 ⋅ h 2 + β 23 ⋅ h 3 + β 24 ⋅ h 4 o 3 = β 31 ⋅ h 1 + β 32 ⋅ h 2 + β 33 ⋅ h 3 + β 34 ⋅ h 4 o 4 = β 41 ⋅ h 1 + β 42 ⋅ h 2 + β 43 ⋅ h 3 + β 44 ⋅ h 4 \left\{\begin{aligned}o^1&=\beta_{11}\cdot h^1+\beta_{12}\cdot h^2+\beta_{13}\cdot h^3+\beta_{14}\cdot h^4\\o^2&=\beta_{21}\cdot h^1+\beta_{22}\cdot h^2+\beta_{23}\cdot h^3+\beta_{24}\cdot h^4\\o^3&=\beta_{31}\cdot h^1+\beta_{32}\cdot h^2+\beta_{33}\cdot h^3+\beta_{34}\cdot h^4\\o^4&=\beta_{41}\cdot h^1+\beta_{42}\cdot h^2+\beta_{43}\cdot h^3+\beta_{44}\cdot h^4\end{aligned}\right. ⎩⎪⎪⎪⎪⎨⎪⎪⎪⎪⎧​o1o2o3o4​=β11​⋅h1+β12​⋅h2+β13​⋅h3+β14​⋅h4=β21​⋅h1+β22​⋅h2+β23​⋅h3+β24​⋅h4=β31​⋅h1+β32​⋅h2+β33​⋅h3+β34​⋅h4=β41​⋅h1+β42​⋅h2+β43​⋅h3+β44​⋅h4​

3 Vision Transformer

本节主要详细介绍 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的工作原理,3.1节是关于 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的整体框架,3.2节是关于 T r a n s f o r m e r E n c o d e r \mathrm{Transformer\text{ }Encoder} Transformer Encoder的内部操作细节。对于 T r a n s f o r m e r E n c o d e r \mathrm{Transformer\text{ }Encoder} Transformer Encoder中 M u l t i \mathrm{Multi} Multi- H e a d A t t e n t i o n \mathrm{Head\text{ }Attention} Head Attention的原理本文不会赘述,具体想了解的可以参考上一篇文章《Transformer详解(附代码)》中相关原理的介绍。不难发现,不管是自然语言处理中的 T r a n s f o r m e r \mathrm{Transformer} Transformer,还是计算机视觉中图像生成的 S A G A N \mathrm{SAGAN} SAGAN,以及文本生成图像的 A t t n G A N \mathrm{AttnGAN} AttnGAN,它们核心模块中注意力机制的主要目的就是求出注意力分布。

3.1 Vision Transformer整体框架

如果下图所示为 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的整体框架以及相应的训练流程

  • 给定一张图片 X ∈ R 3 n × 3 n X\in \mathbb{R}^{3n\times 3n} X∈R3n×3n,并将它分割成 9 9 9个 p a t c h \mathrm{patch} patch分别为 x 1 , ⋯ , x 9 ∈ R n × n x^1,\cdots,x^9\in\mathbb{R}^{n\times n} x1,⋯,x9∈Rn×n。然后再将这个 9 9 9个 p a t c h \mathrm{patch} patch拉平,则有 x 1 , ⋯ , x 9 ∈ R n 2 x^1,\cdots,x^9\in\mathbb{R}^{n^2} x1,⋯,x9∈Rn2
  • 利用矩阵 W ∈ R l × n 2 W\in \mathbb{R}^{l \times n^2} W∈Rl×n2将拉平后的向量 x i ∈ R n 2 , i ∈ { 1 , ⋯ , 9 } x^i\in\mathbb{R}^{n^2},i\in\{1,\cdots,9\} xi∈Rn2,i∈{1,⋯,9}经过线性变换得到图像编码向量 z i ∈ R l , i ∈ { 1 , ⋯ , 9 } z^i\in \mathbb{R}^{l},i\in\{1,\cdots,9\} zi∈Rl,i∈{1,⋯,9},具体的计算公式为 z i = W ⋅ x i , i ∈ { 1 , ⋯ 9 } z^i = W\cdot x^i,\quad i\in\{1,\cdots9\} zi=W⋅xi,i∈{1,⋯9}
  • 然后将图像编码向量 z i , i ∈ { 1 , ⋅ , 9 } z^{i},i\in\{1,\cdot,9\} zi,i∈{1,⋅,9}和类编码向量 z 0 z^0 z0分别与对应的位置编进行加和得到输入编码向量,则有 z i + p i ∈ R l , i ∈ { 0 , ⋯ 9 } z^{i}+p^{i}\in\mathbb{R}^l,\quad i\in\{0,\cdots 9\} zi+pi∈Rl,i∈{0,⋯9}
  • 接着将输入编码向量输入到 V i s i o n T r a n s f o r m e r E n c o d e r \mathrm{Vision\text{ }Transformer\text{ }Encoder} Vision Transformer Encoder中得到对应的输出 o i ∈ R l , i ∈ { 0 , ⋯ , 9 } o^i\in \mathbb{R}^l,i\in\{0,\cdots,9\} oi∈Rl,i∈{0,⋯,9}
  • 最后将类编码向量 o 0 o^0 o0输入全连接神经网络中 M L P \mathrm{MLP} MLP得到类别预测向量 y ^ ∈ R c \hat{y}\in\mathbb{R}^c y^​∈Rc,并与真实类别向量 y ∈ R c y\in\mathbb{R}^c y∈Rc计算交叉熵损失得到损失值 l o s s loss loss,利用优化算法更新模型的权重参数

注意事项: 看到这里可能会有一个疑问为什么预测类别的时候只用到了类别编码向量 o 0 o^0 o0, V i s i o n T r a n s f o r m e r E n c o d e r \mathrm{Vision\text{ }Transformer\text{ }Encoder} Vision Transformer Encoder其它的输出为什么没有输入到 M L P \mathrm{MLP} MLP中?为了回答这个问题,我们令函数 f 0 ( ⋅ ) f_0(\cdot) f0​(⋅)为 V i s i o n T r a n s f o r m e r E n c o d e r \mathrm{Vision\text{ }Transformer\text{ }Encoder} Vision Transformer Encoder,则类编码向量 o 0 o^{0} o0可以表示为 o 0 = f 0 ( z 0 + p 0 , ⋯ , z 9 + p 9 ) o^0=f_0(z^0+p^0,\cdots,z^9+p^9) o0=f0​(z0+p0,⋯,z9+p9)由上公式可以发现,类编码向量 o 0 o^{0} o0是属于高层特征,其实它综合了所有的图像编码信息,所以可以用它来进行分类,这个可以类比在卷积神经网络中最后的类别输出向量其实就是一层层卷积得到的高层特征。

3.2 Transformer Encoder操作原理

如下图所示分别为 V i s i o n T r a n s f o r m e r E n c o d e r \mathrm{Vision\text{ }Transformer\text{ }Encoder} Vision Transformer Encoder模型结构图和原始 T r a n s f o r m e r E n c o d e r \mathrm{Transformer\text{ }Encoder} Transformer Encoder的模型结构图。可以直观的发现 V i s i o n T r a n s f o r m e r E n c o d e r \mathrm{Vision\text{ }Transformer\text{ }Encoder} Vision Transformer Encoder和 T r a n s f o r m e r E n c o d e r \mathrm{Transformer\text{ }Encoder} Transformer Encoder都有层归一化,多头注意力机制,残差连接和线性变换这四个操作,只是在操作顺序有所不同。在以下的 T r a n s f o r m e r \mathrm{ \text{ }Transformer}  Transformer代码实例中,将以下两种 E n c o d e r \mathrm{Encoder} Encoder网络结构都进行了实现,可以发现两种网络结构都可以进行很好的训练。
下图左半部分 V i s i o n T r a n s f o r m e r E n c o d e r \mathrm{Vision\text{ }Transformer\text{ }Encoder} Vision Transformer Encoder具体的操作流程为

  • 给定输入编码矩阵 Z ∈ R l × n Z\in\mathbb{R}^{l\times n} Z∈Rl×n,首先将其进行层归一化得到 Z ′ ∈ R l × n Z^{\prime}\in\mathbb{R}^{l \times n} Z′∈Rl×n
  • 利用矩阵 W q , W k , W v ∈ R l × l W^{q},W^{k},W^{v}\in \mathbb{R}^{l\times l} Wq,Wk,Wv∈Rl×l对 Z ′ Z^{\prime} Z′进行线性变换得到矩阵 Q , K , W ∈ R l × n Q,K,W\in\mathbb{R}^{l\times n} Q,K,W∈Rl×n具体的计算过程为 { Q = W q ⋅ Z ′ K = W k ⋅ Z ′ V = W v ⋅ Z ′ \left\{\begin{aligned}Q &= W^{q}\cdot Z^{\prime}\\K&=W^{k}\cdot Z^{\prime}\\V&=W^v \cdot Z^{\prime}\end{aligned}\right. ⎩⎪⎨⎪⎧​QKV​=Wq⋅Z′=Wk⋅Z′=Wv⋅Z′​再将这三个矩阵输入到 M u l t i \mathrm{Multi} Multi- H e a d A t t e n t i o n \mathrm{Head\text{ }Attention} Head Attention(该原理参考《Transformer详解(附代码)》)中得到矩阵 Z ′ ′ ∈ R l × n Z^{\prime\prime}\in \mathbb{R}^{l \times n} Z′′∈Rl×n将最原始的输入矩阵 Z Z Z与 Z ′ ′ Z^{\prime\prime} Z′′进行残差计算得到 Z + Z ′ ′ ∈ R l × n Z+Z^{\prime\prime}\in \mathbb{R}^{l\times n} Z+Z′′∈Rl×n
  • 将 Z + Z ′ ′ Z+Z^{\prime\prime} Z+Z′′进行第二次层归一化得到 Z ′ ′ ′ ∈ R l × n Z^{\prime\prime\prime}\in\mathbb{R}^{l\times n} Z′′′∈Rl×n,然后再将 Z ′ ′ ′ Z^{\prime\prime\prime} Z′′′输入到全连接神经网络中进行线性变换得到 Z ′ ′ ′ ′ ∈ R l × n Z^{\prime\prime\prime\prime}\in\mathbb{R}^{l\times n} Z′′′′∈Rl×n。最后将 Z + Z ′ ′ Z+Z^{\prime\prime} Z+Z′′与 Z ′ ′ ′ ′ Z^{\prime\prime\prime\prime} Z′′′′进行残差操作得到该 B l o c k \mathrm{Block} Block的输出 Z + Z ′ ′ + Z ′ ′ ′ ′ ∈ R l × n Z+Z^{\prime\prime}+Z^{\prime\prime\prime\prime}\in\mathbb{R}^{l\times n} Z+Z′′+Z′′′′∈Rl×n。一个 E n c o d e r \mathrm{Encoder} Encoder可以将 N N N个 B l o c k \mathrm{Block} Block进行堆叠,最后得到的输出为 O ∈ R l × n O\in\mathbb{R}^{l\times n} O∈Rl×n。

4 程序代码

V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的代码示例如下所示。该代码是由上一篇《Transformer详解(附代码)》的代码的基础上改编而来。 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的作者的本意就是想让在 N L P \mathrm{NLP} NLP中的 T r a n s f o r m e r \mathrm{Transformer} Transformer模型架构做尽可能少的修改可以直接迁移到 C V \mathrm{CV} CV中,所以以下程序尽可能保持作者的原意,并在代码实现了两种 E n c o d e r \mathrm{Encoder} Encoder的网络结构,即3.2节图片所示的两个网络结构,一种是最原始的 E n c o d e r \mathrm{Encoder} Encoder网络结构,一种是 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer论文里的 E n c o d e r \mathrm{Encoder} Encoder的网络结构。这里需要注意的是, V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer里并能没有 D e c o d e r \mathrm{Decoder} Decoder模块,所以不需要计算 E n c o d e r \mathrm{Encoder} Encoder和 D e c o d e r \mathrm{Decoder} Decoder的交叉注意力分布,这就进一步给 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的编程带来了简便。 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的开源代码的网址为

import torch
import torch.nn as nn
import os
from einops import rearrange
from einops import repeat
from einops.layers.torch import Rearrangedef inputs_deal(inputs):return inputs if isinstance(inputs, tuple) else(inputs, inputs)class SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)def forward(self, values, keys, query):N =query.shape[0]value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]# split embedding into self.heads piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)# queries shape: (N, query_len, heads, heads_dim)# keys shape : (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, heads_dim)# (N, query_len, heads, head_dim)out = self.fc_out(out)return outclass TransformerBlock(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion*embed_size),nn.ReLU(),nn.Linear(forward_expansion*embed_size, embed_size))self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, x, type_mode):if type_mode == 'original':attention = self.attention(value, key, query)x = self.dropout(self.norm(attention + x))forward = self.feed_forward(x)out = self.dropout(self.norm(forward + x))return outelse:attention = self.attention(self.norm(value), self.norm(key), self.norm(query))x =self.dropout(attention + x)forward = self.feed_forward(self.norm(x))out = self.dropout(forward + x)return outclass TransformerEncoder(nn.Module):def __init__(self,embed_size,num_layers,heads,forward_expansion,dropout = 0,type_mode = 'original'):super(TransformerEncoder, self).__init__()self.embed_size = embed_sizeself.type_mode = type_modeself.Query_Key_Value = nn.Linear(embed_size, embed_size * 3, bias = False)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion,)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x):for layer in self.layers:QKV_list = self.Query_Key_Value(x).chunk(3, dim = -1)x = layer(QKV_list[0], QKV_list[1], QKV_list[2], x, self.type_mode)return xclass VisionTransformer(nn.Module):def __init__(self, image_size, patch_size, num_classes, embed_size, num_layers, heads, mlp_dim, pool = 'cls',channels = 3,dropout = 0,emb_dropout = 0.1,type_mode = 'vit'):super(VisionTransformer, self).__init__()img_h, img_w = inputs_deal(image_size)patch_h, patch_w = inputs_deal(patch_size)assert img_h % patch_h == 0 and img_w % patch_w == 0, 'Img dimensions can be divisible by the patch dimensions'num_patches = (img_h // patch_h) * (img_w // patch_w)patch_size = channels * patch_h * patch_wself.patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_h, p2=patch_w),nn.Linear(patch_size, embed_size, bias=False))self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_size))self.cls_token = nn.Parameter(torch.randn(1, 1, embed_size))self.dropout = nn.Dropout(emb_dropout)self.transformer = TransformerEncoder(embed_size, num_layers, heads, mlp_dim,dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Sequential(nn.LayerNorm(embed_size),nn.Linear(embed_size, num_classes))def forward(self, img):x = self.patch_embedding(img)b, n, _ = x.shapecls_tokens = repeat(self.cls_token, '() n d ->b n d', b = b)x = torch.cat((cls_tokens, x), dim = 1)x += self.pos_embedding[:, :(n + 1)]x = self.dropout(x)x = self.transformer(x)x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]x = self.to_latent(x)return self.mlp_head(x)if __name__ == '__main__':vit = VisionTransformer(image_size = 256,patch_size = 16,num_classes = 10,embed_size = 256,num_layers = 6,heads = 8,mlp_dim = 512,dropout = 0.1,emb_dropout = 0.1)img = torch.randn(3, 3, 256, 256)pred = vit(img)print(pred)

以下代码是利用 V i s i o n T r a n s f o r m e r \mathrm{Vision \text{ }Transformer} Vision Transformer网络结构训练一个分类 m n i s t \mathrm{mnist} mnist数据集的主程序代码。

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import VIT
import osdef train():batch_size = 4device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')epoches = 20mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= batch_size, shuffle=True)mnist_model = VIT.VisionTransformer(image_size = 28,patch_size = 7,num_classes = 10,channels = 1,embed_size = 512,num_layers = 1,heads = 2,mlp_dim =1024,dropout = 0,emb_dropout = 0)loss_fn = nn.CrossEntropyLoss()mnist_model = mnist_model.to(device)opitimizer = optim.Adam(mnist_model.parameters(), lr=0.00001)mnist_model.train()for epoch in range(epoches):total_loss = 0 corrects = 0 num = 0for batch_X, batch_Y in train_loader:batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)opitimizer.zero_grad()outputs = mnist_model(batch_X)_, pred = torch.max(outputs.data, 1)loss = loss_fn(outputs, batch_Y)loss.backward()opitimizer.step()total_loss += loss.item()corrects = torch.sum(pred == batch_Y.data)num += batch_sizeprint(epoch, total_loss/float(num), corrects.item()/float(batch_size))if __name__ == '__main__':train()

训练的过程如下所示,可以发现损失函数可以稳定下降。但是训练一个 V i s i o n T r a n s f o r m e r \mathrm{Vision \text{ }Transformer} Vision Transformer模型真的是很烧硬件,跟训练一个普通的 C N N \mathrm{CNN} CNN模型相比,训练一个 V i s i o n T r a n s f o r m e r \mathrm{Vision \text{ }Transformer} Vision Transformer模型更加耗时耗力。

更多推荐

Vision Transformer详解(附代码)

本文发布于:2024-02-13 10:06:40,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1758218.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:详解   代码   Vision   Transformer

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!