train: ./traffic_violation_dataset/images/train
val: ./traffic_violation_dataset/images/val
nc: 7
names: ['Crossing_Violation', 'Crosswalk_Violation', 'Helmet_Violation', 'Normal', 'Passenger_Violation', 'Pedestrian_Violation', 'Trafficlight_Violation']
from ultralytics import YOLO
import os
import time
MODEL_NAME = "yolov11s.pt"
CONFIG_FILE = "traffic.yaml"
EPOCHS = 100
BATCH_SIZE = 16
IMG_SIZE = 640
PROJECT_NAME = "traffic_violation_detection"
EXPERIMENT_NAME = "exp_traffic_yolo11"
DEVICE = 0 if os.environ.get("CUDA_AVAILABLE") else "cpu"
if not os.path.exists(CONFIG_FILE):
raise FileNotFoundError(f"[ERROR] 找不到配置文件:{CONFIG_FILE}")
print("🚀 加载 YOLOv11 模型...")
model = YOLO(MODEL_NAME)
print("🔥 开始训练交通违规检测模型...")
start_time = time.time()
results = model.train(
data=CONFIG_FILE,
epochs=EPOCHS,
batch=BATCH_SIZE,
imgsz=IMG_SIZE,
device=DEVICE,
project=PROJECT_NAME,
name=EXPERIMENT_NAME,
exist_ok=True,
patience=30,
save=True,
save_period=10,
cache=False,
workers=4,
optimizer='AdamW',
lr0=0.001,
lrf=0.01,
momentum=0.937,
weight_decay=0.0005,
hsv_h=0.015,
hsv_s=0.7,
hsv_v=0.4,
degrees=10.0,
translate=0.2,
scale=0.5,
shear=2.0,
flipud=0.0,
fliplr=0.5,
mosaic=1.0,
mixup=0.1,
copy_paste=0.1
)
training_time = (time.time() - start_time) / 60
print(f"✅ 训练完成!总耗时:{training_time:.2f} 分钟")
print(f"📍 模型保存路径:{results.save_dir}")
print("🔍 开始验证...")
metrics = model.val()
print(f"[email protected]: {metrics.box.map50:.4f}")
print(f"[email protected]:0.95: {metrics.box.map:.4f}")
print(f"Precision: {metrics.box.p:.4f}")
print(f"Recall: {metrics.box.r:.4f}")
print("\n📊 各类别 [email protected]:")
for i, cls_name in enumerate(model.names):
print(f" {cls_name}: {metrics.box.maps[i]:.4f}")
print("📦 导出 ONNX 模型...")
onnx_path = model.export(format="onnx", dynamic=True, simplify=True)
print(f"ONNX 模型已保存至:{onnx_path}")
from ultralytics import YOLO
import cv2
model = YOLO("traffic_violation_detection/exp_traffic_yolo11/weights/best.pt")
image_path = "test.jpg"
results = model(image_path, conf=0.4, iou=0.5, imgsz=640)
for r in results:
im_array = r.plot()
im = cv2.cvtColor(im_array, cv2.COLOR_RGB2BGR)
cv2.imshow("Traffic Violation Detection", im)
cv2.waitKey(0)
cv2.imwrite("result_detected.jpg", im)
for r in results:
boxes = r.boxes
for box in boxes:
cls_id = int(box.cls[0])
conf = float(box.conf[0])
class_name = model.names[cls_id]
print(f"检测到违规行为:{class_name} (置信度:{conf:.2f})")