onnx_labels.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # Created by zhenqin.
  4. # User: zhenqin
  5. # Date: 2024-04-25
  6. # Time: 下午1:58
  7. # Verdor: yiidata.com
  8. #
  9. # 针对有网络模型,但还没有训练保存 .pth 文件的情况
  10. import onnxruntime
  11. import netron
  12. import onnx
  13. import torch.onnx
  14. def printLabels():
  15. # Load Label from Model's meta data
  16. session = onnxruntime.InferenceSession("./yolov8s.onnx", providers=['CPUExecutionProvider'])
  17. meta = session.get_modelmeta().custom_metadata_map # metadata
  18. label_names = eval(meta['names'])
  19. print(label_names)
  20. def addLabels():
  21. f = "./yolov7.onnx"
  22. new_f = "./yolov7-425-labels.onnx"
  23. model_onnx = onnx.load(f) # load onnx model
  24. onnx.checker.check_model(model_onnx) # check onnx model
  25. labels = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
  26. "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
  27. "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
  28. "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
  29. "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
  30. "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
  31. "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
  32. "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
  33. "hair drier", "toothbrush"]
  34. index = 1
  35. meta = model_onnx.metadata_props.add()
  36. meta.key, meta.value = "labels", ",".join(labels)
  37. print(meta.key + ":" + meta.value)
  38. onnx.save(model_onnx, new_f)
  39. if __name__ == '__main__':
  40. # printLabels()
  41. # netron.start("./yolov8s.onnx") # 输出网络结构
  42. netron.start("./yolov7-425-labels.onnx")
  43. # addLabels()