您当前的位置: 首页 >  Python
  • 4浏览

    0关注

    477博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

Python onnx 模型打印显示所有节点及查看相互关系

高精度计算机视觉 发布时间:2021-10-22 01:21:32 ,浏览量:4

最近使用onnx时,想把所有的节点的信息和权重参数显示出来,找了一下没找到类似的源码,官方介绍的pythonAPI都是些什么加载,保存,转换之类的,没有详细介绍怎么使用onnx分析模型的,只好自己写一个。

其实很简单,我只列些最基本的,具体分析还得看个人的需要,

import onnx

model_in_file = 'yolov5s-sim.onnx'

if __name__ == "__main__":
    model = onnx.load(model_in_file)

    nodes = model.graph.node    
    nodnum = len(nodes) # 205

    for nid in range(nodnum):
        if (nodes[nid].output[0] == 'stride_32'):
            print('Found stride_32: index = ', nid)
        else:
            print(nodes[nid].output)

    inits = model.graph.initializer
    ininum = len(inits)  #124

    for iid in range(ininum):
        el = inits[iid]
        print('name:', el.name, ' dtype:', el.data_type, ' dim:', el.dims) 
        # el.raw_data for weights and biases

    print(model.graph.output) # display all the output nodes

print('Done')

比如,我这显示出来的模型中的节点是这样的,

[input: "data"
input: "model.0.conv.weight"
input: "model.0.conv.bias"
output: "122"
name: "Conv_0"
op_type: "Conv"
attribute {
  name: "dilations"
  ints: 1
  ints: 1
  type: INTS
}
attribute {
  name: "group"
  i: 1
  type: INT
}
attribute {
  name: "kernel_shape"
  ints: 6
  ints: 6
  type: INTS
}
attribute {
  name: "pads"
  ints: 2
  ints: 2
  ints: 2
  ints: 2
  type: INTS
}
attribute {
  name: "strides"
  ints: 2
  ints: 2
  type: INTS
}
, 
。。。。。。
, input: "325"
input: "model.24.m.2.weight"
input: "model.24.m.2.bias"
output: "376"
name: "Conv_234"
op_type: "Conv"
attribute {
  name: "dilations"
  ints: 1
  ints: 1
  type: INTS
}
attribute {
  name: "group"
  i: 1
  type: INT
}
attribute {
  name: "kernel_shape"
  ints: 1
  ints: 1
  type: INTS
}
attribute {
  name: "pads"
  ints: 0
  ints: 0
  ints: 0
  ints: 0
  type: INTS
}
attribute {
  name: "strides"
  ints: 1
  ints: 1
  type: INTS
}
, input: "376"
input: "398"
output: "399"
name: "Reshape_251"
op_type: "Reshape"
, input: "399"
output: "stride_32"
name: "Transpose_252"
op_type: "Transpose"
attribute {
  name: "perm"
  ints: 0
  ints: 1
  ints: 3
  ints: 4
  ints: 2
  type: INTS
}
]

可以看出,在onnx模型中,结点之间用逗号隔开,输出和输出都分别列出,比如我这里最后一个节点的信息是

>>> nodes[204]
input: "399"
output: "stride_32"
name: "Transpose_252"
op_type: "Transpose"
attribute {
  name: "perm"
  ints: 0
  ints: 1
  ints: 3
  ints: 4
  ints: 2
  type: INTS
}

>>> nodes[203]
input: "376"
input: "398"
output: "399"
name: "Reshape_251"
op_type: "Reshape"

其中node[204]input表示输入节点是399,也就是node[203]的输出;node[204]输出名称是stride_32,部署时就用这个名称来提取最终结果。像netron这样的工具,就是根据这些node之间的关系来绘制网络图的。

权重分析这里就不展开了。

关注
打赏
1659856567
查看更多评论
立即登录/注册

微信扫码登录

0.1397s