模型训练

编程入门 行业动态 更新时间:2024-10-09 16:31:30

<a href=https://www.elefans.com/category/jswz/34/1771358.html style=模型训练"/>

模型训练

遇到报错one of the variables needed for gradient computation has been modified by an inplace operation。意思是对输入x原地操作(inplace operation),一个变量在反向传播过程中被修改了,而不是按照预期的版本(version 0)更新,导致梯度不正确。

使用这句代码定位报错位置

torch.autograd.set_detect_anomaly(True)

定位到报错后可以修改代码,这是我原来的forword代码,可以看到 x[i] = self.dilate_attention[i](qkv[i][0], qkv[i][1], qkv[i][2])这一句代码,将原来的x值原地替换,不能这样做

for i in range(self.num_dilation):x[i] = self.dilate_attention[i](qkv[i][0], qkv[i][1], qkv[i][2])
x = x.permute(1, 2, 3, 0, 4).reshape(B, H, W, C) #3,1,24,24,171-->1,24,24,513
x = self.proj(x)
x = self.proj_drop(x)

我们需要新建变量将这些值存起来最后赋值。新建一个列表,将值存入,最后使用cat统一(会丢失一个维度补上),然后就不报错了

x_i=[]
for i in range(self.num_dilation):x_i.append(self.dilate_attention[i](qkv[i][0], qkv[i][1], qkv[i][2]))
x =torch.cat(x_i,dim=0)#3,24,24,171
x = x.unsqueeze(1)#3,1,24,24,171
x = x.permute(1, 2, 3, 0, 4).reshape(B, H, W, C) #3,1,24,24,171-->1,24,24,513
x = self.proj(x)
x = self.proj_drop(x)

更多推荐

模型训练

本文发布于:2023-11-16 17:59:45,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1630753.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:模型

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!