RTDETR模型一键训练/预测(执行train.sh与detect.sh)

RTDETR模型一键训练/预测(执行train.sh与detect.sh)

文章目录

引言

本文章基于客户一键训练与测试需求,我使用u公司的yolov8集成的RTDETR模型改成较为保姆级的一键操作的训练/预测方式,也特别适合新手或想偷懒转换数据格式的朋友们。本文一键体现数据格式为图像与xml,调用train.sh与detect.sh可完成模型的训练与预测。而为完成该操作,模型内嵌入xml转RTDETR的txt格式、自动分配训练/验证集、自动切换环境等内容。接下来,我将介绍如何操作,并附修改源码。

源码链接:我已上传个人资源,请自行下载!

一、配置参数设置

该文件是RTDETR数据转换配置和模型使用参数,被我修改满足一键训练与测试文件的配置参数。包含将图像与xml文件数据格式转为模型训练格式数据,只需要提供xml与图像文件夹,可完成数据转换,详情如下:

# 设置img与xml的文件路径,也可为同一个文件,按照xml选择img img_path: C:/Users/Administrator/Desktop/rtdetr/example_template/data # xml_path: C:/Users/Administrator/Desktop/rtdetr/example_template/data # 设置数据集训练与验证集测试的比率,和小于1,通常test比率不设置为0 train_rate: 0.8 val_rate: 0.2 test_rate: path: C:/Users/Administrator/Desktop/rtdetr/example_template/rtdert_data # 必填,转换存放数据集文件夹,必须设置 train: images/train # 不设置 val: images/val # 不设置 test: # Classes names: 0: person 1: bicycle 2: car 3: motorcycle 4: airplane 5: bus 6: train 7: truck 8: boat 9: traffic light 10: fire hydrant 11: stop sign 12: parking meter 13: bench 14: bird 15: cat 16: dog 17: horse 18: sheep 19: cow 20: elephant 21: bear 22: zebra 23: giraffe 24: backpack 25: umbrella 26: handbag 27: tie 28: suitcase 29: frisbee 30: skis 31: snowboard 32: sports ball 33: kite 34: baseball bat 35: baseball glove 36: skateboard 37: surfboard 38: tennis racket 39: bottle 40: wine glass 41: cup 42: fork 43: knife 44: spoon 45: bowl 46: banana 47: apple 48: sandwich 49: orange 50: broccoli 51: carrot 52: hot dog 53: pizza 54: donut 55: cake 56: chair 57: couch 58: potted plant 59: bed 60: dining table 61: toilet 62: tv 63: laptop 64: mouse 65: remote 66: keyboard 67: cell phone 68: microwave 69: oven 70: toaster 71: sink 72: refrigerator 73: book 74: clock 75: vase 76: scissors 77: teddy bear 78: hair drier 79: toothbrush 

二、数据格式转换代码

该文件代码提供了xml格式转rtdetr模型需要格式,基本是属于逻辑,代码能力较为基础,我不在介绍,代码如下:

import pandas as pd import cv2 from tqdm import tqdm import os import numpy as np import json import xml.etree.ElementTree as ET from lxml.etree import Element, SubElement, tostring, ElementTree from xml.dom.minidom import parseString import random import shutil import yaml img_format = ['.jpg', '.png', '.bmp'] def build_dir(root): import os if not os.path.exists(root): os.makedirs(root) return root def del_dir(root): import os if os.path.exists(root): shutil.rmtree(root) return root ############################################生成xml方法########################## def product_xml(name_img, boxes, codes, img=None, wh=None): ''' :param img: 以读好的图片 :param name_img: 图片名字,如'xxx.jpg' :param boxes: box为列表 :param codes: 为列表 :return: ''' if img is not None: width = img.shape[0] height = img.shape[1] else: assert wh is not None width = wh[0] height = wh[1] node_root = Element('annotation') node_folder = SubElement(node_root, 'folder') node_folder.text = 'VOC2007' node_filename = SubElement(node_root, 'filename') node_filename.text = name_img # 图片名字 node_size = SubElement(node_root, 'size') node_width = SubElement(node_size, 'width') node_width.text = str(height) node_height = SubElement(node_size, 'height') node_height.text = str(width) node_depth = SubElement(node_size, 'depth') node_depth.text = '3' for i, code in enumerate(codes): box = [boxes[i][0], boxes[i][1], boxes[i][2], boxes[i][3]] node_object = SubElement(node_root, 'object') node_name = SubElement(node_object, 'name') node_name.text = code node_difficult = SubElement(node_object, 'difficult') node_difficult.text = '0' node_bndbox = SubElement(node_object, 'bndbox') node_xmin = SubElement(node_bndbox, 'xmin') node_xmin.text = str(int(box[0])) node_ymin = SubElement(node_bndbox, 'ymin') node_ymin.text = str(int(box[1])) node_xmax = SubElement(node_bndbox, 'xmax') node_xmax.text = str(int(box[2])) node_ymax = SubElement(node_bndbox, 'ymax') node_ymax.text = str(int(box[3])) xml = tostring(node_root, pretty_print=True) # 格式化显示,该换行的换行 dom = parseString(xml) name = name_img[:-4] + '.xml' tree = ElementTree(node_root) print('name:{},dom:{}'.format(name, dom)) return tree, name def product_xml_demo(): ''' 通过box与cat信息为图片产生xml文件 ''' img_root = r'C:\Users\Administrator\Desktop\123\1.jpg' write_img_name = 'hhhaaa.jpg' bboxes_lst = [[22, 32, 46, 89]] cat_lst = ['cat'] img = cv2.imread(img_root) tree, xml_name = product_xml(write_img_name, bboxes_lst, cat_lst, img=img) tree.write(os.path.join('./', xml_name)) ############################################xml转yolo的txt########################## def read_xml(xml_root): ''' :param xml_root: .xml文件 :return: dict('cat':['cat1',...],'bboxes':[[x1,y1,x2,y2],...],'whd':[w ,h,d]) ''' dict_info = {'cat': [], 'bboxes': [], 'box_wh': [], 'whd': []} if os.path.splitext(xml_root)[-1] == '.xml': tree = ET.parse(xml_root) # ET是一个xml文件解析库,ET.parse()打开xml文件。parse--"解析" root = tree.getroot() # 获取根节点 whd = root.find('size') whd = [whd.find('width').text, whd.find('height').text, whd.find('depth').text] for obj in root.findall('object'): # 找到根节点下所有“object”节点 cat = str(obj.find('name').text) # 找到object节点下name子节点的值(字符串) bbox = obj.find('bndbox') x1, y1, x2, y2 = [int(bbox.find('xmin').text), int(bbox.find('ymin').text), int(bbox.find('xmax').text), int(bbox.find('ymax').text)] b_w = x2 - x1 + 1 b_h = y2 - y1 + 1 dict_info['cat'].append(cat) dict_info['bboxes'].append([x1, y1, x2, y2]) dict_info['box_wh'].append([b_w, b_h]) dict_info['whd'].append(whd) else: print('[inexistence]:{} suffix is not xml '.format(xml_root)) return dict_info def write_txt(text_lst, out_txt=None): ''' 每行内容为列表,将其写入text中 ''' out_dir = out_txt if out_txt is not None else 'classes.txt' file_write_obj = open(out_dir, 'w', encoding='utf-8') # 以写的方式打开文件,如果文件不存在,就会自动创建 for text in text_lst: file_write_obj.writelines(str(text)) file_write_obj.write('\n') file_write_obj.close() def xml2yolotxt(xml_root, img_root=None, save_txt=None, labels_name_lst=None): ''' :param xml_root: xml的路径 :param img_root:图像路径,可提供也可不提供,提供主要获得图像的高宽 :param out_file:保存txt路径的文件夹 :param labels_name_lst:提供训练列表,xml中出现类别与列表对应,如['pedes', 'elec', 'car', 'truck', 'bus', 'tricycle'] pedes表示0,elec表示1,car表示2等 :return: ''' if labels_name_lst is None: raise ValueError("lack labels list ") if save_txt is None: raise ValueError("lack saving root for txt file ") xml_info = read_xml(xml_root) if img_root is not None: # 从中提取W与H img = cv2.imread(img_root) H, W = img.shape[:2] else: whd = xml_info['whd'][0] W, H = float(whd[0]), float(whd[1]) boxes_lst = xml_info['bboxes'] labels_lst = xml_info['cat'] yolotxt_lst = [] for i, b in enumerate(boxes_lst): label = labels_lst[i] if label in labels_name_lst: label_idx = list(labels_name_lst).index(label) bw, bh = b[2] - b[0], b[3] - b[1] x, y = b[0] + bw / 2, b[1] + bh / 2 x, y, w, h = x / W, y / H, bw / W, bh / H # yolotxt = str(cat_lst[i]) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(h) yolotxt = str(label_idx) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(h) yolotxt_lst.append(yolotxt) if len(yolotxt_lst) > 0: write_txt(yolotxt_lst, save_txt) def convert_data_train(xml_path, img_path, out_file_path, labels_name_lst, **kwargs): ''' xml_path:xml文件夹的路径 img_path:图片文件夹的路径 out_file_path:模型训练的文件夹,用于yolo模型训练 labels_name_lst:标签列表,模型只转换与训练的标签列表 kwargs:其它参数 ''' print('\n convert data...') img_suffix = kwargs.get('img_suffix') if kwargs.get('img_suffix') else 4 img_names = [name for name in os.listdir(img_path) if name[-4:] in img_format] img_names_no_suffix = [name[:-img_suffix] for name in img_names] xml_names_temp = [name for name in os.listdir(xml_path) if name[-3:] == 'xml'] N = len(xml_names_temp) N_idx = [i for i in range(N)] random.shuffle(N_idx) xml_names = [xml_names_temp[i] for i in N_idx] train_N = N * kwargs.get('train_rate') if kwargs.get('train_rate') else 0.7 * N val_N = N * kwargs.get('val_rate') if kwargs.get('val_rate') else 0.3 * N test_N = N * kwargs.get('test_rate') if kwargs.get('test_rate') else 0 if (train_N / N + val_N / N + test_N / N) > 1: raise ValueError( "rate of datasets error,sum>1, train_rate:{}\tval_rate:{}\ttest_rate{}".format(train_N / N, val_N / N, test_N / N)) # 构建训练文件 images_path = os.path.join(out_file_path, 'images') labels_path = os.path.join(out_file_path, 'labels') del_dir(images_path) del_dir(labels_path) build_dir(images_path) build_dir(labels_path) train_img_path = build_dir(os.path.join(images_path, 'train')) val_img_path = build_dir(os.path.join(images_path, 'val')) test_img_path = build_dir(os.path.join(images_path, 'test')) train_label_path = build_dir(os.path.join(labels_path, 'train')) val_label_path = build_dir(os.path.join(labels_path, 'val')) test_label_path = build_dir(os.path.join(labels_path, 'test')) problem_xmls=[] for i in tqdm(range(int(train_N))): xml_name = xml_names[i] xml_root = os.path.join(xml_path, xml_name) if xml_name[:-4] in list(img_names_no_suffix): img_idx = list(img_names_no_suffix).index(xml_name[:-4]) img_name = img_names[img_idx] img_root = os.path.join(img_path, img_name) save_txt = os.path.join(train_label_path, xml_name[:-3] + 'txt') try: xml2yolotxt(xml_root, img_root=img_root, save_txt=save_txt, labels_name_lst=labels_name_lst) except: problem_xmls.append(xml_root) break shutil.copy(img_root, os.path.join(train_img_path, img_name)) print('\nfinishing vonvert of train data,train_rate:\t{}\t train count:\t{} \n'.format(train_N / N, int(train_N))) for i in tqdm(range(int(train_N), int(train_N + val_N))): xml_name = xml_names[i] xml_root = os.path.join(xml_path, xml_name) if xml_name[:-4] in list(img_names_no_suffix): img_idx = list(img_names_no_suffix).index(xml_name[:-4]) img_name = img_names[img_idx] img_root = os.path.join(img_path, img_name) save_txt = os.path.join(val_label_path, xml_name[:-3] + 'txt') try: xml2yolotxt(xml_root, img_root=img_root, save_txt=save_txt, labels_name_lst=labels_name_lst) except: problem_xmls.append(xml_root) break # xml2yolotxt(xml_root, img_root=img_root, save_txt=save_txt, labels_name_lst=labels_name_lst) shutil.copy(img_root, os.path.join(val_img_path, img_name)) print('\nfinishing vonvert of val data, val_rate:\t{}\t val count:\t{} \n'.format(val_N / N, int(val_N))) for i in tqdm(range(int(train_N + val_N), int(train_N + val_N + test_N))): xml_name = xml_names[i] xml_root = os.path.join(xml_path, xml_name) if xml_name[:-4] in list(img_names_no_suffix): img_idx = list(img_names_no_suffix).index(xml_name[:-4]) img_name = img_names[img_idx] img_root = os.path.join(img_path, img_name) save_txt = os.path.join(test_label_path, xml_name[:-3] + 'txt') try: xml2yolotxt(xml_root, img_root=img_root, save_txt=save_txt, labels_name_lst=labels_name_lst) except: problem_xmls.append(xml_root) break # xml2yolotxt(xml_root, img_root=img_root, save_txt=save_txt, labels_name_lst=labels_name_lst) shutil.copy(img_root, os.path.join(test_img_path, img_name)) print('\nfinishing vonvert of test data, test_rate:\t{}\t test count:\t{} \n'.format(test_N / N, int(test_N))) print( '\n problem xml:{}\n'.format(len(problem_xmls)) ) for probel_path in problem_xmls: print(probel_path) def product_yolo_dataset(yaml_path): f = open(yaml_path, 'rb') cfg = yaml.load(f, Loader=yaml.FullLoader) img_path = cfg['img_path'] xml_path = cfg['xml_path'] out_file_path = cfg['path'] labels_name_lst = [v for k,v in cfg['names'].items()] kwargs = {"train_rate": cfg['train_rate'], "val_rate": cfg['val_rate'], "test_rate": cfg['test_rate']} convert_data_train(xml_path, img_path, out_file_path, labels_name_lst, **kwargs) return cfg def yolo_dataset_demo(): ''' 将xml数据格式转换为yolo格式的方法 ''' yaml_path = 'coco128_auto.yaml' product_yolo_dataset(yaml_path) def read_yaml(yaml_path): f = open(yaml_path, 'rb') cfg = yaml.load(f, Loader=yaml.FullLoader) return cfg def del_runsfile(): from pathlib import Path import sys FILE = Path(__file__).resolve() ROOT = FILE.parents[0] # YOLOv5 root directory if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative del_dir(ROOT/'runs/detect/train') if __name__ == '__main__': yolo_dataset_demo() del_runsfile() # 帮忙删除runs文件 

注:该代码只需图像文件与对应xml文件,即可按照比列转换train、val、test数据。

三、一键训练/预测的sh内容

1、训练sh文件(train.sh)内容

训练文件为sh文件,只需通过以下命令,实现训练。

sh train.sh 

该文件包含虚拟环境切换与自动调用模型训练,其详情如下:

# train.sh train_weight=/home/ubuntu/Project/tj/auto_project/RTDETR/model_rtdetr/rtdetr-l.pt echo -e "\n"train time $(date "+%Y-%m-%d")"\n" # 更换虚拟环境 __conda_setup="$('/home/ubuntu/miniconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)" if [ $? -eq 0 ]; then eval "$__conda_setup" else if [ -f "/home/ubuntu/miniconda3/etc/profile.d/conda.sh" ]; then . "/home/ubuntu/miniconda3/etc/profile.d/conda.sh" else export PATH="/home/ubuntu/miniconda3/bin:$PATH" fi fi unset __conda_setup conda activate yolov8 cur_dir=$(cd `dirname $0`;pwd) # 获得当前路径 echo -e "\ncur_dir:"${cur_dir}"\n" yaml_dir=$cur_dir/coco128_auto.yaml echo -e "\nyaml_dir:"${yaml_dir}"\n" #save_dir=$cur_dir/runs/train #echo -e "\nsave_dir:"$save_dir"\n" # # #if [ -d ${save_dir} ];then # echo "save_dir 文件存在" # else # echo "save_dir文件不存在-->创建文件" # mkdir -p $save_dir #fi cd ${cur_dir} ls echo -e "\n\n\n\t\t\t start train ... \n\n\n" # xml数据转txt数据格式 python auto_tools.py yolo train model=$train_weight data=$yaml_dir epochs=300 imgsz=640 batch=24 amp=False name=train/exp 

2、train.sh内容说明

1、开头有一个重要预训练权重路径,确定使用rtdetr哪个模型,默认为l模型
train_weight=/home/oem/Project/tj/auto_project/RTDETR/model_rtdetr/rtdetr-l.pt

2、最后一句模型运行命令,默认参数命令如下:
yolo train model= t r a i n w e i g h t d a t a = train_weight data= trainw​eightdata=yaml_dir epochs=300 imgsz=640 batch=12 amp=False name=train/exp

3、添加参数
显卡选择参数device,添加 device=0,1或device=0等形式

3、预测sh文件(detect.sh)介绍

预测文件为sh文件,只需通过以下命令,实现训练。

sh detect.sh 

该文件包含虚拟环境切换与自动调用模型预测,其详情如下:

 # detect.sh echo -e "\n"detect time $(date "+%Y-%m-%d")"\n" # 更换虚拟环境 __conda_setup="$('/home/ubuntu/miniconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)" if [ $? -eq 0 ]; then eval "$__conda_setup" else if [ -f "/home/ubuntu/miniconda3/etc/profile.d/conda.sh" ]; then . "/home/ubuntu/miniconda3/etc/profile.d/conda.sh" else export PATH="/home/ubuntu/miniconda3/bin:$PATH" fi fi unset __conda_setup conda activate yolov8 cur_dir=$(cd `dirname $0`;pwd) # 获得当前路径 echo -e "\ncur_dir:"${cur_dir}"\n" yaml_dir=$cur_dir/coco128_auto.yaml echo -e "\nyaml_dir:"${yaml_dir}"\n" save_dir=$cur_dir/runs/detect echo -e "\nsave_dir:"$save_dir"\n" if [ -d ${save_dir} ];then echo "save_dir 文件存在" else echo "save_dir文件不存在-->创建文件" mkdir -p $save_dir fi cd ${cur_dir} ls echo -e "\n\n\n\t\t\t start detect ... \n\n\n" python predect.py --conf_thres 0.25 

4、detect.sh内容说明

1、最后一句模型运行命令,默认参数命令如下:
python predect.py --conf_thres 0.25

2、添加权重与图片保存路径,如下格式
–weights /home/ubuntu/runs/detect/train/exp/weights/best.pt
–save_dir /home/ubuntu/runs/detect/predect/exp

四、训练、预测运行结果显示

1、训练效果展示

www.zeeklog.com - RTDETR模型一键训练/预测(执行train.sh与detect.sh)

2、预测效果展示

www.zeeklog.com - RTDETR模型一键训练/预测(执行train.sh与detect.sh)

总结

本文一个目的,傻瓜式训练与预测,通过sh脚本实现3个任务,
①、虚拟环境自动切换
②、数据格式自动转换,输入为图像文件与对应xml文件自动完成rtdetr模型训练与预测数据格式
③、模型自动训练与预测,且只需执行sh train.sh或 sh detect.sh即可实现

整体脚本:

文件整体格式如下图:

www.zeeklog.com - RTDETR模型一键训练/预测(执行train.sh与detect.sh)

Read more

最新电子电气架构(EEA)调研-3

而新一代的强实时性、高确定性,以及满足CAP定理的同步分布式协同技术(SDCT),可以实现替代TSN、DDS的应用,且此技术已经在无人车辆得到验证,同时其低成本学习曲线、无复杂二次开发工作,将开发人员的劳动强度、学习曲线极大降低,使开发人员更多的去完成算法、执行器功能完善。 五、各大车厂的EEA 我们调研策略是从公开信息中获得各大车厂的EEA信息,并在如下中进行展示。 我们集中了华为、特斯拉、大众、蔚来、小鹏、理想、东风(岚图)等有代表领先性的车辆电子电气架构厂商。        1、华为 图12 华为的CCA电子电气架构              (1)华为“计算+通信”CC架构的三个平台                         1)MDC智能驾驶平台;                         2)CDC智能座舱平台                         3)VDC整车控制平台。        联接指的是华为智能网联解决方案,解决车内、车外网络高速连接问题,云服务则是基于云计算提供的服务,如在线车主服务、娱乐和OTA等。 华

By Ne0inhk
Apache IoTDB 架构特性与 Prometheus+Grafana 监控体系部署实践

Apache IoTDB 架构特性与 Prometheus+Grafana 监控体系部署实践

Apache IoTDB 架构特性与 Prometheus+Grafana 监控体系部署实践 文章目录 * Apache IoTDB 架构特性与 Prometheus+Grafana 监控体系部署实践 * Apache IoTDB 核心特性与价值 * Apache IoTDB 监控面板完整部署方案 * 安装步骤 * 步骤一:IoTDB开启监控指标采集 * 步骤二:安装、配置Prometheus * 步骤三:安装grafana并配置数据源 * 步骤四:导入IoTDB Grafana看板 * TimechoDB(基于 Apache IoTDB)增强特性 * 总结与应用场景建议 Apache IoTDB 核心特性与价值 Apache IoTDB 专为物联网场景打造的高性能轻量级时序数据库,以 “设备 - 测点” 原生数据模型贴合物理设备与传感器关系,通过高压缩算法、百万级并发写入能力和毫秒级查询响应优化海量时序数据存储成本与处理效率,同时支持边缘轻量部署、

By Ne0inhk
SQL Server 2019安装教程(超详细图文)

SQL Server 2019安装教程(超详细图文)

SQL Server 介绍) SQL Server 是由 微软(Microsoft) 开发的一款 关系型数据库管理系统(RDBMS),支持结构化查询语言(SQL)进行数据存储、管理和分析。自1989年首次发布以来,SQL Server 已成为企业级数据管理的核心解决方案,广泛应用于金融、电商、ERP、CRM 等业务系统。它提供高可用性、安全性、事务处理(ACID)和商业智能(BI)支持,并支持 Windows 和 Linux 跨平台部署。 一、获取 SQL Server 2019 安装包 1. 官方下载方式 前往微软官网注册账号后,即可下载 SQL Server Developer 版本(

By Ne0inhk