蓝图分离卷积BSConv 学习笔记 (附代码)

编程入门 行业动态 更新时间:2024-10-28 21:28:45

蓝图分离<a href=https://www.elefans.com/category/jswz/34/1765938.html style=卷积BSConv 学习笔记 (附代码)"/>

蓝图分离卷积BSConv 学习笔记 (附代码)

论文地址:.13549

代码地址:

1.是什么?

BSConv是深度可分离卷积DSConv的升级版本,它更好地利用内核内部相关性来实现高效分离。具体而言,BSConvU是将一个标准的卷积分解为1x1卷积(PW)和一个逐通道卷积,是深度可分离卷积(DSConv—逐通道、逐点)的逆向版本。此外,BSConv还有一个变体操作—BSConvS。

2.为什么?

受启发于预训练模型的核属性的量化分析:深度方向的强相关性。作者提出一种“蓝图分离卷积”(blueprint separable convolutions, BSConv)作为高效CNN的构建模块。

基于该发现,作者构建了一套理论基础并用于推导如何采用标准OP进行高效实现。更进一步,所提方法为深度分离卷积的应用(深度分离卷积已成为当前主流网络架构的核心模块)提供了系统的理论推导、可解释性以及原因分析。最后,作者揭示了基于深度分离卷积的网络架构(如MobileNet)隐式的依赖于跨核相关性;而所提BSConv则基于核内相关性,故可以为常规卷积提供一种更有效的拆分。

作者通过充分的实验(大尺度分类与细粒度分类)验证了所提BSConv可以明显的提升MobileNet以及其他基于深度分离卷积的架构的性能,而不会引入额外的复杂度。对于细粒度问题,所提方法取得13.7%的性能提升;在ImageNet分类任务,BSConv在ResNet的“即插即用”取得了9.5%的性能提升。

3.怎么样?

3.1网络结构

在标准卷积中,每个卷积层对输入张量进行变化得到输出张量,相应的卷积核,每个卷积核的尺寸为M*K*K。相应的公式可以描述为(图示见下图):

这些卷积核将通过反向传播方式进行优化训练。

预训练CNN中的卷积核可以通过一个模板以及M个因子进行近似。该发现也是本文提的(blueprint separable convolutions,BSConv)的驱动源泉,它滤波器卷积提供另一种定义方式。

尽管上述定义为滤波器添加了硬约束,但作者通过实验表明:相比标准卷积,所提方法可以达到相同甚至更优的性能。另外,需要注意的是:标准卷积的可训练参数为,而所提方法仅具有个可训练参数。

3.2 Variants and Implementations

前面已经介绍了BSConv的卷积核信息,它的权值可以组合为矩阵。此时根据W的学习方式不同,又有两种不同的变种。

  • BSConv-U:在大多场景下,权值W可以不进行任何约束进行训练学习。此时,公式(1)可以转换为如下公式。此时,常规卷积1*1可以解耦为卷积K*K深度卷积,见下图。

 对于这种形式的CNN架构,作者发现:权值W在行方向存在高度相关性。这为进一步的正则化与参数降低提供了可能。也就引出了下面将要介绍的BSConv-S变种。

  • BSConv-S:基于前述发现,作者对权值W进行低秩分解:。其中.而后,经过一些列的变换处理,最终BSConv的公式转换为下面的公式。此时,常规卷积可以解耦为1*1卷积+1*1卷积+K*K深度卷积,见上图。

3.3  Discussion

前面已经介绍了BSConv的两种变种,这里将对比分析一下上述两种变种与已有模块的区别和联系。

  • BSConv-U是一种逆深度分类卷积。两者的出发点有一些区别:DSConv实施了跨核相关性,而BSConv-U则实施了核内相关性。已有研究表明:尽管跨核相关性与核内相关性都是有效假设,但核内相关性更有优势,对于高效分离更具潜力。需要注意的是:卷积后不跟激活函数或者规范化函数。

  • BSConv-S是一种具有正交正则化功能的转移线性瓶颈模块。线性瓶颈层是当前高效网络MobileNet的核心模块,它由pointwise、depthwise、pointwise级联构成,而BSConv-S则是由pointwise, pointwise, depthwise级联构成。从中可以看到两者之间的紧密联系。此外,需要注意的是:与前者相同,激活函数与规范化函数不在模块内添加

3.4代码实现

class BSConvU(torch.nn.Sequential):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, padding_mode="zeros", with_bn=False, bn_kwargs=None):super().__init__()# check argumentsif bn_kwargs is None:bn_kwargs = {}# pointwiseself.add_module("pw", torch.nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=(1, 1),stride=1,padding=0,dilation=1,groups=1,bias=False,))# batchnormif with_bn:self.add_module("bn", torch.nn.BatchNorm2d(num_features=out_channels, **bn_kwargs))# depthwiseself.add_module("dw", torch.nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,groups=out_channels,bias=bias,padding_mode=padding_mode,))class BSConvS(torch.nn.Sequential):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, padding_mode="zeros", p=0.25, min_mid_channels=4, with_bn=False, bn_kwargs=None):super().__init__()# check argumentsassert 0.0 <= p <= 1.0mid_channels = min(in_channels, max(min_mid_channels, math.ceil(p * in_channels)))if bn_kwargs is None:bn_kwargs = {}# pointwise 1self.add_module("pw1", torch.nn.Conv2d(in_channels=in_channels,out_channels=mid_channels,kernel_size=(1, 1),stride=1,padding=0,dilation=1,groups=1,bias=False,))# batchnormif with_bn:self.add_module("bn1", torch.nn.BatchNorm2d(num_features=mid_channels, **bn_kwargs))# pointwise 2self.add_module("pw2", torch.nn.Conv2d(in_channels=mid_channels,out_channels=out_channels,kernel_size=(1, 1),stride=1,padding=0,dilation=1,groups=1,bias=False,))# batchnormif with_bn:self.add_module("bn2", torch.nn.BatchNorm2d(num_features=out_channels, **bn_kwargs))# depthwiseself.add_module("dw", torch.nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,groups=out_channels,bias=bias,padding_mode=padding_mode,))def _reg_loss(self):W = self[0].weight[:, :, 0, 0]WWt = torch.mm(W, torch.transpose(W, 0, 1))I = torch.eye(WWt.shape[0], device=WWt.device)return torch.norm(WWt - I, p="fro")class BSConvS_ModelRegLossMixin():def reg_loss(self, alpha=0.1):loss = 0.0for sub_module in self.modules():if hasattr(sub_module, "_reg_loss"):loss += sub_module._reg_loss()return alpha * loss

参考:

深度分离卷积重思考:BSConv

轻量化神经网络卷积设计研究进展

更多推荐

蓝图分离卷积BSConv 学习笔记 (附代码)

本文发布于:2023-11-17 11:10:17,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1643633.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:卷积   蓝图   学习笔记   代码   BSConv

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!