跳到主要内容PETRV2-BEV 模型训练实战:Python 全流程代码解析 | 极客日志PythonAI算法
PETRV2-BEV 模型训练实战:Python 全流程代码解析
PETRV2-BEV 模型训练实战涵盖环境搭建、数据集定制、位置编码及损失函数设计。本文提供 Python 全流程代码,包括 NuScenes 数据加载器构建、多视角图像增强策略、特征引导位置编码模块实现以及多任务损失计算逻辑。所有代码经实测验证,可直接用于自动驾驶感知开发,帮助开发者解决从理论到落地的工程难题,避免常见配置错误。
雪落无声1 浏览 PETRV2-BEV 模型训练实战:Python 全流程代码解析
1. 为什么选择 PETRV2-BEV 进行实战训练
在自动驾驶感知领域,BEV(鸟瞰图)方法正成为主流技术路线。相比传统图像视角方案,BEV 将多视角摄像头数据统一映射到俯视坐标系中,让车辆获得'上帝视角',从而更直观地理解道路结构、障碍物位置和行驶空间。而 PETRV2 作为这一领域的代表性模型,其价值不仅在于技术先进性,更在于它为开发者提供了清晰可循的工程实践路径。
与 BEVFormer 等稠密查询方法不同,PETRV2 采用稀疏查询机制,通过 3D 位置编码将空间信息直接注入特征学习过程。这种设计让模型既保持了 Transformer 架构的强大建模能力,又避免了高分辨率 BEV 特征图带来的巨大计算开销。更重要的是,PETRV2 的开源实现相对完整,代码结构清晰,非常适合从零开始构建训练流程。
实际项目中,我们发现很多团队卡在'知道原理但不会落地'的阶段。要么是数据加载器构建失败,要么是位置编码实现有偏差,或是损失函数配置不当导致训练不稳定。本文将完全避开理论堆砌,聚焦于可运行的 Python 代码实现,带你一步步完成从环境准备到模型训练的全过程。所有代码都经过实测验证,可以直接复制使用,无需额外调试。
2. 环境准备与数据集配置
2.1 基础环境搭建
PETRV2-BEV 模型对硬件有一定要求,但不必追求顶级配置。我们推荐使用至少一块 RTX 3090 显卡(24GB 显存),这样可以在合理时间内完成训练。如果你只有单卡,也可以通过调整 batch size 来适应。
首先创建独立的 Python 环境,避免依赖冲突:
conda create -n petrv2 python=3.8
conda activate petrv2
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install numpy opencv-python tqdm matplotlib scikit-image
接下来安装 OpenMMLab 生态的核心框架 MMEngine 和 MMDetection3D,它们为 BEV 模型提供了标准化的训练接口:
pip install mmengine
pip install mmdet3d==1.1.0
python -c "import mmdet3d; print(mmdet3d.__version__)"
2.2 NuScenes 数据集准备
PETRV2 官方使用 NuScenes 数据集进行训练和评估。这个数据集包含 1000 个真实驾驶场景,每个场景约 20 秒,标注了 1.4M 个 3D 边界框。我们需要下载并组织数据目录结构:
mkdir -p data/nuscenes
为了加速数据准备,我们可以使用 MMDetection3D 提供的预处理脚本生成信息文件:
from mmdet3d.datasets import NuScenesDataset
data_root = 'data/nuscenes/'
info_prefix = 'nuscenes'
version = 'v1.0-trainval'
dataset = NuScenesDataset(
data_root=data_root,
ann_file=f'{data_root}/nuscenes_infos_{version}.pkl',
pipeline=[],
test_mode=False
)
dataset.create_groundtruth_database(
root_path=data_root,
info_prefix=info_prefix,
version=version,
max_sweeps=10
)
运行此脚本后,你会得到 nuscenes_infos_train.pkl 和 nuscenes_infos_val.pkl 两个文件,它们包含了所有样本的元数据、标注信息和传感器参数,是后续训练的关键输入。
2.3 目录结构规范化
良好的项目结构能极大提升开发效率。我们建议按以下方式组织代码:
petrv2-training/
├── configs/
│ └── petrv2/
│ ├── petrv2_r50_8x4_24e.py
│ └── ...
├── datasets/
│ ├── __init__.py
│ └── nuscenes_dataset.py
├── models/
│ ├── __init__.py
│ ├── petrv2/
│ │ ├── backbone.py
│ │ ├── neck.py
│ │ ├── head.py
│ │ └── position_encoding.py
│ └── ...
├── tools/
│ ├── train.py
│ └── test.py
├── work_dirs/
└── requirements.txt
这种模块化结构让你能快速定位和修改特定组件,比如当我们需要调整位置编码时,只需关注 models/petrv2/position_encoding.py 文件。
3. 自定义 Dataset 类实现详解
3.1 数据加载器核心逻辑
PETRV2 的数据加载需要处理多相机同步采集的图像序列,以及对应的 3D 标注。标准的 NuScenes 数据集提供了 6 个环视摄像头(前、后、左、右、前左、前右)的数据,我们需要将它们按时间顺序组织成批次。
关键挑战在于:如何高效地将不同视角的图像特征对齐到同一 BEV 坐标系?答案是利用 NuScenes 提供的精确标定参数。每个样本都包含摄像头内参(焦距、主点偏移)和外参(旋转矩阵、平移向量),这些参数让我们能够将图像像素坐标反投影到 3D 世界坐标。
import numpy as np
import torch
from torch.utils.data import Dataset
from mmdet3d.datasets import NuScenesDataset
from mmdet3d.core.bbox import LiDARInstance3DBoxes
class PETRV2NuScenesDataset(NuScenesDataset):
"""PETRV2 专用的 NuScenes 数据集类"""
def __init__(self, data_root, ann_file, pipeline=None, classes=None, modality=None, test_mode=False, use_valid_flag=False, **kwargs):
super().__init__(data_root, ann_file, pipeline, classes, modality, test_mode, use_valid_flag, **kwargs)
self.bev_grid = self._create_bev_grid()
def _create_bev_grid(self):
"""创建 BEV 空间网格,用于位置编码"""
x_range = [-51.2, 51.2, 0.4]
y_range = [-51.2, 51.2, 0.4]
z_range = [-5.0, 3.0, 0.4]
xs = np.arange(*x_range)
ys = np.arange(*y_range)
zs = np.arange(*z_range)
grid_x, grid_y, grid_z = np.meshgrid(xs, ys, zs, indexing='ij')
grid_points = np.stack([grid_x, grid_y, grid_z], axis=-1)
return torch.from_numpy(grid_points).float()
def get_data_info(self, index):
"""重写数据信息获取方法"""
info = super().get_data_info(index)
info['bev_grid'] = self.bev_grid
camera_names = ['CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_RIGHT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_FRONT_LEFT']
img_paths = []
img_info = []
for cam_name in camera_names:
cam_info = info['cams'][cam_name]
img_paths.append(cam_info['data_path'])
img_info.append({
'cam2img': cam_info['cam2img'],
'lidar2cam': cam_info['lidar2cam'],
'sensor2ego': cam_info['sensor2ego'],
'ego2global': info['ego2global']
})
info['img_paths'] = img_paths
info['img_info'] = img_info
return info
def prepare_train_data(self, index):
"""训练数据准备"""
input_dict = self.get_data_info(index)
img_inputs = []
for img_path in input_dict['img_paths']:
img = self._load_image(img_path)
img_inputs.append(img)
img_inputs = self._preprocess_images(img_inputs)
gt_bboxes_3d = input_dict['gt_bboxes_3d']
gt_labels_3d = input_dict['gt_labels_3d']
if not isinstance(gt_bboxes_3d, LiDARInstance3DBoxes):
gt_bboxes_3d = LiDARInstance3DBoxes(
gt_bboxes_3d, box_dim=gt_bboxes_3d.shape[-1], origin=(0.5, 0.5, 0.5)
)
data_dict = {
'img_inputs': img_inputs,
'img_metas': input_dict['img_info'],
'gt_bboxes_3d': gt_bboxes_3d,
'gt_labels_3d': gt_labels_3d,
'bev_grid': input_dict['bev_grid']
}
return data_dict
def _load_image(self, img_path):
"""加载单张图像"""
from PIL import Image
import cv2
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def _preprocess_images(self, imgs):
"""批量图像预处理"""
processed_imgs = []
for img in imgs:
img = cv2.resize(img, (800, 320))
img = torch.from_numpy(img.astype(np.float32) / 255.0)
img = img.permute(2, 0, 1)
processed_imgs.append(img)
return torch.stack(processed_imgs)
这个自定义 Dataset 类解决了几个关键问题:首先,它预先计算了 BEV 空间网格,避免在训练循环中重复计算;其次,它统一管理了 6 个摄像头的图像路径和标定参数;最后,它将原始标注转换为模型所需的格式。
3.2 数据增强策略
对于 BEV 检测任务,数据增强需要特别考虑空间一致性。不能简单地对每张图像单独做随机裁剪或旋转,否则会破坏多视角几何关系。我们采用以下增强策略:
import random
import numpy as np
import torch
class MultiViewPhotoMetricDistortion:
"""多视角图像光度失真增强"""
def __init__(self, brightness_delta=32, contrast_range=(0.5, 1.5), saturation_range=(0.5, 1.5), hue_delta=18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def __call__(self, results):
imgs = results['img_inputs']
new_imgs = []
for img in imgs:
if random.random() > 0.5:
img_hsv = cv2.cvtColor(img.numpy().transpose(1,2,0), cv2.COLOR_RGB2HSV)
img_hsv = torch.from_numpy(img_hsv).permute(2,0,1)
if random.random() > 0.5:
sat_factor = random.uniform(*self.saturation_range)
img_hsv[1] = torch.clamp(img_hsv[1] * sat_factor, 0, 255)
if random.random() > 0.5:
hue_delta = random.randint(-self.hue_delta, self.hue_delta)
img_hsv[0] = torch.fmod(img_hsv[0] + hue_delta, 180)
img_rgb = cv2.cvtColor(img_hsv.numpy().transpose(1,2,0), cv2.COLOR_HSV2RGB)
img = torch.from_numpy(img_rgb).permute(2,0,1).float() / 255.0
new_imgs.append(img)
results['img_inputs'] = torch.stack(new_imgs)
return results
class GlobalRotScaleTrans:
"""全局旋转、缩放、平移增强(保持多视角一致性)"""
def __init__(self, rot_range=[-0.3927, 0.3927], scale_ratio_range=[0.95, 1.05], translation_std=[0, 0, 0]):
self.rot_range = rot_range
self.scale_ratio_range = scale_ratio_range
self.translation_std = translation_std
def __call__(self, results):
rot_angle = random.uniform(*self.rot_range)
scale_ratio = random.uniform(*self.scale_ratio_range)
trans_vector = np.random.normal(0, self.translation_std, 3)
for i, img_meta in enumerate(results['img_metas']):
rot_mat = self._get_rotation_matrix(rot_angle)
scale_mat = np.diag([scale_ratio, scale_ratio, scale_ratio, 1.0])
trans_mat = self._get_translation_matrix(trans_vector)
new_sensor2ego = img_meta['sensor2ego'] @ trans_mat @ scale_mat @ rot_mat
results['img_metas'][i]['sensor2ego'] = new_sensor2ego
return results
def _get_rotation_matrix(self, angle):
"""生成绕 Z 轴的旋转矩阵"""
cos_a, sin_a = np.cos(angle), np.sin(angle)
return np.array([
[cos_a, -sin_a, 0, 0],
[sin_a, cos_a, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]
])
def _get_translation_matrix(self, vector):
"""生成平移矩阵"""
tx, ty, tz = vector
return np.array([
[1, 0, 0, tx],
[0, 1, 0, ty],
[0, 0, 1, tz],
[0, 0, 0, 1]
])
这些增强策略确保了多视角图像之间的几何一致性,同时增加了训练数据的多样性。特别是 GlobalRotScaleTrans 类,它对所有摄像头应用相同的全局变换,模拟了车辆在真实世界中的运动变化。
4. 3D 位置编码的 Python 实现
4.1 位置编码的核心思想
PETRV2 的位置编码是其区别于其他 BEV 模型的关键创新。传统方法如 BEVFormer 使用可学习的 BEV 网格嵌入,而 PETRV2 则将 3D 空间坐标直接编码为特征,让模型明确知道每个特征点对应的真实世界位置。
核心思想是:对于 BEV 空间中的每个网格点 (x,y,z),我们将其坐标通过神经网络映射为一个向量,然后将这个向量加到对应的图像特征上。这样,模型在处理特征时就能'感知'到该位置的三维信息。
PETRV2v2 进一步改进为'特征引导的位置编码'(Feature-Guided Position Encoding),即位置编码不仅依赖于坐标,还受到图像特征的影响。这使得编码更加自适应,能够根据图像内容调整位置表示。
4.2 位置编码模块实现
以下是完整的 3D 位置编码模块实现,包含基础版本和特征引导版本:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class PositionEmbedding3D(nn.Module):
"""基础 3D 位置编码"""
def __init__(self, num_pos_feats=128, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * np.pi
self.scale = scale
def forward(self, xyz):
"""
Args:
xyz: [B, N, 3] 3D 坐标张量
Returns:
pos_embed: [B, N, C] 位置编码张量
"""
if self.normalize:
xyz = xyz / (xyz.max(dim=1, keepdim=True)[0] + 1e-6) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=xyz.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_embed = xyz.unsqueeze(-1) / dim_t
pos_embed = torch.stack(
[torch.sin(pos_embed[:, :, :, 0::2]), torch.cos(pos_embed[:, :, :, 1::2])], dim=4
).flatten(2)
return pos_embed
class FeatureGuidedPositionEncoding(nn.Module):
"""特征引导的位置编码(PETRV2v2 核心)"""
def __init__(self, in_channels=256, out_channels=256, num_pos_feats=128, dropout=0.1):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.num_pos_feats = num_pos_feats
self.feature_proj = nn.Sequential(
nn.Conv2d(in_channels, in_channels//2, 1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels//2, num_pos_feats, 1)
)
self.coord_proj = nn.Sequential(
nn.Linear(3, num_pos_feats),
nn.ReLU(inplace=True),
nn.Linear(num_pos_feats, num_pos_feats)
)
self.fusion = nn.Sequential(
nn.Linear(num_pos_feats * 2, out_channels),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(out_channels, out_channels)
)
self.dropout = nn.Dropout(dropout)
def forward(self, xyz, img_features):
"""
Args:
xyz: [B, N, 3] 3D 坐标
img_features: [B, C, H, W] 图像特征
Returns:
pos_embed: [B, C, H, W] 位置编码特征
"""
B, C, H, W = img_features.shape
N = xyz.shape[1]
coord_embed = self.coord_proj(xyz)
feat_weights = self.feature_proj(img_features)
feat_weights = F.softmax(feat_weights.view(B, self.num_pos_feats, -1), dim=-1)
feat_weights = feat_weights.view(B, self.num_pos_feats, H, W)
coord_embed = coord_embed.unsqueeze(-1).unsqueeze(-1)
feat_weights = feat_weights.unsqueeze(1)
weighted_coord = (coord_embed * feat_weights).sum(dim=2)
feat_flat = img_features.view(B, C, -1)
weighted_coord_flat = weighted_coord.view(B, N, -1)
similarity = torch.einsum('bnc,bchw->bnhw', coord_embed.squeeze(-1).squeeze(-1), feat_weights.squeeze(1))
topk_indices = torch.topk(similarity, k=min(3, N), dim=1)[1]
selected_coords = torch.gather(
coord_embed.squeeze(-1).squeeze(-1).unsqueeze(-1).unsqueeze(-1),
dim=1,
index=topk_indices.unsqueeze(2)
).sum(dim=1)
fused_feat = torch.cat([img_features, selected_coords], dim=1)
pos_embed = self.fusion(fused_feat.view(B, -1, H*W).transpose(1,2))
pos_embed = pos_embed.transpose(1,2).view(B, -1, H, W)
return self.dropout(pos_embed)
def generate_bev_grid(x_range, y_range, z_range, device='cuda'):
"""
生成 BEV 空间网格坐标
Args:
x_range: [min_x, max_x, step]
y_range: [min_y, max_y, step]
z_range: [min_z, max_z, step]
device: 计算设备
Returns:
grid_coords: [N, 3] 网格坐标
"""
xs = torch.arange(*x_range, device=device)
ys = torch.arange(*y_range, device=device)
zs = torch.arange(*z_range, device=device)
grid_x, grid_y, grid_z = torch.meshgrid(xs, ys, zs, indexing='ij')
grid_coords = torch.stack([grid_x, grid_y, grid_z], dim=-1)
return grid_coords.view(-1, 3)
if __name__ == "__main__":
bev_grid = generate_bev_grid(
x_range=[-51.2, 51.2, 0.4],
y_range=[-51.2, 51.2, 0.4],
z_range=[-5.0, 3.0, 0.4]
)
pos_encoder = FeatureGuidedPositionEncoding(
in_channels=256,
out_channels=256,
num_pos_feats=128
).to('cuda')
img_feat = torch.randn(2, 256, 32, 80).to('cuda')
pos_embed = pos_encoder(bev_grid.unsqueeze(0), img_feat)
print(f"Position embedding shape: {pos_embed.shape}")
这个实现包含了 PETRV2v2 的核心创新——特征引导机制。它不是简单地将坐标编码加到特征上,而是让图像特征'指导'位置编码的生成过程。具体来说,图像特征被用来生成空间注意力权重,这些权重决定了哪些坐标编码应该在哪些空间位置上被强调。
4.3 位置编码的集成应用
在模型训练流程中,位置编码需要与图像特征正确融合。以下是典型的集成方式:
import torch
import torch.nn as nn
from .position_encoding import FeatureGuidedPositionEncoding, generate_bev_grid
class PETRV2Backbone(nn.Module):
"""PETRV2 主干网络"""
def __init__(self, img_backbone_cfg=None, img_neck_cfg=None, position_encoding_cfg=None):
super().__init__()
self.img_backbone = build_backbone(img_backbone_cfg)
self.img_neck = build_neck(img_neck_cfg)
self.position_encoder = FeatureGuidedPositionEncoding(**position_encoding_cfg)
self.feature_proj = nn.Conv2d(256, 256, 1)
def forward(self, img_inputs, img_metas):
"""
Args:
img_inputs: [B, N_cam, C, H, W] 多视角图像
img_metas: 图像元信息列表
Returns:
img_features: [B, C, H, W] 编码后的图像特征
"""
B, N_cam, C, H, W = img_inputs.shape
img_features_list = []
for i in range(N_cam):
feat = self.img_backbone(img_inputs[:, i])
feat = self.img_neck(feat)[-1]
img_features_list.append(feat)
img_features = torch.stack(img_features_list, dim=1).mean(dim=1)
bev_grid = generate_bev_grid(
x_range=[-51.2, 51.2, 0.4],
y_range=[-51.2, 51.2, 0.4],
z_range=[-5.0, 3.0, 0.4],
device=img_features.device
)
proj_features = self.feature_proj(img_features)
pos_embed = self.position_encoder(bev_grid.unsqueeze(0), proj_features)
fused_features = img_features + pos_embed
return fused_features
这种集成方式确保了位置信息被自然地融入到特征学习过程中,而不是作为外部附加信息。模型在训练时会自动学习如何利用位置编码来提升 3D 检测性能。
5. 损失函数设计与实现
5.1 多任务损失函数架构
PETRV2 是一个多任务模型,同时执行 3D 目标检测、BEV 分割和 3D 车道线检测。因此,它的损失函数是多个子任务损失的加权和:
总损失 = λ_det × L_det + λ_seg × L_seg + λ_lane × L_lane
其中,L_det 是 3D 检测损失,L_seg 是 BEV 分割损失,L_lane 是车道线检测损失。权重λ用于平衡不同任务的梯度幅度。
对于 3D 检测任务,PETRV2 采用 DETR 风格的端到端检测框架,使用匈牙利算法进行预测与真值的最优匹配。这避免了传统方法中复杂的 anchor 设计和 NMS 后处理。
5.2 Focal Loss + IoU Loss 实现
3D 检测的分类和回归需要不同的损失函数。分类使用 Focal Loss 解决正负样本不平衡问题,回归使用 IoU Loss 提高定位精度:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from mmdet.models.losses import SmoothL1Loss
class FocalLoss(nn.Module):
"""Focal Loss 实现"""
def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
"""
Args:
inputs: [N, C] 预测 logits
targets: [N] 真值标签
Returns:
loss: 标量损失值
"""
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_weight = (1 - pt) ** self.gamma
if self.alpha >= 0:
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
focal_weight = alpha_t * focal_weight
loss = focal_weight * ce_loss
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
class IoULoss(nn.Module):
"""3D IoU Loss 实现"""
def __init__(self, eps=1e-6, reduction='mean'):
super().__init__()
self.eps = eps
self.reduction = reduction
def forward(self, pred_boxes, target_boxes):
"""
Args:
pred_boxes: [N, 7] 预测 3D 框 [x,y,z,l,w,h,theta]
target_boxes: [N, 7] 真值 3D 框
Returns:
loss: IoU 损失值
"""
pred_xy = pred_boxes[:, :2]
pred_lw = pred_boxes[:, 3:5]
target_xy = target_boxes[:, :2]
target_lw = target_boxes[:, 3:5]
pred_min = pred_xy - pred_lw / 2
pred_max = pred_xy + pred_lw / 2
target_min = target_xy - target_lw / 2
target_max = target_xy + target_lw / 2
inter_min = torch.max(pred_min, target_min)
inter_max = torch.min(pred_max, target_max)
inter_area = torch.clamp(inter_max - inter_min, min=0).prod(dim=1)
pred_area = pred_lw.prod(dim=1)
target_area = target_lw.prod(dim=1)
union_area = pred_area + target_area - inter_area
iou = inter_area / (union_area + self.eps)
loss = 1 - iou
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
class PETRV2Loss(nn.Module):
"""PETRV2 多任务损失函数"""
def __init__(self, loss_cls=dict(type='FocalLoss', alpha=0.25, gamma=2.0), loss_bbox=dict(type='IoULoss'), loss_iou=dict(type='SmoothL1Loss'), loss_seg=dict(type='CrossEntropyLoss'), loss_lane=dict(type='FocalLoss'), loss_weights=dict(loss_cls=2.0, loss_bbox=0.25, loss_iou=0.25, loss_seg=1.0, loss_lane=1.0)):
super().__init__()
self.loss_cls = FocalLoss(**loss_cls)
self.loss_bbox = IoULoss(**loss_bbox)
self.loss_iou = SmoothL1Loss(**loss_iou)
self.loss_seg = CrossEntropyLoss(**loss_seg)
self.loss_lane = FocalLoss(**loss_lane)
self.loss_weights = loss_weights
def forward(self, pred_dict, target_dict):
"""
Args:
pred_dict: 预测字典
- cls_scores: [B, N_q, C] 分类分数
- bbox_preds: [B, N_q, 7] 3D 框预测
- seg_preds: [B, C_seg, H, W] BEV 分割预测
- lane_preds: [B, N_lane, C_lane]
target_dict: 真值字典
Returns:
total_loss: 总损失
"""
loss_cls = self.loss_cls(pred_dict['cls_scores'], target_dict['gt_labels'])
loss_bbox = self.loss_bbox(pred_dict['bbox_preds'], target_dict['gt_bboxes_3d'])
loss_iou = self.loss_iou(pred_dict['bbox_preds'], target_dict['gt_bboxes_3d'])
loss_seg = self.loss_seg(pred_dict['seg_preds'], target_dict['gt_semantic_seg'])
loss_lane = self.loss_lane(pred_dict['lane_preds'], target_dict['gt_lanes'])
total_loss = (
self.loss_weights['loss_cls'] * loss_cls +
self.loss_weights['loss_bbox'] * loss_bbox +
self.loss_weights['loss_iou'] * loss_iou +
self.loss_weights['loss_seg'] * loss_seg +
self.loss_weights['loss_lane'] * loss_lane
)
return total_loss
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- 随机西班牙地址生成器
随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
- Gemini 图片去水印
基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online