参数"/>
PyTorch冻结已训练网络参数
在训练多层神经网络中,我们发现由于网络参数过多,网络收敛的条件有点苛刻。因此,分层训练的方式在日常生活中常常被用到。所谓的分层训练,顾名思义,即多层网络中,我们先训练好第一层网络,固定其参数,去训练第二层网络,当第二层网络训练完毕,就固定前两层参数,去训练第三层网络,以此类推。下面展现代码的实现方式。
# 网络模型
class Test(torch.nn.Module):def __init__(self,): super(TwISTA, self).__init__()…………………def forward(self, y, max_itr):…………………network = Test()
- 首先,搭建一层神经网络将其参数调到最优,并把网络参数保存下来。
torch.save(network.state_dict(), '/home/data_ssd/Test.pth') # 保存到指定目录:/home/data_ssd/,文件名称格式:Test.pth
- 其次,搭建下一层网络,调通、运行代码。用以下代码查看网络模型名称。
model_dict = network.state_dict()
for k, v in model_dict.items(): # 查看自己网络参数各层名称、数值print(k) # 输出网络参数名字# print(v) # 输出网络参数数值
运行结果:
fcs.0.thr
fcs.0.beta
fcs.0.Tw_alpha
fcs.0._W.weight
fcs.0._S.weight
fcs.1.thr
fcs.1.beta
fcs.1.Tw_alpha
fcs.1._W.weight
fcs.1._S.weight
我们可以看到,以“fcs.0”开头的参数是第一层网络,以“fcs.1”开头的参数是第二层网络。
- 加载第一层已训练好的参数,关闭梯度。
pretrained_dict = torch.load('/home/data_ssd/Test.pth') # 到相应目录加载刚刚保存的文件(网络参数)
model_dict['fcs.0.thr'] = pretrained_dict['fcs.0.thr']
model_dict['fcs.0.beta'] = pretrained_dict['fcs.0.beta']
model_dict['fcs.0.Tw_alpha'] = pretrained_dict['fcs.0.Tw_alpha']
model_dict['fcs.0._W.weight'] = pretrained_dict['fcs.0._W.weight']
model_dict['fcs.0._S.weight'] = pretrained_dict['fcs.0._S.weight']
# 第一层网络fcs.0不再参与训练,关闭梯度
for name, param in network.named_parameters():# print(name)if "fcs.0" in name:param.requires_grad = False# 查看是否关闭成功
for name, param in network.named_parameters():if param.requires_grad:print("requires_grad: True ", name)else:print("requires_grad: False ", name)
运行结果:
requires_grad: False fcs.0.thr
requires_grad: False fcs.0.beta
requires_grad: False fcs.0.Tw_alpha
requires_grad: False fcs.0._W.weight
requires_grad: False fcs.0._S.weight
requires_grad: True fcs.1.thr
requires_grad: True fcs.1.beta
requires_grad: True fcs.1.Tw_alpha
requires_grad: True fcs.1._W.weight
requires_grad: True fcs.1._S.weight
从结果中可以看出,以“fcs.0”开头的参数(即第一层网络参数)均已关闭梯度。
- 最后,在优化器中屏蔽掉第一层网络参数不再参与训练,仅仅训练新添加的网络参数。
opt = torch.optim.Adam(filter(lambda p: p.requires_grad, network.parameters()), lr=adam_lr) # 过滤掉没有梯度的参数
以此类推,分别优化多层网络
更多推荐
PyTorch冻结已训练网络参数
发布评论