import os
import cv2
import numpy as np
def fuse_rgb_thermal(rgb_dir, thermal_dir, output_dir):
os.makedirs(output_dir, exist_ok=True)
rgb_files = sorted(os.listdir(rgb_dir))
for rgb_file in rgb_files:
if not rgb_file.lower().endswith(('.jpg', '.png')):
continue
thermal_file = rgb_file
rgb_path = os.path.join(rgb_dir, rgb_file)
thermal_path = os.path.join(thermal_dir, thermal_file)
rgb = cv2.imread(rgb_path)
thermal = cv2.imread(thermal_path, cv2.IMREAD_GRAYSCALE)
if rgb is None or thermal is None:
print(f"Skip {rgb_file}")
continue
thermal = np.expand_dims(thermal, axis=2)
fused = np.concatenate([rgb, thermal], axis=2)
np.save(os.path.join(output_dir, rgb_file.replace('.jpg', '.npy')), fused)
if __name__ == '__main__':
fuse_rgb_thermal(
rgb_dir='raw_data/train/rgb',
thermal_dir='raw_data/train/thermal',
output_dir='datasets/images/train'
)
nc: 1
scales: [0.33, 0.50]
backbone:
- [-1, 1, Conv, [64, 3, 2]]
- [-1, 1, Conv, [128, 3, 2]]
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]]
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]]
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]]
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]]
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 6], 1, Concat, [1]]
- [-1, 3, C2f, [512]]
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 4], 1, Concat, [1]]
- [-1, 3, C2f, [256]]
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 12], 1, Concat, [1]]
- [-1, 3, C2f, [512]]
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 9], 1, Concat, [1]]
- [-1, 3, C2f, [1024]]
- [[15, 18, 21], 1, Detect, [nc]]
from ultralytics import YOLO
import torch
import numpy as np
from pathlib import Path
from ultralytics.data.dataset import YOLODataset
def load_image_npy(self, i):
f = self.im_files[i]
if f.endswith('.npy'):
im = np.load(f)
im = im.transpose(2, 0, 1)
im = torch.from_numpy(im).float()
h, w = im.shape[1], im.shape[2]
return im, h, w
else:
return self._orig_load_image(i)
def get_label_file_npy(self, img_path):
return str(Path(img_path).with_suffix('.txt'))
YOLODataset._orig_load_image = YOLODataset.load_image
YOLODataset.load_image = load_image_npy
YOLODataset.get_label_file = get_label_file_npy
if __name__ == '__main__':
model = YOLO('models/yolov8s-rgbt.yaml')
model.train(
data='rgbt_drone.yaml',
epochs=100,
imgsz=512,
batch=16,
name='drone_rgbdet',
project='runs',
device=0,
cache=False,
workers=4
)