Gaze360代码解读

编程入门 行业动态 更新时间:2024-10-10 19:19:36

Gaze360<a href=https://www.elefans.com/category/jswz/34/1771412.html style=代码解读"/>

Gaze360代码解读

代码链接 :

论文链接:

Gaze360模型

注视是自然的连续信号。凝视注视和过渡产生一系列凝视方向。为了利用这一点,论文提出了一个基于视频的凝视跟踪模型使用双向长期短期记忆胶囊(LSTM),它提供了一种对序列进行建模的方法,其中一个元素的输出取决于过去和将来的输入。在该论文中,作者利用7个帧的序列来预测中心帧的视线。注意,仅包括单个中央框架的其他序列长度也是可能的。

上图说明了Gaze360模型的体系结构。卷积神经网络(主干)分别处理每个帧中的头部作物,该神经网络产生具有256维的高级特征。这些特征被馈送到具有两层的双向LSTM,这些LSTM消化前向和后向向量中的序列。最后,将这些向量连接起来并通过一个完全连接的层,以产生两个输出:凝视预测和误差分位数估计。

Gaze360模型代码拆分

GazeLSTM

class GazeLSTM(nn.Module):def __init__(self):super(GazeLSTM, self).__init__()self.img_feature_dim = 256  # the dimension of the CNN feature to represent each frameself.base_model = resnet18(pretrained=True)self.base_model.fc2 = nn.Linear(1000, self.img_feature_dim)self.lstm = nn.LSTM(self.img_feature_dim, self.img_feature_dim,bidirectional=True,num_layers=2,batch_first=True)# The linear layer that maps the LSTM with the 3 outputsself.last_layer = nn.Linear(2*self.img_feature_dim, 3)def forward(self, input):base_out = self.base_model(input.view((-1, 3) + input.size()[-2:]))base_out = base_out.view(input.size(0),7,self.img_feature_dim)lstm_out, _ = self.lstm(base_out)lstm_out = lstm_out[:,3,:]output = self.last_layer(lstm_out).view(-1,3)angular_output = output[:,:2]angular_output[:,0:1] = math.pi*nn.Tanh()(angular_output[:,0:1])angular_output[:,1:2] = (math.pi/2)*nn.Tanh()(angular_output[:,1:2])var = math.pi*nn.Sigmoid()(output[:,2:3])var = var.view(-1,1).expand(var.size(0),2)

 首先是model.py中的GazeLSTM部分,首先在初始化函数中定义了图片的特征维度为256,主干网络是resnet18,将输入数据的shape通过view函重塑为(-1,3,input.size()[-1],input.size()[-2])的形状,并输入到resnet1得到base_out,改变其shape为(-1,7,256),传入至双向LSTM中保存t时刻的输出lstm_out[:,3,:],将该输出传入至全连接层并将输出的shape改为(-1,3)。

PinBallLoss

使用神经网络做回归任务,我们使用MSE、MAE作为损失函数,最终得到的输出y通常会被近似为y的期望值,但有些情况下目标值y的空间可能会比较大,只预测一个期望值并不能帮助我们做进一步的决策。

这里介绍一个特殊的损失函数——分位数损失,利用分位数损失我们不需要对数据进行任何先验的处理,就可以轻松做到预测输出y的某一分位数水平值,例如5%分位数或95%分位数,利用这个输出很自然就完成预测输出范围的回归模型。

分位数损失函数的表达式如下图:

 代码中以一个简明的表达方式来表达上式:

class PinBallLoss(nn.Module):def __init__(self):super(PinBallLoss, self).__init__()self.q1 = 0.1self.q9 = 1-self.q1def forward(self, output_o,target_o,var_o):q_10 = target_o-(output_o-var_o)q_90 = target_o-(output_o+var_o)loss_10 = torch.max(self.q1*q_10, (self.q1-1)*q_10)loss_90 = torch.max(self.q9*q_90, (self.q9-1)*q_90)loss_10 = torch.mean(loss_10)loss_90 = torch.mean(loss_90)return loss_10+loss_90

 Gaze360训练函数理解(run.py)

def main():global args, best_errormodel_v = GazeLSTM()model = torch.nn.DataParallel(model_v).cuda()model.cuda()cudnn.benchmark = Trueimage_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])train_loader = torch.utils.data.DataLoader(ImagerLoader(source_path,train_file,transforms.Compose([transforms.RandomResizedCrop(size=224,scale=(0.8,1)),transforms.ToTensor(),image_normalize,])),batch_size=batch_size, shuffle=True,num_workers=workers, pin_memory=True)val_loader = torch.utils.data.DataLoader(ImagerLoader(source_path,val_file,transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),image_normalize,])),batch_size=batch_size, shuffle=True,num_workers=workers, pin_memory=True)criterion = PinBallLoss().cuda()optimizer = torch.optim.Adam(model.parameters(), lr)if test==True:test_loader = torch.utils.data.DataLoader(ImagerLoader(source_path,test_file,transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),image_normalize,])),batch_size=batch_size, shuffle=True,num_workers=workers, pin_memory=True)checkpoint = torch.load(checkpoint_test)model.load_state_dict(checkpoint['state_dict'])angular_error = validate(test_loader, model, criterion)print('Angular Error is',angular_error)for epoch in range(0, epochs):# train for one epochtrain(train_loader, model, criterion, optimizer, epoch)# evaluate on validation setangular_error = validate(val_loader, model, criterion)# remember best angular error in validation and save checkpointis_best = angular_error < best_errorbest_error = min(angular_error, best_error)save_checkpoint({'epoch': epoch + 1,'state_dict': model.state_dict(),'best_prec1': best_error,}, is_best)

 在main函数中调用预先设计的模型,加载训练数据集和验证数据集,使用PinBallLoss损失函数和Adam优化器。判断是否为测试模式,如果是测试模型还需加载测试数据集。在接下来的for循环中则是对每一个epoch进行一次训练,对验证集进行评估,记住最好的角度错误在验证和保存检查点。最后输出每个epoch的凝视估计和分位数误差估计。

更多推荐

Gaze360代码解读

本文发布于:2024-02-14 07:24:59,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1762189.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:代码

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!