PyTorch冻结已训练网络参数

编程入门 行业动态 更新时间:2024-10-25 04:27:24

PyTorch冻结已训练网络<a href=https://www.elefans.com/category/jswz/34/1771441.html style=参数"/>

PyTorch冻结已训练网络参数

在训练多层神经网络中,我们发现由于网络参数过多,网络收敛的条件有点苛刻。因此,分层训练的方式在日常生活中常常被用到。所谓的分层训练,顾名思义,即多层网络中,我们先训练好第一层网络,固定其参数,去训练第二层网络,当第二层网络训练完毕,就固定前两层参数,去训练第三层网络,以此类推。下面展现代码的实现方式。

# 网络模型
class Test(torch.nn.Module):def __init__(self,): super(TwISTA, self).__init__()…………………def forward(self, y, max_itr):…………………network = Test()   
  • 首先,搭建一层神经网络将其参数调到最优,并把网络参数保存下来。
torch.save(network.state_dict(), '/home/data_ssd/Test.pth')  # 保存到指定目录:/home/data_ssd/,文件名称格式:Test.pth
  • 其次,搭建下一层网络,调通、运行代码。用以下代码查看网络模型名称。
model_dict = network.state_dict()
for k, v in model_dict.items():  # 查看自己网络参数各层名称、数值print(k)  # 输出网络参数名字# print(v)  # 输出网络参数数值

运行结果:

fcs.0.thr
fcs.0.beta
fcs.0.Tw_alpha
fcs.0._W.weight
fcs.0._S.weight
fcs.1.thr
fcs.1.beta
fcs.1.Tw_alpha
fcs.1._W.weight
fcs.1._S.weight

我们可以看到,以“fcs.0”开头的参数是第一层网络,以“fcs.1”开头的参数是第二层网络。

  • 加载第一层已训练好的参数,关闭梯度。
pretrained_dict = torch.load('/home/data_ssd/Test.pth')  # 到相应目录加载刚刚保存的文件(网络参数)
model_dict['fcs.0.thr'] = pretrained_dict['fcs.0.thr']
model_dict['fcs.0.beta'] = pretrained_dict['fcs.0.beta']
model_dict['fcs.0.Tw_alpha'] = pretrained_dict['fcs.0.Tw_alpha']
model_dict['fcs.0._W.weight'] = pretrained_dict['fcs.0._W.weight']
model_dict['fcs.0._S.weight'] = pretrained_dict['fcs.0._S.weight']
# 第一层网络fcs.0不再参与训练,关闭梯度
for name, param in network.named_parameters():# print(name)if "fcs.0" in name:param.requires_grad = False# 查看是否关闭成功
for name, param in network.named_parameters():if param.requires_grad:print("requires_grad: True ", name)else:print("requires_grad: False ", name)

运行结果:

requires_grad: False  fcs.0.thr
requires_grad: False  fcs.0.beta
requires_grad: False  fcs.0.Tw_alpha
requires_grad: False  fcs.0._W.weight
requires_grad: False  fcs.0._S.weight
requires_grad: True  fcs.1.thr
requires_grad: True  fcs.1.beta
requires_grad: True  fcs.1.Tw_alpha
requires_grad: True  fcs.1._W.weight
requires_grad: True  fcs.1._S.weight

从结果中可以看出,以“fcs.0”开头的参数(即第一层网络参数)均已关闭梯度。

  • 最后,在优化器中屏蔽掉第一层网络参数不再参与训练,仅仅训练新添加的网络参数。
opt = torch.optim.Adam(filter(lambda p: p.requires_grad, network.parameters()), lr=adam_lr)  # 过滤掉没有梯度的参数

以此类推,分别优化多层网络

更多推荐

PyTorch冻结已训练网络参数

本文发布于:2024-03-08 14:49:36,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1721181.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:参数   网络   PyTorch

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!