SCAR的pytorch实现

编程入门 行业动态 更新时间:2024-10-26 18:20:17

SCAR的<a href=https://www.elefans.com/category/jswz/34/1769961.html style=pytorch实现"/>

SCAR的pytorch实现

本文所实现的网络来源于SCAR:Spatial-/Channel-wise Attention Regression Networks for Crowd Counting(Neurocompting 2019)

import torch;from torchvision import models
from torchvision.models import vgg16
import warnings;from torch import nn
warnings.filterwarnings("ignore")
vgg16 = vgg16(pretrained=True)
def initialize_weights(models):for model in models:real_init_weights(model)
import warnings
warnings.filterwarnings("ignore")
def real_init_weights(m):if isinstance(m, list):for mini_m in m:real_init_weights(mini_m)else:if isinstance(m, nn.Conv2d):nn.init.normal_(m.weight, std=0.01)if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):m.weight.data.normal_(0.0, std=0.01)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m,nn.Module):for mini_m in m.children():real_init_weights(mini_m)else:print( m )
class SCAR(torch.nn.Module):def __init__(self,loadwieght=False):super(SCAR,self).__init__()self.vgg10=vgg10if loadwieght==False:mod = models.vgg16(pretrained=True)initialize_weights(self.modules())self.vgg10.load_state_dict(mod.features[0:23].state_dict())self.dconv1=torch.nn.Conv2d(512,512,3,dilation=2,stride=1,padding=2)self.dconv2 = torch.nn.Conv2d(512, 512, 3, dilation=2, stride=1,padding=2)self.dconv3 = torch.nn.Conv2d(512, 512, 3, dilation=2, stride=1,padding=2)self.dconv4 = torch.nn.Conv2d(512, 256, 3, dilation=2, stride=1,padding=2)self.dconv5 = torch.nn.Conv2d(256, 128, 3, dilation=2, stride=1,padding=2)self.dconv6 = torch.nn.Conv2d(128, 64, 3, dilation=2, stride=1,padding=2)self.relu = torch.nn.functional.reluself.SAM=SAM()self.CAM=CAM()self.finalconv=torch.nn.Conv2d(128,1,1)self.upsample=torch.nn.functional.upsampledef forward(self,x):y=self.vgg10(x)y=self.relu(self.dconv1(y))y = self.relu(self.dconv1(y))y = self.relu(self.dconv2(y))y = self.relu(self.dconv3(y))y = self.relu(self.dconv4(y))y = self.relu(self.dconv5(y))y = self.relu(self.dconv6(y))y_sa=self.SAM(y)y_ca=self.CAM(y)y=torch.cat((y_ca,y_sa),dim=1)y=self.finalconv(y)y=self.upsample(y,scale_factor=8)#由于进行了三次池化 因此8倍上取样return yvgg10=torch.nn.Sequential(torch.nn.Conv2d(3,64,3,stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(64, 64, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(2,2),torch.nn.Conv2d(64, 128, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(128, 128, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(2,2),torch.nn.Conv2d(128, 256, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(256, 256, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(256, 256, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(2,2),  #尝试不进行下采样以达到不进行上采样torch.nn.Conv2d(256, 512, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(512, 512, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(512, 512, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),#torch.nn.MaxPool2d(2),
)class SAM(torch.nn.Module):def __init__(self):super(SAM,self).__init__()# SAM不改变输入到SAM中的x的shapeself.q=torch.nn.Conv2d(64,64,1)self.k = torch.nn.Conv2d(64, 64, 1)self.v=torch.nn.Conv2d(64, 64, 1)self.lamda=torch.nn.Conv2d(64,64,1)self.bn=torch.nn.BatchNorm2d(64)def forward(self,x):N, C, H, W = x.size()q=self.q(x).view((N,-1,H*W)).permute(0,2,1) # HW*Ck=self.q(x).view((N,-1,H*W))v=self.v(x).view((N,-1,H*W))mid=torch.bmm(q,k)attention=torch.nn.functional.softmax(mid,dim=-1)# HW*HWy=torch.bmm(v,attention)y=y.view((N,C,H,W))y=self.lamda(y)+xreturn yclass CAM(torch.nn.Module):def __init__(self):super(CAM,self).__init__()self.conv1=torch.nn.Conv2d(64,64,1)self.conv2 = torch.nn.Conv2d(64, 64, 1)self.bn = torch.nn.BatchNorm2d(64)def forward(self,x):N, C, H, W = x.size()q=self.conv1(x).view(N,C,-1)# C*HWk=self.conv1(x).view(N,-1,C) # HW*Cattention_pre=torch.bmm(q,k)# C*Cattention=torch.nn.functional.softmax(attention_pre,dim=-1)v=x.view(N,C,-1)cl2=torch.bmm(attention,v).view((N,C,H,W))cfinal=self.conv2(cl2)+xreturn cfinal

更多推荐

SCAR的pytorch实现

本文发布于:2023-11-15 19:52:00,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1605344.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:SCAR   pytorch

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!