superpoint的onnx转换全流程

编程入门 行业动态 更新时间:2024-10-28 10:34:11

superpoint的onnx转换全<a href=https://www.elefans.com/category/jswz/34/1770115.html style=流程"/>

superpoint的onnx转换全流程

Tensorflow转ONNX - it610

转PB文件(Graph Freezing)

如上所述,常规TF的模型导出方法会将网络信息与权重信息分开存储在不同文件当中,这在部署时候不是很方便。官方提供了一种Freeze Graph的方式,用于将模型相关信息统统打包到一个*.pb文件当中。

官方提供了相关工具freeze_graph,一般安装完TensorFlow后会自动添加到用户PATH相应的bin目录下,如果没有找到的话可以去TensorFlow源码tensorflow/python/tools/free_graph.py这个位置去找一下,或者直接通过命令行导入module的方式调用。

举例如下,如果有多个输出节点,用逗号隔开:

# 1.直接调用
freeze_graph --input_graph=/home/mnist-tf/graph.proto \--input_checkpoint=/home/mnist-tf/ckpt/model.ckpt \--output_graph=/tmp/frozen_graph.pb \--output_node_names=fc2/add \--input_binary=True# 2. 通过调用module的方式
python -m tensorflow.python.tools.freeze_graph \--input_graph=my_checkpoint_dir/graphdef.pb \--input_binary=true \--output_node_names=output \--input_checkpoint=my_checkpoint_dir \--output_graph=tests/models/fc-layers/frozen.pb

其中有一个比较困扰的点在于需要准确的知道输出节点的节点名称,我的做法是通过tf.get_default_graph().as_graph_def().node来得到各节点信息,然后再从中查看具体的输出节点名。

print([tensor.name for tensor in tf.get_default_graph().as_graph_def().node])

onnx                  1.10.1
onnx-simplifier       0.3.6
onnxoptimizer         0.2.6
onnxruntime           1.8.1

完整的SSD-tf版本转换为onnx示例:
.ipynb

本文所示示例为superpoint


magic-point网络训练模拟数据input(h,w,c)=(120,160,1)
magic-point网络导出coco数据的GT,使用input(h,w,c)=(240,320,1)
super-point网络训练coco数据input(h,w,c)=(240,320,1)
super-point网络inference的时候使用input(h,w,c)=(480,640,1)


项目作者给出了网络三个阶段保存的模型
MagicPoint (synthetic) mp_synth-v11
MagicPoint (COCO) mp_synth-v11_ha1_trained
SuperPoint (COCO) sp_v6
我们通过导出onnx已经确定了他们的input和output
mp_synth-v11
input: image
output: pred  logits  prob  prob_nms
mp_synth-v11_ha1_trained
input: image
output: pred  logits  prob  prob_nms
sp_v6
input: image
output: descriptors  prob_nms  prob  descriptors_raw pred logits

参考:
1.模型checkpoint格式


1.5 如果你的模型是frozen_inference_graph.pb格式,就需要且不知道input output的name,
就比较麻烦,这里我们不考虑这种情况
.md


2.将checkpoint格式转换为saved model格式(pb)
.py


3.将saved model格式转换为onnx

记得选择适合自己网络的opset
可参考:
.4.0/docs/OperatorKernels.md
.md

python -m tf2onnx.convert --saved-model saved_models/model/ --output saved_models/model/saved_model.onnx --opset 11


4.得到onnx后发现,无法正常显示网络输入和网络输出


5.安装onnx
通过onnx打印网络所有的input和output

import onnxname = "X.onnx"
onnx_file = name
model = onnx.load(onnx_file)
for m in range(len(model.graph.input)):print(model.graph.input[m].name)print("\n")for m in range(len(model.graph.output)):print(model.graph.output[m].name)


6.下载onnx-simplifier


5.通过onnx-simplifier,指定input=(N,H,W,C) ps: data_format='channels_last'
导出simplifier后的onnx,可看到网络每层的shape

import onnx
from onnxsim import simplifyonnx_file = "mp_synth-v11.onnx"
sim_onnx_path = "mp_synth-v11__simplify.onnx"model = onnx.load(onnx_file)
onnx.checker.check_model(model)# print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))# 列表里的 shape 改成自己对应的
model_simp, check = simplify(model, input_shapes={"image":[1, 120, 160, 1]}) onnx.save(model_simp, sim_onnx_path)print("Simplify onnx done !")


参考:

其它一些参考

ONNX学习笔记 - 知乎

如何在Netron中看到每一层的shape_BokyLiu的博客-CSDN博客

【模型加速】PointPillars模型TensorRT加速实验(2)_昌山小屋的博客-CSDN博客

onnx2pytorch和onnx-simplifer新版介绍 - 知乎

torch.onnx — PyTorch 1.9.1 documentation

torch.onnx — PyTorch 1.9.1 documentation

GitHub - microsoft/onnxjs: ONNX.js: run ONNX models using JavaScript

GitHub - microsoft/onnxjs-demo: demos to show the capabilities of ONNX.js

superpoint导出pb

import os
os.environ['CUDA_VISIBLE_DEVICES'] = "7"
import yaml
import argparse
import logging
from pathlib import Pathlogging.basicConfig(format='[%(asctime)s %(levelname)s] %(message)s',datefmt='%m/%d/%Y %H:%M:%S',level=logging.INFO)
import tensorflow as tf  # noqa: E402from superpoint.models import get_model  # noqa: E402if __name__ == '__main__':with open("config.yml", 'r') as f:config = yaml.load(f)config['model']['data_format'] = 'channels_last'def mkdir_os(path):if not os.path.exists(path):os.makedirs(path)export_root_dir = "onnx/"mkdir_os(export_root_dir)export_dir = "onnx/pb_model_onlytwo_NoneH640W480C1"checkpoint_path = "model"'''data_shapesuperpoint/models/base_model.py# Prediction network with feed_dictif self.data_shape is None:self.data_shape = {i: spec['shape'] for i, spec in self.input_spec.items()}self.pred_in = {i: tf.placeholder(spec['type'], shape=self.data_shape[i], name=i)for i, spec in self.input_spec.items()}self._pred_graph(self.pred_in)'''with get_model(config['model']['name'])(data_shape={'image': [None, 640, 480, 1]},**config['model']) as net:net.load(str(checkpoint_path))# tf.saved_model.simple_save(#         net.sess,#         str(export_dir),#         inputs=net.pred_in,#         outputs=net.pred_out)# name:logits# name:prob# name:descriptors_raw# name:descriptors# name:prob_nms# name:predtf.saved_model.simple_save(net.sess,str(export_dir),inputs=net.pred_in,outputs={"prob_nms": net.pred_out["prob_nms"],"descriptors": net.pred_out["descriptors"],})

此时Batch还是1

后面onnx流程中就被固定住了

所以在:

pd转onnx的时候:

Dynamic Input Reshape Incorrect · Issue #1640 · onnx/tensorflow-onnx · GitHub

Dynamic Input Reshape Incorrect · Issue #8591 · microsoft/onnxruntime · GitHub

更多推荐

superpoint的onnx转换全流程

本文发布于:2024-02-11 16:21:28,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1681973.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:流程   superpoint   onnx

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!