pytorch预训练、预训练时报错misssing key(s) in state

编程入门 行业动态 更新时间:2024-10-14 00:26:04

pytorch预训练、预训练<a href=https://www.elefans.com/category/jswz/34/1770019.html style=时报错misssing key(s) in state"/>

pytorch预训练、预训练时报错misssing key(s) in state

使用:

1. 直接加载预训练模型

推荐方式:

net.load_state_dict(torch.load("model.pth"))

2. 加载一部分预训练模型

见上面的链接

3. 微调经典网络

见上面的链接

4. 修改经典网络

见上面的链接

 

另外,对于pytorch来说,如果需要更改字典:

载入变量和权重组成的字典:torch.load

pretrain_dict=torch.load(r'../pretrained_model/se_resnext101_32x4d-3b2fe3d8.pth')

载入新的模型self.basemodel = tvm.resnet50(pretrained=False)

获取模型的变量和权重组成的字典:basemodel.state_dict()

model_dict = self.basemodel.state_dict()

 

如果预训练时报错misssing key(s) in state_dict:

以resnet50为例

原本的模型:

在使用预训练模型时,pytorch的机制会导致模型每层前面加了一个模型名字:

如下的模型多了一个basemodel的字样

因此新旧两个模型的字典对不上号

预训练时就会报错misssing key(s) in state_dict

 

需要更改预训练模型权重字典里面每一层的名字,删除basemodel字样

 

名称不同的时候更换dict中的变量名:

k就是每一层的名字,v是权重

方法1: 

           for k, v in pretrained_dict.items():print("pretrained k,v:",k,v)if not k.find("basemodel") == -1: #if find pretrain model name, delete itname = k[(len("basemodel")+1):]   # remove `module.`model_dict[name] = velse:name = kprint("delete last layer without pretrained model name")print("new_name:",name)

方法2:

            pretrained_dict = {k[(len("basemodel")+1):]: v for k, v in pretrained_dict.items() if k[(len("basemodel")+1):] in model_dict}  #去除上次预训练时模型的变量前面添加的”basemodel”字样

 

完成后可以查看变量和权重字典中的某一个:

下面的代码查看了pretrainmodel的字典和newmodel的字典

            #-----------------------------------------------------self.basemodel = tvm.resnet50(pretrained=False)#             self.basemodel.load_state_dict(torch.load(r'../pretrained_model/resnet50-19c8e357.pth'))count=0model_dict=torch.load(r'../pretrained_model/resnet50-19c8e357.pth')for k, v in model_dict.items():count+=1if count==10:               print("resnet50_dict resnet50 k,v:",k,v)#-----------------------------------------------------#-----------------------------------------------------count=0for k, v in pretrained_dict.items():count+=1if count==10:               print("pretrained_dict resnet50 k,v:",k,v)#-----------------------------------------------------

 

也可以打印整个字典:

for k, v in pretrained_dict.items():print("pretrained k,v:",k,v)

 

参考文献:

更多推荐

pytorch预训练、预训练时报错misssing key(s) in state

本文发布于:2024-03-23 18:34:27,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1741443.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:时报   pytorch   misssing   state   key

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!