跳到主要内容
智能车竞赛惯导校准与视觉避障实战思路分享 | 极客日志
Python AI 算法
智能车竞赛惯导校准与视觉避障实战思路分享 智能车竞赛中网络延迟影响显著,建议使用有线连接及专用路由器。上位机辅助可视化障碍物位置可提升操控效率。扫码环节利用深度相机清晰度优势并优化节点逻辑降低 CPU 占用。P 点返回精度通过逆透视变换结合固定地图元素或 YOLO 识别进行坐标校准。STM32 源码需调整舵机转角多项式系数以保证左右对称,提高串口通信频率至 50Hz 以匹配 EKF 计算。数据标注采用模型预贴加人工复核流程,配合去重增强脚本提升训练集质量。
宁静 发布于 2026/4/7 更新于 2026/4/25 4 浏览前言
在智能车竞赛中,我们团队取得了优异成绩。作为技术负责人,我想分享一下备赛过程中的一些技术思路。为了保持竞争力,部分核心算法细节将不公开。
本文记录了备赛过程中的全流程经验,包括网络优化、上位机辅助处理、扫码策略、P 点返回校准、STM32 源码修改及数据处理脚本。
网络问题
参赛初期常遇到严重的网络延迟问题——上位机延迟。第二年备赛时我们非常重视此问题。
初期使用家用路由器,在校赛期间表现良好。赛后升级了高性能无线路由器。调试时建议携带专用路由器,实验室环境下可消除延时。赛场上建议使用较少占用的信道(如 165 信道),但需注意现场干扰情况。
图 1 高性能无线路由器
连接方式上,上位机和终端最好都使用网线连接路由器,避免使用板载无线网卡。
调试阶段若开启中继模式可能导致信道不可调,进而引发轮次增加后的延时。建议放弃云端 API 调用,转用本地部署方案。
现场网络环境复杂,部分队伍因网络问题未能晋级。图生文环节对网络依赖较高,云端 API 质量虽好但易受现场影响,本地部署模型效果有限。
上位机辅助处理
在上位机视角中,桶和 P 点底部会有红线标识。这是通过上位机的 bridge_client.py 单独运行 Python 脚本接收 YOLO 结果,并使用 tkinter 库绘制实现的。画出障碍物位置可以帮助操作者快速确定障碍物位置。
import tkinter as tk
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy
from std_msgs.msg import String, Int32
from sensor_msgs.msg import Image
from nav_msgs.msg import Odometry
from origincar_msg.msg import Sign
from cv_bridge import CvBridge
import cv2
import numpy as np
import keyboard
from threading import Thread
class LLM2Origincar :
def __init__ (self, host, port ):
self .ros = None
.host = host
.port = port
.roadblock_list = []
.end_list = []
.init_ros()
.init_topic()
.init_thread()
.keep()
( ):
.yolo_sub = Topic( .ros, , , latch= )
.yolo_sub.subscribe( .yolo_sub_callback)
( ):
.roadblock_list.clear()
.end_list.clear()
target msg[ ]:
target[ ] == :
rect = target[ ][ ][ ]
.roadblock_list.append({
: rect[ ],
: rect[ ],
: rect[ ] + rect[ ],
})
target[ ] == :
rect = target[ ][ ][ ]
.end_list.append({
: rect[ ],
: rect[ ],
: rect[ ],
: rect[ ] + rect[ ],
: target[ ][ ][ ],
})
( ):
:
:
canvas.delete( )
canvas.create_line( , , , , fill= , width= )
canvas.create_line( , , , , fill= , width= )
.roadblock_list:
obst .roadblock_list:
b = (obst[ ] * )
canvas.create_line(
(obst[ ] * ),
b,
((obst[ ] + obst[ ]) * ),
b,
fill= ,
width=
)
.end_list:
end .end_list:
x1 = (end[ ] * )
y1 = (end[ ] * )
x2 = ((end[ ] + end[ ]) * )
y2 = (end[ ] * )
canvas.create_line(x1, y2, x2, y2, fill= , width= )
canvas.create_text( ((x1+x2)/ ), (y1- ) (y1- ) > , text= . (end[ ]), fill= )
self
self
self
self
self
self
self
self
def
init_topic
self
self
self
'/hobot_dnn_detection'
'ai_msgs/msg/PerceptionTargets'
True
self
self
def
yolo_sub_callback
self, msg
self
self
for
in
'targets'
if
'type'
'roadblock'
'rois'
0
'rect'
self
'x'
'x_offset'
'w'
'width'
'b'
'y_offset'
'height'
elif
'type'
'end'
'rois'
0
'rect'
self
'x'
'x_offset'
'y'
'y_offset'
'w'
'width'
'b'
'y_offset'
'height'
'c'
'rois'
0
'confidence'
def
keep
self
try
while
True
"all"
141
0
141
680
"red"
1
689
0
689
680
"red"
1
if
self
for
in
self
int
'b'
1.42
int
'x'
1.41
int
'x'
'w'
1.41
"red"
2
if
self
for
in
self
int
'x'
1.41
int
'y'
1.41
int
'x'
'w'
1.41
int
'b'
1.42
"blue"
1
int
2
20
if
20
0
else
0
"conf:{:.2f}"
format
'c'
'cyan'
除了画出障碍物作为辅助,还增加了按键来辅助任务切换和调用 API。
def keyboard_thread ():
while True :
sleep(0.05 )
if keyboard.is_pressed('b' ) or keyboard.is_pressed('B' ):
self .sign4return_pub.publish(self .sign4return_data)
sleep(0.5 )
if keyboard.is_pressed('r' ) or keyboard.is_pressed('R' ):
self .sign4return_data['data' ] = 5
self .sign4return_pub.publish(self .sign4return_data)
self .sign4return_data['data' ] = 0
sleep(0.5 )
if keyboard.is_pressed('p' ) or keyboard.is_pressed('P' ):
self .sign4return_data['data' ] = 6
self .sign4return_pub.publish(self .sign4return_data)
self .sign4return_data['data' ] = 0
sleep(0.5 )
if keyboard.is_pressed('j' ) or keyboard.is_pressed('J' ):
self .llm_data['data' ] = 1
self .llm_pub.publish(self .llm_data)
sleep(1 )
半场扫码 小车前部的 USB 相机拍摄的照片清晰度较低。深度相机具有优势,不仅能获取深度信息,照片也比 USB 相机清晰得多。
图 3 深度相机扫码
图 4 USB 相机扫码
扫码节点不应一直开启,特别耗 CPU。扫码条件是:任务状态为任务一,且小车过了半场(全局坐标的 x 超过 2m)。
import rclpy
from rclpy.node import Node
import cv2
import numpy as np
from sensor_msgs.msg import Image
from std_msgs.msg import String, Int32
from nav_msgs.msg import Odometry
from origincar_msg.msg import Sign
from cv_bridge import CvBridge
TASK1 = 1
TASK2_WAITFOR_CMD = 2
TASK2 = 3
TASK3 = 4
TASK_STOP = 5
class QrCodeDetection (Node ):
def __init__ (self ):
super ().__init__('QRcodeSub' )
self .Sign4ReturnSub = self .create_subscription(Int32, 'sign4return' , self .sign4return_callback, 10 )
self .ImageSub = self .create_subscription(Image, '/aurora/rgb/image_raw' , self .image_callback, 10 )
self .OdomSub = self .create_subscription(Odometry, '/odom_combined' , self .Odom_callback, 10 )
self .qrcode_publisher = self .create_publisher(String, "/qrcode_information" , 10 )
self .info_result = String()
self .sign_publisher = self .create_publisher(Sign, '/sign_switch' , 10 )
self .sign_msg = Sign()
self .detector = cv2.wechat_qrcode_WeChatQRCode(
"/userdata/WorkSpace/codes/src/qrcode/qrcode/model/detect.prototxt" ,
"/userdata/WorkSpace/codes/src/qrcode/qrcode/model/detect.caffemodel" ,
"/userdata/WorkSpace/codes/src/qrcode/qrcode/model/sr.prototxt" ,
"/userdata/WorkSpace/codes/src/qrcode/qrcode/model/sr.caffemodel"
)
self .bridge = CvBridge()
self .node_run = False
self .task = TASK1
def image_callback (self, msg ):
if self .node_run and (self .task == TASK1 or self .task == TASK2):
cv2_image = self .bridge.imgmsg_to_cv2(msg, desired_encoding='mono8' )[155 :,:]
res = self .detector.detectAndDecode(cv2_image)[0 ]
if res:
self .node_run = False
for r in res:
self .info_result.data = str (r)
self .qrcode_publisher.publish(self .info_result)
self .get_logger().info("{}" .format (self .info_result.data))
if self .info_result.data == "AntiClockWise" :
self .sign_msg.sign_data = 4
elif self .info_result.data == "ClockWise" :
self .sign_msg.sign_data = 3
else :
try :
data = int (r)
if data % 2 :
self .sign_msg.sign_data = 3
else :
self .sign_msg.sign_data = 4
except :
pass
self .sign_publisher.publish(self .sign_msg)
self .info_result.data = "None"
self .sign_msg.sign_data = 0
else :
return
def sign4return_callback (self, msg ):
if msg.data == 0 or msg.data == -1 :
self .task = TASK1
self .node_run = False
if msg.data == 5 :
self .task = TASK2
elif msg.data == 6 :
self .task = TASK3
def Odom_callback (self, msg ):
if self .task == TASK1 and msg.pose.pose.position.x > 2 :
self .node_run = True
if __name__ == '__main__' :
rclpy.init(args=None )
qrCodeDetection = QrCodeDetection()
while rclpy.ok():
rclpy.spin(qrCodeDetection)
qrCodeDetection.destroy_node()
rclpy.shutdown()
我们没有把整张照片丢给扫码模型,而是裁掉了一部分,例如去掉红线上面的区域,这样图像变小很多。
图 5 裁掉一部分图片
准确返回 P 点 准确返回 P 点的思路有 3 个,其中一个是另一队分享的。他们的思路是任务二停车对准 P 点,然后退出遥操作,利用 YOLO 识别 P 点回去,但这很看操作水平。
思路 1——使用地图的固定元素来校准 这个思路要重置里程计,每次都在通道重置,相当于把原点设在了这里,计算出来的相对坐标可直接用作全局坐标。
如果小车每次从任务二出来都停在同一个位置,然后重置里程计,一定会有一个终点可以让小车回去。正式比赛时,很难做到每次都停得这么准确,所以得想办法校准,用地图的固定元素。
通道出来一定有可以识别到的线。线在地图的位置绝对是固定的,接下来的操作就是在线上了。
有了线,先求线的相对位置(相对于小车的)。求线的相对位置可以参考逆透视变换相关文章。
先把车固定在一个位置,用小橙的 USB 相机拍一张照片,然后给终点让小橙跑过去,多试几次,每次都要放回原来的位置跑过去,小橙正好回去的点就是关键坐标。
图 6 固定小橙的位置,终点是 (1.9m, -1.5m)
有了这个点,接下来就是用逆透视变换求前面 2 根线的相对位置(必须是最近的 2 根)。
这样,我们就有了 3 个点的坐标,分别是 A,B,P。有了这 3 个点,我们再任意摆放车,再用逆透视变换计算视角下最近的 2 根线的相对位置。这样就又有了 2 个点,分别是 A',B',一共有了 5 个点,现在就是用这 5 个点来计算 P'。
问题简化一下:在原来坐标系下,我知道 A,B,P 3 个点。原坐标系经过变换之后,我可以知道 A',B',接下来我想求 P'。
要求 P',那就得先知道前后 2 个坐标系是怎么变换的,也就是前后 2 个坐标系之间的旋转矩阵 R 和平移变量 t。
利用已知的两个对应点对 (A, A') 和 (B, B') 来求解 R 和 t。
对于点 A 和 A':
A′ = RA + t (1)
对于点 B 和 B':
B′ = RB + t (2)
将方程 (1) 和 (2) 相减:
A′ − B′ = R(A − B)
令:
ΔAB = A − B
ΔA′B′ = A′ − B′
因此,旋转矩阵 R 可以这样求解:
R = ΔA′B′ ΔAB^(-1)
一旦得到 R,可以通过任一对点求解 t。比如,用 A 和 A':
t = A′ − RA
最后,对于 P,其在新坐标系中的坐标 P′ 为:
P′ = RP + t
这就是校准的所有步骤了,用 2 根线的坐标去校准 P 点的坐标。
大家把 C 换成 P 来看就可以了。
最后计算出来这个 P 点是非常准确的,用 python 写的代码(有些参数是固定的,大家可以先离线计算):
def end_point (x1, y1, x2, y2, x3, y3, x1_, y1_, x2_, y2_ ):
delta_x = x1 - x2
delta_y = y1 - y2
delta_x_ = x1_ - x2_
delta_y_ = y1_ - y2_
den = delta_x ** 2 + delta_y ** 2
a = (delta_x * delta_x_ + delta_y * delta_y_) / den
b = (delta_x * delta_y_ - delta_y * delta_x_) / den
tx = x1_ - a * x1 + b * y1
ty = y1_ - b * x1 - a * y1
x3_ = a * x3 - b * y3 + tx
y3_ = b * x3 + a * y3 + ty
print (f"(x1, y1): ({x1, y1} ), (x2, y2): ({x2, y2} ), (x3, y3): ({x3, y3} ) delta x: {delta_x} , delta y: {delta_y} , den: {den} " )
return x3_, y3_
print (f"end': {end_point(ptx1, pty1, ptx2, pty2, 1.9 , -1.5 , ptx3, pty3, ptx4, pty4)} " )
思路 2——不重置里程计,使用 YOLO 识别 P 点结果来校正终点 我们比赛时候使用的就是这个思路,上一个思路要在通道处停一下,这个是不需要停下来,直接冲出去就行了。这个思路比上一个简单很多,能回去的概率非常大。
这个思路是,使用 YOLO 识别 P 点,然后还是用逆透视变换计算 P 点相对坐标,再通过小车的坐标计算这个 P 点的全局坐标。
H = np.array([
[-4.66389128e-04 , -2.26288030e-04 , -4.92300831e-02 ],
[7.59821540e-04 , 5.20569143e-05 , -2.33074608e-01 ],
[-6.59643252e-04 , -7.15022786e-03 , 1.00000000e+00 ],
])
def pixel2global (self, pixel_x, pixel_y ):
pixel = np.array([pixel_x, pixel_y, 1 ], dtype=np.float32)
local = np.dot(H, pixel)
local /= local[2 ]
local[0 ] += 0.25
car_cos = np.cos(self .current_pos[2 ])
car_sin = np.sin(self .current_pos[2 ])
global_x = self .current_pos[0 ] + car_cos * local[0 ] - car_sin * local[1 ]
global_y = self .current_pos[1 ] + car_sin * local[0 ] + car_cos * local[1 ]
return global_x, global_y
将 P 点中心的像素坐标丢到这个函数里面,输出的就是校正之后的终点了(车头朝向为 x+,车左边为 y+)。因为全程不重置里程计,所以未校正的终点设为出发的原点 (0.5m, 0.2m)。校正之前,小车的终点就是这个。
这种思路对 YOLO 的要求比较高,所以必须采集非常多的数据,我们在实验室采了 7k,数据增强之后就有了差不多 2w2,训练出来的效果很好。最后在比赛现场采了 1k4,增强到 7k,一共 2w9 张,加急训练了 22 小时。其中数据集有超过一半的是 P 点的。
采集数据的时候,这种只有一个小角或被桶挡住了或比较远看起来很扁的也贴上(凡是小车可以看到 P 点,哪怕是一点点也要标上,而且尽量采集多一些),这样训练出来也是可以识别到的,回去更轻松。不过要注意一下任务三出来的位置,可能会误识成 P 点,所以在任务三出来(包括白色网格那里)也要采集一点,不需要太多。
修改 STM32 源码 stm32 源码我基本上都看过了,也找到了很多可以修改的地方。
最重要的是舵机转角。我找到了舵机转角的限制,但是转角对不对称跟有没有限制是没关系的,如果去掉限制,反而有可能会把前轮给跑坏。
所以,问题应该是出在了计算上面,计算舵机转角的地方也就只有这个多项式了。
右前轮的转向角度的限幅是(-0.49,0.32)(0.32 是左转最大的角度,-0.49 是右转最大角度),说明右前轮的角度在 -0.49 和 0.32 的时候,小车的舵机量应该是相同的。但是,直接把这 2 个值带入原来的多项式中,得到的舵机量是这样的:
这就说明小车左转转角小,右转转角大。用上位机跑的时候,确实是左转小,右转大。
调整多项式的二次项系数之后,大概让左右转的舵机量相同。
经过这样的调整之后,小车左转和右转是差不多的角度了。在上位机跑起来,左右转是对称的。
为了提高惯导的计算频率,我把串口发送的频率提高到了 50Hz,串口波特率提高到了 921600,而且也把其他没用到的外设(比如 CAN,蓝牙)给关了,只留下串口 1 和串口 3。
这里为什么用串口 1 呢?因为我们小橙底板的串口 3 有问题,会出现断连,干脆直接换成串口 1 了,也就是烧录口。(不知道为什么,把串口 3 关掉之后,oled 会卡住,只能打开它了)
在 X5 上面,找到 /root/dev_ws/origincar/origincar_base/launch/base_serial.launch.py,波特率改成 921600,再编译一下就可以正常通信了(不用管 clear_flag)。
因为提高串口发送的频率到 50Hz,所以,小车上 EKF 的计算频率也要设为 50Hz。
我们实际用下来,发现小车的 odom_combined 挺准的,一圈下来 x 和 y 偏得都不算太大。我们也改了一下 ekf.yaml,大家可以参考一下我们改的地方。
原来的/imu/data_raw 是原始数据,/imu/data 是经过滤波之后的。
除了删掉一些外设,我们还把电位器决定车型号的部分也删掉了,固定车型号为 Ackerman;oled 刷屏显示也把跟 Ackerman 无关的给删掉了。
补充 忘记把处理数据的代码放上来了,就在这里补充一下吧。
我们贴标签不是傻乎乎地全部自己来贴。我们先用以前训练过的模型贴一遍,然后再人工检查一遍。这样子操作,一个人不用一天时间就可以贴差不多 6k。
除此之外,我们还写了删除无效图片和无效标签的脚本(图片没有对应同名的标签或者标签没有对应同名的图片)、数据增强的脚本(没有旋转)和将数据分批次让队友来帮忙的脚本。
附
这份是让模型贴标签的: import argparse
import os
import shutil
import time
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
import cv2
from models.experimental import attempt_load
from utils.datasets import LoadImages
from utils.utils import non_max_suppression, scale_coords, xyxy2xywh
from utils.torch_utils import select_device, time_synchronized
def auto_annotate (source, weights, output, img_size=640 , conf_thres=0.25 , iou_thres=0.45 , view_img=False ):
device = select_device(device)
half = device.type != 'cpu'
model = attempt_load(weights, map_location=device)
imgsz = img_size
if half:
model.half()
names = model.module.names if hasattr (model, 'module' ) else model.names
dataset = LoadImages(source, img_size=imgsz)
t0 = time.time()
img = torch.zeros((1 , 3 , imgsz, imgsz), device=device)
_ = model(img.half() if half else img)
if device.type != 'cpu' else None
for path, img, im0s, _ in dataset:
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float ()
img /= 255.0
if img.ndimension() == 3 :
img = img.unsqueeze(0 )
t1 = time_synchronized()
pred = model(img, augment=False )[0 ]
pred = non_max_suppression(pred, conf_thres, iou_thres, classes=None , agnostic=False )
t2 = time_synchronized()
p, im0 = path, im0s.copy()
txt_path = str (Path(output) / Path(p).stem) + ('.txt' )
open (txt_path, 'w' ).close()
whwh gn = torch.tensor(im0.shape)[[1 , 0 , 1 , 0 ]]
if pred is not None :
for i, det in enumerate (pred):
if det is not None and len (det):
det[:, :4 ] = scale_coords(img.shape[2 :], det[:, :4 ], im0.shape).round ()
with open (txt_path, 'w' ) as f:
if det is not None and len (det):
for *xyxy, conf, cls in reversed (det):
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1 , 4 )) / gn).view(-1 ).tolist()
line = "%d %.6f %.6f %.6f %.6f" % (cls, *xywh)
f.write(line + "\n" )
else :
f.write("" )
print (f'{Path(p).name} done. ({t2 - t1:.3 f} s)' )
if view_img:
cv2.imshow(Path(p).name, im0)
if cv2.waitKey(1 ) == ord ('q' ):
raise StopIteration
print (f'Done. ({time.time() - t0:.3 f} s)' )
if __name__ == '__main__' :
parser = argparse.ArgumentParser()
parser.add_argument('--source' , type =str , default='dataset_process/new1/images' , help ='输入图像文件夹路径' )
parser.add_argument('--weights' , type =str , default='runs/2025.7.28/weights/last.pt' , help ='模型权重路径' )
parser.add_argument('--output' , type =str , default='dataset_process/new1/labels' , help ='输出标签路径' )
parser.add_argument('--img-size' , type =int , default=640 , help ='推理尺寸 (像素)' )
parser.add_argument('--conf-thres' , type =float , default=0.25 , help ='目标置信度阈值' )
parser.add_argument('--iou-thres' , type =float , default=0.45 , help ='NMS 的 IOU 阈值' )
parser.add_argument('--device' , help ='cuda 设备,如 0 或 0,1,2,3 或 cpu' )
parser.add_argument('--view-img' , action='store_true' , help ='显示结果' )
opt = parser.parse_args()
print (opt)
with torch.no_grad():
auto_annotate(
source=opt.source,
weights=opt.weights,
output=opt.output,
img_size=opt.img_size,
conf_thres=opt.conf_thres,
iou_thres=opt.iou_thres,
device=opt.device,
view_img=opt.view_img
)
这份是删除无效数据的: import os
from pathlib import Path
def remove_invalid_images_labels (image_dir, label_dir ):
deleted_images = 0
deleted_labels = 0
for image_file in os.listdir(image_dir):
if image_file.lower().endswith(('.jpg' , '.png' , '.jpeg' )):
image_path = os.path.join(image_dir, image_file)
label_path = os.path.join(label_dir, Path(image_file).stem + '.txt' )
if not os.path.exists(label_path):
os.remove(image_path)
deleted_images += 1
print (f"删除图片(无标签): {image_file} " )
else :
with open (label_path, 'r' ) as f:
content = f.read().strip()
if not content:
os.remove(image_path)
os.remove(label_path)
deleted_images += 1
deleted_labels += 1
print (f"删除无效数据:{image_file} 和对应标签" )
print (f"\n操作完成!共删除:{deleted_images} 张图片,{deleted_labels} 个标签" )
if __name__ == '__main__' :
image_dir = os.path.join(os.path.dirname(__file__), "new1/images/" )
label_dir = os.path.join(os.path.dirname(__file__), "new1/labels/" )
confirm = input ("是否继续?(y/n): " ).lower()
if confirm == 'y' :
remove_invalid_images_labels(image_dir, label_dir)
else :
print ("操作已取消" )
这份是数据增强的: import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from pathlib import Path
import shutil
from PIL import Image
import random
from multiprocessing import Pool
import os
class YOLOAugment :
def __init__ (self, output_dir ):
self .output_dir = output_dir
Path(f"{output_dir} /images" ).mkdir(parents=True , exist_ok=True )
Path(f"{output_dir} /labels" ).mkdir(parents=True , exist_ok=True )
self .img_augment = T.Compose([
T.ColorJitter(brightness=0.3 , contrast=0.3 , saturation=0.2 ),
T.GaussianBlur(kernel_size=(3 , 7 ))
])
def apply_augment (self, img_path, label_path, aug_id ):
img = Image.open (img_path).convert('RGB' )
with open (label_path) as f:
bboxes = [list (map (float , line.strip().split())) for line in f]
img_tensor = TF.to_tensor(img)
bboxes_tensor = torch.tensor(bboxes)
img_tensor = self .img_augment(img_tensor)
stem = Path(img_path).stem
self ._save_results(img_tensor, bboxes_tensor, stem, aug_id)
return img, bboxes
def _save_results (self, img_tensor, bboxes, stem, aug_id ):
aug_img = TF.to_pil_image(img_tensor)
aug_img.save(f"{self.output_dir} /images/{stem} _aug{aug_id} .jpg" )
with open (f"{self.output_dir} /labels/{stem} _aug{aug_id} .txt" , 'w' ) as f:
for bbox in bboxes.numpy():
line = " " .join(map (str , bbox))
f.write(line + '\n' )
def process_file (args ):
img_path, label_path, output_dir, aug_per_image = args
augmenter = YOLOAugment(output_dir)
for i in range (1 , aug_per_image + 1 ):
augmenter.apply_augment(img_path, label_path, i)
shutil.copy(img_path, f"{output_dir} /images/{Path(img_path).name} " )
shutil.copy(label_path, f"{output_dir} /labels/{Path(label_path).name} " )
if __name__ == "__main__" :
root_path = os.path.dirname(__file__)
input_dir = os.path.join(root_path, "new1" )
output_dir = os.path.join(root_path, "new1_aug" )
aug_per_image = 3
num_workers = 4
tasks = []
for img_file in Path(f"{input_dir} /images" ).glob("*.*" ):
if img_file.suffix.lower() in ('.jpg' , '.png' , '.jpeg' ):
label_file = Path(f"{input_dir} /labels/{img_file.stem} .txt" )
if label_file.exists():
tasks.append((str (img_file), str (label_file), output_dir, aug_per_image))
print (f"开始增强 {len (tasks)} 张图像..." )
with Pool(processes=num_workers) as pool:
pool.map (process_file, tasks)
orig_count = len (tasks)
aug_count = orig_count * aug_per_image
print (f"处理完成!\n" f"- 原始图像保留:{orig_count} 张\n" f"- 增强图像生成:{aug_count} 张\n" f"- 总数据量:{orig_count + aug_count} 张" )
这份是让队友打工的: import os
import zipfile
import math
from pathlib import Path
def create_task_packs (images_dir, labels_dir, output_dir, tasks=3 , label_txt=False ):
image_files = sorted ([f for f in os.listdir(images_dir) if f.endswith(('.jpg' , '.png' ))])
label_files = sorted ([f for f in os.listdir(labels_dir) if f.endswith('.txt' )])
image_stems = {Path(f).stem for f in image_files}
label_stems = {Path(f).stem for f in label_files}
unmatched = image_stems.symmetric_difference(label_stems)
if unmatched:
print (f"⚠️ 警告:发现 {len (unmatched)} 个不匹配文件(示例:{list (unmatched)[:3 ]} )" )
print ("建议先运行数据校验脚本修复不一致问题!" )
return
total_pairs = len (image_files)
pairs_per_task = math.ceil(total_pairs / tasks)
print (f"数据集统计:" )
print (f"- 图片数量:{len (image_files)} " )
print (f"- 标注数量:{len (label_files)} " )
print (f"- 将分成 {tasks} 个任务包,每个约 {pairs_per_task} 对数据\n" )
os.makedirs(output_dir, exist_ok=True )
for task_num in range (1 , tasks + 1 ):
start_idx = (task_num - 1 ) * pairs_per_task
end_idx = min (start_idx + pairs_per_task, total_pairs)
task_images = image_files[start_idx:end_idx]
task_labels = [Path(f).stem + '.txt' for f in task_images]
zip_path = os.path.join(output_dir, f"task_{task_num} .zip" )
print (f"创建任务包 {task_num} :" )
print (f"- 包含图片:{len (task_images)} 张" )
print (f"- 包含标注:{len (task_labels)} 个" )
print (f"- 保存到:{zip_path} " )
with zipfile.ZipFile(zip_path, 'w' , zipfile.ZIP_DEFLATED) as zipf:
for img in task_images:
img_path = os.path.join(images_dir, img)
zipf.write(img_path, f"images/{img} " )
for label in task_labels:
label_path = os.path.join(labels_dir, label)
if os.path.exists(label_path):
zipf.write(label_path, f"labels/{label} " )
else :
print (f"⚠️ 缺失标注文件:{label} " )
if label_txt is not False :
label_info = Path(label_txt).open ("r" ).read()
zipf.writestr(f"labels/labels.txt" , label_info)
print ("-" * 50 )
print (f"\n🎉 任务包创建完成!共生成 {tasks} 个压缩包,保存在:{output_dir} " )
if __name__ == "__main__" :
root_path = os.path.dirname(__file__)
dataset_dir = os.path.join(root_path, "new1" )
output_dir = os.path.join(root_path, "package" )
label_txt = os.path.join(root_path, "labels.txt" )
num_tasks = 4
create_task_packs(
images_dir=os.path.join(dataset_dir, "images" ),
labels_dir=os.path.join(dataset_dir, "labels" ),
output_dir=output_dir,
tasks=num_tasks,
)
希望这份分享能帮助大家在接下来的比赛中取得好成绩。
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
随机西班牙地址生成器 随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
Gemini 图片去水印 基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online