path: /home/user/digit_box_dataset
train: images/train
val: images/val
test: images/test
nc: 10
names: ['0','1','2','3','4','5','6','7','8','9']
from ultralytics import YOLO
model = YOLO("yolov8s.pt")
train_results = model.train(
data="digit_box.yaml",
epochs=100,
imgsz=640,
batch=16,
device=0,
lr0=0.01,
lrf=0.01,
weight_decay=0.0005,
warmup_epochs=3,
box=7.5,
cls=0.5,
save=True,
project="digit_box_train",
name="digit_box_model",
exist_ok=True
)
print("最佳模型路径:", model.best)
print("训练集框损失:", train_results.results_dict["train/box_loss"])
print("验证集 [email protected](框精度核心指标):", train_results.metrics["metrics/mAP50(B)"])
model = YOLO("digit_box_train/digit_box_model/weights/best.pt")
val_metrics = model.val(
data="digit_box.yaml",
imgsz=640,
device=0,
iou=0.5
)
print(f"框平均精度 [email protected]:{val_metrics.box.map:.4f}")
print(f"框精确率(框定位准不准):{val_metrics.box.mp:.4f}")
print(f"框召回率(框有没有漏检):{val_metrics.box.mr:.4f}")
print(f"类别准确率:{val_metrics.box.mc:.4f}")
for cls_id, cls_name in model.names.items():
print(f"数字{cls_name}的框精度:{val_metrics.box.ap50[cls_id]:.4f}")
import cv2
from ultralytics import YOLO
MODEL_PATH = "digit_box_train/digit_box_model/weights/best.pt"
TEST_IMG_PATH = "test_digit.jpg"
CONF_THRESH = 0.5
IOU_THRESH = 0.5
model = YOLO(MODEL_PATH)
results = model(
TEST_IMG_PATH,
imgsz=640,
conf=CONF_THRESH,
iou=IOU_THRESH,
device=0
)
digit_boxes = []
for r in results:
boxes = r.boxes
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
cls_id = int(box.cls[0])
digit = model.names[cls_id]
conf = round(float(box.conf[0]), 3)
digit_boxes.append({
"digit": digit,
"confidence": conf,
"bbox": (x1, y1, x2, y2),
"center": ((x1+x2)/2, (y1+y2)/2)
})
digit_boxes = sorted(digit_boxes, key=lambda x: x["center"][0])
detected_digits = [d["digit"] for d in digit_boxes]
img = cv2.imread(TEST_IMG_PATH)
for d in digit_boxes:
x1, y1, x2, y2 = d["bbox"]
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
label = f"{d['digit']} ({d['confidence']})"
cv2.putText(img, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
print("检测到的逐位数字:", detected_digits)
print("每个 digit 框详情:", digit_boxes)
cv2.imwrite("digit_box_detect_result.jpg", img)
cv2.imshow("Digit Box Detection", img)
cv2.waitKey(0)
cv2.destroyAllWindows()
anchors:
- [6,8, 10,13, 16,23]
- [23,33, 30,61, 62,45]
- [59,119, 116,90, 156,198]