import json
import os
import shutil
import random
COCO_ROOT = "./coco2017"
OUTPUT_ROOT = "./robot_dataset"
TARGET_CATS = {"person": 1, "chair": 62, "dining table": 67, "bench": 13, "bottle": 44}
TRAIN_NUM = 7000
VAL_NUM = 1000
os.makedirs(os.path.join(OUTPUT_ROOT, "images/train"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_ROOT, "images/val"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_ROOT, "labels/train"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_ROOT, "labels/val"), exist_ok=True)
def coco2yolo(anno_file, img_dir, output_img_dir, output_label_dir, target_cats, max_num):
"""
COCO 格式转 YOLO 格式,同时筛选指定类别与样本数
anno_file: COCO 标注文件路径
img_dir: COCO 图片目录
output_img_dir: 输出图片目录
output_label_dir: 输出标注目录
target_cats: 目标类别字典 {名称:COCO ID}
max_num: 最大筛选样本数
"""
with open(anno_file, "r", encoding="utf-8") as f:
coco_data = json.load(f)
img_id2info = {img["id"]: (img["file_name"], img["width"], img["height"]) for img in coco_data["images"]}
coco_id2yolo_idx = {v: k for k, v in enumerate(target_cats.values())}
img_anno = {}
for ann in coco_data["annotations"]:
coco_cat_id = ann["category_id"]
if coco_cat_id not in coco_id2yolo_idx:
continue
img_id = ann["image_id"]
if img_id not in img_anno:
img_anno[img_id] = []
x, y, w, h = ann["bbox"]
img_name, img_w, img_h = img_id2info[img_id]
x_center = (x + w / 2) / img_w
y_center = (y + h / 2) / img_h
w_norm = w / img_w
h_norm = h / img_h
yolo_idx = coco_id2yolo_idx[coco_cat_id]
img_anno[img_id].append(f"{yolo_idx} {x_center:.6f} {y_center:.6f} {w_norm:.6f} {h_norm:.6f}")
selected_img_ids = random.sample(list(img_anno.keys()), min(max_num, len(img_anno)))
print(f"筛选出{len(selected_img_ids)}张样本")
for img_id in selected_img_ids:
img_name, _, _ = img_id2info[img_id]
shutil.copy(os.path.join(img_dir, img_name), os.path.join(output_img_dir, img_name))
label_name = img_name.replace(".jpg", ".txt")
with open(os.path.join(output_label_dir, label_name), "w", encoding="utf-8") as f:
f.write("\n".join(img_anno[img_id]))
print("开始处理训练集...")
coco2yolo(
anno_file=os.path.join(COCO_ROOT, "annotations/instances_train2017.json"),
img_dir=os.path.join(COCO_ROOT, "train2017"),
output_img_dir=os.path.join(OUTPUT_ROOT, "images/train"),
output_label_dir=os.path.join(OUTPUT_ROOT, "labels/train"),
target_cats=TARGET_CATS,
max_num=TRAIN_NUM
)
print("开始处理验证集...")
coco2yolo(
anno_file=os.path.join(COCO_ROOT, "annotations/instances_val2017.json"),
img_dir=os.path.join(COCO_ROOT, "val2017"),
output_img_dir=os.path.join(OUTPUT_ROOT, "images/val"),
output_label_dir=os.path.join(OUTPUT_ROOT, "labels/val"),
target_cats=TARGET_CATS,
max_num=VAL_NUM
)
print(f"机器人场景数据集整理完成,保存至{OUTPUT_ROOT},共{TRAIN_NUM+VAL_NUM}张样本")
from ultralytics.data.utils import check_dataset
import os
import random
robot_yaml = "./robot_dataset/robot_dataset.yaml"
with open(robot_yaml, "w", encoding="utf-8") as f:
f.write(f"""
path: {os.path.abspath("./robot_dataset")} # 数据集绝对路径
train: images/train
val: images/val
nc: 5 # 类别数
names: ['person', 'chair', 'dining table', 'bench', 'bottle'] # 类别名称,与 YOLO 索引一致
""")
check_dataset(robot_yaml)
print("数据集格式验证通过!")
label_dir = "./robot_dataset/labels/train"
random_label = random.choice(os.listdir(label_dir))
with open(os.path.join(label_dir, random_label), "r", encoding="utf-8") as f:
anno_content = f.read()
print(f"随机抽查标注文件{random_label}内容:\n{anno_content}")