admin管理员组文章数量:1582046
在做的实验基础代码是用的 Pytorch-Lightning 中的训练器 Trainer 进行训练
- 首先需要保存的训练后的模型参数,保存 checkpoint 断点
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=args.ckpt_dir + "/" + args.model_type,
filename=model_savename + "---{epoch}---" + dt_string +'-'+str(args.use_img)+str(args.use_att)+str(args.use_date)+str(args.use_trends)+'RNN3_5',#str(note)
monitor="val_mae",
mode="min",#这里实验效果是越小越好,所以是“min”
save_top_k=5,#1
)
print(checkpoint_callback.best_model_path)#打印出效果最好的模型参数存储的路径
这里保存了效果前五的模型,这里的实验效果是越小越好,并打印出效果最好的模型参数存储路径
- 在训练器 Trainer 里加载之前保存的最佳模型
trainer.fit(model, train_dataloaders=trainloader, val_dataloaders=testloader,ckpt_path='自己替换成最佳模型参数所存在的路径.ckpt')
主要是 trainer.fit() 函数里,ckpt_path 参数所提供的效果,输入 ckpt 文件路径(从这里文件恢复训练)
参考博客:https://blog.csdn/qq_27135095/article/details/122635743?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522167583461916800180668936%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id=167583461916800180668936&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2allfirst_rank_ecpm_v1~rank_v31_ecpm-1-122635743-null-null.142%5Ev73%5Econtrol,201%5Ev4%5Eadd_ask,239%5Ev1%5Einsert_chatgpt&utm_term=pl%20trainer%20%E6%98%AF%E5%A6%82%E4%BD%95%E8%AE%AD%E7%BB%83%E7%9A%84&spm=1018.2226.3001.4187
版权声明:本文标题:深度学习如何恢复训练?中断的训练如何接着之前保存的 ckpt 参数继续训练?Pytorch-Lightning Trainer 内容由热心网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:https://www.elefans.com/dianzi/1727892513a1136456.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论