Quellcode durchsuchen

增加训练代码

zhzhenqin vor 5 Monaten
Ursprung
Commit
286edfc742
4 geänderte Dateien mit 26 neuen und 4 gelöschten Zeilen
  1. 4 1
      docs/YOLO模型训练.md
  2. 2 1
      docs/YOLO理论研究.md
  3. 16 0
      simple_yolo.py
  4. 4 2
      yolo_model_nms_export.py

+ 4 - 1
docs/YOLO模型训练.md

@@ -1,7 +1,7 @@
 ## YOLO 采用命令训练数据集
 
 ```shell
-yolo train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
+yolo train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01 --device='0,1'
 
 yolo task=detect mode=train model=yolov8x.yaml data=mydata.yaml epochs=10 batch=16
 
@@ -16,6 +16,9 @@ yolo task=segment mode=predict model=yolov8x-seg.pt source='/kaggle/input/person
 - data: 选择生成的数据集配置文件
 - epochs:指的就是训练过程中整个数据集将被迭代多少次,显卡不行你就调小点。
 - batch:一次看完多少张图片才进行权重更新,梯度下降的mini-batch,显卡不行你就调小点。
+- device: cpu or '0' or '0,1', 采用 cpu or gpu,以及 gpu 编号
+- imgsz: 输入图片大小,显卡不行你就调小点。
+- name: 模型保存的名称
 
 实际运行:
 ```shell

+ 2 - 1
docs/YOLO理论研究.md

@@ -1 +1,2 @@
-https://blog.csdn.net/weixin_45303602/article/details/129175854
+https://blog.csdn.net/weixin_45303602/article/details/129175854
+https://blog.csdn.net/qq_32892383/article/details/136505299

+ 16 - 0
simple_yolo.py

@@ -0,0 +1,16 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+
+from ultralytics import YOLO
+
+model = YOLO("/Volumes/Media/WorkDoc/Beizhi/CODE/hyd-yolo/yolov8/yolov8n.pt")
+
+# yolo train data= model=/home/jxft/datasets/yolov8/yolov8n.pt epochs=100 lr0=0.01
+model.train(data="/Users/zhenqin/temp/yolo_demo/datasets/hyd-action.yaml", 
+            workers=0, 
+            batch=16, 
+            epochs=3)
+results = model.val()
+
+# Export the model to ONNX format
+success = model.export(format='onnx')

+ 4 - 2
yolo_model_nms_export.py

@@ -19,7 +19,8 @@ class YOLOv8:
         return self.detect_objects(image)
 
     def initialize_model(self, path):
-        self.session = onnxruntime.InferenceSession(path,providers=['CUDAExecutionProvider','CPUExecutionProvider'])
+        # 'CUDAExecutionProvider', 
+        self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'])
         # Get model info
         self.get_input_details()
         self.get_output_details()
@@ -185,7 +186,8 @@ class YOLOv8:
         return y
 
 if __name__ == "__main__":
+    model_path = "/Volumes/Media/WorkDoc/Beizhi/CODE/hyd-yolo/models/water_strean_model.pt"
     yolov8_detector = YOLOv8(model_path, conf_thres=0.7, iou_thres=0.7)
-    image = cv2.imread()
+    image = cv2.imread("/Users/zhenqin/temp/yolo_demo/datasets/hyd-action/images/Frame2000128.png")
     boxes, scores, class_ids = yolov8_detector(image)
     print(boxes, scores, class_ids)