pytorch保存和读取模型,定义save

编程入门 行业动态 更新时间:2024-10-25 22:31:58

pytorch保存和读取<a href=https://www.elefans.com/category/jswz/34/1771358.html style=模型,定义save"/>

pytorch保存和读取模型,定义save

保存模型:

# 定义函数,保存最新和最佳模型
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):"""Save the training model"""torch.save(state, filename)if is_best:shutil.copyfile(filename, 'model_best.pth.tar')
# 调用时:
save_checkpoint({'state_dict': model.state_dict(),'best_prec1': best_prec1,}, is_best, filename=os.path.join(args.save_dir, 'model.th'))

读取模型:

# 调用保存的最佳模型的准确率输出resume = 'model_best.pth.tar'checkpoint = torch.load(resume)best_acc1 = checkpoint['best_prec1']print('best acc:{0}'.format(best_acc1))

更多推荐

pytorch保存和读取模型,定义save

本文发布于:2024-02-25 15:45:16,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1699478.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:模型   定义   pytorch   save

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!