时报错misssing key(s) in state"/>
pytorch预训练、预训练时报错misssing key(s) in state
使用:
1. 直接加载预训练模型
推荐方式:
net.load_state_dict(torch.load("model.pth"))
2. 加载一部分预训练模型
见上面的链接
3. 微调经典网络
见上面的链接
4. 修改经典网络
见上面的链接
另外,对于pytorch来说,如果需要更改字典:
载入变量和权重组成的字典:torch.load
pretrain_dict=torch.load(r'../pretrained_model/se_resnext101_32x4d-3b2fe3d8.pth')
载入新的模型self.basemodel = tvm.resnet50(pretrained=False)
获取模型的变量和权重组成的字典:basemodel.state_dict()
model_dict = self.basemodel.state_dict()
如果预训练时报错misssing key(s) in state_dict:
以resnet50为例
原本的模型:
在使用预训练模型时,pytorch的机制会导致模型每层前面加了一个模型名字:
如下的模型多了一个basemodel的字样
因此新旧两个模型的字典对不上号
预训练时就会报错misssing key(s) in state_dict
需要更改预训练模型权重字典里面每一层的名字,删除basemodel字样
名称不同的时候更换dict中的变量名:
k就是每一层的名字,v是权重
方法1:
for k, v in pretrained_dict.items():print("pretrained k,v:",k,v)if not k.find("basemodel") == -1: #if find pretrain model name, delete itname = k[(len("basemodel")+1):] # remove `module.`model_dict[name] = velse:name = kprint("delete last layer without pretrained model name")print("new_name:",name)
方法2:
pretrained_dict = {k[(len("basemodel")+1):]: v for k, v in pretrained_dict.items() if k[(len("basemodel")+1):] in model_dict} #去除上次预训练时模型的变量前面添加的”basemodel”字样
完成后可以查看变量和权重字典中的某一个:
下面的代码查看了pretrainmodel的字典和newmodel的字典
#-----------------------------------------------------self.basemodel = tvm.resnet50(pretrained=False)# self.basemodel.load_state_dict(torch.load(r'../pretrained_model/resnet50-19c8e357.pth'))count=0model_dict=torch.load(r'../pretrained_model/resnet50-19c8e357.pth')for k, v in model_dict.items():count+=1if count==10: print("resnet50_dict resnet50 k,v:",k,v)#-----------------------------------------------------#-----------------------------------------------------count=0for k, v in pretrained_dict.items():count+=1if count==10: print("pretrained_dict resnet50 k,v:",k,v)#-----------------------------------------------------
也可以打印整个字典:
for k, v in pretrained_dict.items():print("pretrained k,v:",k,v)
参考文献:
更多推荐
pytorch预训练、预训练时报错misssing key(s) in state
发布评论