|
@@ -0,0 +1,48 @@
|
|
|
+#!/usr/bin/env python
|
|
|
+# -*- coding:utf-8 -*-
|
|
|
+
|
|
|
+# Created by zhenqin.
|
|
|
+# User: zhenqin
|
|
|
+# Date: 2024-04-25
|
|
|
+# Time: 下午1:58
|
|
|
+# Verdor: yiidata.com
|
|
|
+#
|
|
|
+
|
|
|
+# 针对有网络模型,但还没有训练保存 .pth 文件的情况
|
|
|
+import onnxruntime
|
|
|
+import netron
|
|
|
+import onnx
|
|
|
+import torch.onnx
|
|
|
+
|
|
|
+def printLabels():
|
|
|
+ # Load Label from Model's meta data
|
|
|
+ session = onnxruntime.InferenceSession("./yolov8s.onnx", providers=['CPUExecutionProvider'])
|
|
|
+ meta = session.get_modelmeta().custom_metadata_map # metadata
|
|
|
+ label_names = eval(meta['names'])
|
|
|
+ print(label_names)
|
|
|
+
|
|
|
+def addLabels():
|
|
|
+ f = "./yolov7.onnx"
|
|
|
+ new_f = "./yolov7-425-labels.onnx"
|
|
|
+ model_onnx = onnx.load(f) # load onnx model
|
|
|
+ onnx.checker.check_model(model_onnx) # check onnx model
|
|
|
+ labels = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
|
|
+ "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
|
|
|
+ "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
|
|
|
+ "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
|
|
|
+ "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
|
|
|
+ "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
|
|
|
+ "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
|
|
|
+ "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
|
|
|
+ "hair drier", "toothbrush"]
|
|
|
+ index = 1
|
|
|
+ meta = model_onnx.metadata_props.add()
|
|
|
+ meta.key, meta.value = "labels", ",".join(labels)
|
|
|
+ print(meta.key + ":" + meta.value)
|
|
|
+ onnx.save(model_onnx, new_f)
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ # printLabels()
|
|
|
+ # netron.start("./yolov8s.onnx") # 输出网络结构
|
|
|
+ netron.start("./yolov7-425-labels.onnx")
|
|
|
+ # addLabels()
|