人工智能在飞控算法中的应用日益广泛,传统控制算法面临建模难、抗干扰弱等问题。结合人工智能是创新方向之一,主要包括 AI 增强传统控制、强化学习端到端控制及感知 - 控制一体化等思路。
一、核心技术栈(纯 C++)
| 模块 | 工具/框架 | 作用 |
|---|---|---|
| 机器人框架 | ROS2 Humble(C++ API) | 传感器/执行器通信、节点管理、实时调度 |
| 强化学习框架 | LibTorch 2.1.0(CUDA 11.4) | TD3 算法实现、网络训练、模型导出 |
| 模型加速 | TensorRT 8.5(C++ API) | 模型量化、推理引擎构建、GPU/DLA 加速 |
| 仿真环境 | Gazebo 11 + ros2_control(C++接口) | 无人机动力学仿真、传感器模拟、避障场景 |
| 数据处理 | Eigen 3.4、OpenCV 4.5、PCL 1.10 | 状态向量构建、点云/图像预处理 |
| 编译构建 | CMake 3.20+、ament_cmake | 跨平台编译、依赖管理、优化编译选项 |
二、环境准备(Orin NX 专属)
1. 系统与依赖安装
(1)基础依赖
# ROS2 Humble 核心依赖(已安装可跳过)
sudo apt install ros-humble-ros2-control ros-humble-ros2-controllers ros-humble-gazebo-ros2-control
# 数据处理与编译依赖
sudo apt install libeigen3-dev libopencv-dev libpcl-dev cmake gcc-9 g++-9
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 50
sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-9 50
(2)LibTorch(C++ PyTorch)安装
Orin NX 为 ARM64 架构,需下载对应 CUDA 版本的 LibTorch:
# 下载 LibTorch 2.1.0(CUDA 11.4,ARM64)
wget https://download.pytorch.org/libtorch/cu114/libtorch-cxx11-abi-shared-with-deps-2.1.0%2Bcu114.zip
unzip libtorch-cxx11-abi-shared-with-deps-2.1.0+cu114.zip -d /opt/
echo "export Torch_DIR=/opt/libtorch/share/cmake/Torch" >> ~/.bashrc
source ~/.bashrc
(3)TensorRT 依赖(JetPack 预装,验证即可)
# 验证 TensorRT 安装
dpkg -l | grep TensorRT
# 确保库路径正确
echo "export LD_LIBRARY_PATH=/usr/lib/aarch64-linux-gnu/:$LD_LIBRARY_PATH" >> ~/.bashrc
source ~/.bashrc
2. ROS2 功能包创建
cd ~/ros2_ws/src
ros2 pkg create drone_rl_cpp --build-type ament_cmake --dependencies \
rclcpp sensor_msgs geometry_msgs std_msgs gazebo_ros2_control ros2_control \
Eigen3 opencv4 pcl_common pcl_io
cd drone_rl_cpp
# 创建目录结构
mkdir -p include/drone_rl_cpp src config launch models src/networks src/env src/utils
三、核心模块设计(纯 C++ 实现)
1. 模块划分
drone_rl_cpp/
├── include/drone_rl_cpp/
│ ├── env/DroneEnv.hpp # ROS2 环境封装(传感器 + 动作 + 状态 + 奖励)
│ ├── networks/TD3Networks.hpp # TD3 网络(Actor/Critic)
│ ├── utils/ReplayBuffer.hpp # 经验回放缓冲区
│ ├── utils/TrtInfer.hpp # TensorRT 推理封装
│ └── TD3Agent.hpp # TD3 智能体(训练 + 推理)
├── src/
│ ├── env/DroneEnv.cpp # 环境实现
│ ├── networks/TD3Networks.cpp
│ ├── utils/ReplayBuffer.cpp
│ ├── utils/TrtInfer.cpp
│ ├── TD3Agent.cpp
│ ├── train_node.cpp # 训练节点(ROS2)
│ └── infer_node.cpp # 推理控制节点(ROS2)
├── config/ # 控制配置、模型参数
├── launch/ # Gazebo 仿真、硬件启动 launch 文件
└── models/ # 训练好的模型、TensorRT 引擎
2. 核心模块实现
(1)ROS2 环境封装(include/drone_rl_cpp/env/DroneEnv.hpp)
封装传感器订阅、动作发布、状态构建、奖励计算,替代 Python Gym:
#ifndef DRONE_ENV_HPP_
#define DRONE_ENV_HPP_
#include <rclcpp/rclcpp.hpp>
#include <sensor_msgs/msg/imu.hpp>
#include <sensor_msgs/msg/laser_scan.hpp>
#include <geometry_msgs/msg/pose_stamped.hpp>
#include <geometry_msgs/msg/twist_stamped.hpp>
#include <std_msgs/msg/float64_multi_array.hpp>
#include <Eigen/Dense>
#include <vector>
#include <mutex>
#include <atomic>
namespace drone_rl_cpp {
class DroneEnv : public rclcpp::Node {
public:
using Ptr = std::shared_ptr<DroneEnv>;
// 状态维度(22 维)、动作维度(4 维)
static constexpr int STATE_DIM = 22;
static constexpr int ACTION_DIM = 4;
// 动作范围(电机转速:500-2000 RPM)
static constexpr double ACTION_LOW = 500.0;
static ACTION_HIGH = ;
( std::string& node_name = );
~() = ;
;
;
{ is_ready_.(); }
:
;
;
;
;
;
;
;
rclcpp::Subscription<sensor_msgs::msg::Imu>::SharedPtr imu_sub_;
rclcpp::Subscription<geometry_msgs::msg::PoseStamped>::SharedPtr gps_pose_sub_;
rclcpp::Subscription<geometry_msgs::msg::TwistStamped>::SharedPtr gps_twist_sub_;
rclcpp::Subscription<sensor_msgs::msg::LaserScan>::SharedPtr lidar_sub_;
rclcpp::Publisher<std_msgs::msg::Float64MultiArray>::SharedPtr motor_pub_;
std::mutex data_mutex_;
Eigen::VectorXd imu_data_;
Eigen::Vector3d gps_pose_;
Eigen::Vector3d gps_twist_;
Eigen::Vector5d lidar_data_;
std::atomic<> is_ready_;
std::vector<Eigen::Vector3d> target_points_;
current_target_idx_;
std::atomic<> collision_;
};
}
(2)环境实现(src/env/DroneEnv.cpp)
#include "drone_rl_cpp/env/DroneEnv.hpp"
#include <cmath>
#include <algorithm>
namespace drone_rl_cpp {
DroneEnv::DroneEnv(const std::string& node_name)
: Node(node_name),
imu_data_(Eigen::VectorXd::Zero(6)),
gps_pose_(Eigen::Vector3d::Zero()),
gps_twist_(Eigen::Vector3d::Zero()),
lidar_data_(Eigen::Vector5d::Ones() * 10.0), // 初始距离设为 10m
is_ready_(false),
current_target_idx_(0),
collision_(false) {
// 初始化目标点(正方形轨迹:(2,0,1)→(2,2,1)→(0,2,1)→(0,0,1))
target_points_ = {
Eigen::Vector3d(2.0, 0.0, 1.0),
Eigen::Vector3d(2.0, 2.0, 1.0),
Eigen::Vector3d(0.0, 2.0, 1.0),
Eigen::Vector3d(0.0, 0.0, 1.0)
};
// 订阅传感器数据(队列大小 10,确保实时性)
imu_sub_ = this->create_subscription<sensor_msgs::msg::Imu>(, ,
std::(&DroneEnv::imu_callback, , std::placeholders::_1));
gps_pose_sub_ = -><geometry_msgs::msg::PoseStamped>(, ,
std::(&DroneEnv::gps_pose_callback, , std::placeholders::_1));
gps_twist_sub_ = -><geometry_msgs::msg::TwistStamped>(, ,
std::(&DroneEnv::gps_twist_callback, , std::placeholders::_1));
lidar_sub_ = -><sensor_msgs::msg::LaserScan>(, ,
std::(&DroneEnv::lidar_callback, , std::placeholders::_1));
motor_pub_ = -><std_msgs::msg::Float64MultiArray>(, );
start = ->();
(rclcpp::() && !is_ready_ && (->() - start).() < ) {
rclcpp::(->());
std::this_thread::(std::chrono::());
}
(is_ready_) {
(->(), );
} {
(->(), );
}
}
{
;
current_target_idx_ = ;
collision_ = ;
imu_data_.();
gps_pose_.();
gps_twist_.();
lidar_data_.() * ;
zero_cmd = std_msgs::msg::();
zero_cmd.data = {, , , };
motor_pub_->(zero_cmd);
std::this_thread::(std::chrono::());
rclcpp::(->());
();
}
{
(action.() != ACTION_DIM) {
(->(), , ACTION_DIM, action.());
{Eigen::VectorXd::(STATE_DIM), , , };
}
Eigen::VectorXd clipped_action = action.(ACTION_LOW).(ACTION_HIGH);
motor_cmd = std_msgs::msg::();
motor_cmd.data.(ACTION_DIM);
( i = ; i < ACTION_DIM; ++i) {
motor_cmd.data[i] = (i);
}
motor_pub_->(motor_cmd);
std::this_thread::(std::chrono::());
rclcpp::(->());
Eigen::VectorXd state = ();
reward = (state);
done = (state);
std::string info = collision_ ? : (done ? : );
Eigen::Vector3d current_target = target_points_[current_target_idx_];
dist_to_target = (gps_pose_ - current_target).();
(dist_to_target <= ) {
current_target_idx_ = (current_target_idx_ + ) % target_points_.();
(->(), ,
current_target_idx_, current_target.(), current_target.(), current_target.());
}
{state, reward, done, info};
}
{
;
() = msg->angular_velocity.x;
() = msg->angular_velocity.y;
() = msg->angular_velocity.z;
() = msg->linear_acceleration.x;
() = msg->linear_acceleration.y;
() = msg->linear_acceleration.z;
is_ready_ = ;
}
{
;
() = msg->pose.position.x;
() = msg->pose.position.y;
() = msg->pose.position.z;
is_ready_ = ;
}
{
;
() = msg->twist.linear.x;
() = msg->twist.linear.y;
() = msg->twist.linear.z;
is_ready_ = ;
}
{
;
& ranges = msg->ranges;
n = ranges.();
() = *std::(ranges.() + n * / , ranges.() + n * / );
() = *std::(ranges.() + n * / , ranges.() + n * / );
() = *std::(ranges.() + n * / , ranges.() + n * / );
() = *std::(ranges.() + n * / , ranges.() + n * / );
() = *std::(ranges.(), ranges.(), []( a, b) {
a < b && a > ;
});
collision_ = std::(lidar_data_.(), lidar_data_.() + , []( d) {
d < ;
});
is_ready_ = ;
}
{
;
;
Eigen::Vector3d current_target = target_points_[current_target_idx_];
state.(, ) = gps_pose_;
state.(, ) = gps_twist_;
state.(, ) = imu_data_;
state.(, ) = current_target - gps_pose_;
state.(, ) = lidar_data_;
() = (current_target.() - gps_pose_.()).();
() = std::(current_target.() - gps_pose_.());
state;
}
{
err_xy = ();
err_z = ();
Eigen::VectorXd& lidar_dist = state.(, );
Eigen::VectorXd& angular_vel = state.(, );
track_reward = * (err_xy * err_xy + err_z * err_z);
obstacle_reward = ;
( i = ; i < ; ++i) {
obstacle_reward += ((i) >= ) ? : ;
}
smooth_reward = * angular_vel.();
terminal_reward = (std::(err_xy * err_xy + err_z * err_z) <= ) ? : ;
collision_penalty = collision_ ? : ;
track_reward + obstacle_reward + smooth_reward + terminal_reward + collision_penalty;
}
{
err_xy = ();
err_z = ();
Eigen::Vector3d& gps_pos = state.(, );
(collision_) ;
(gps_pos.().() > ) ;
(current_target_idx_ == && std::(err_xy * err_xy + err_z * err_z) <= ) {
(->(), );
;
}
;
}
}
(3)TD3 网络定义(include/drone_rl_cpp/networks/TD3Networks.hpp)
用 LibTorch 实现 Actor(策略网络)和 Critic(价值网络):
#ifndef TD3_NETWORKS_HPP_
#define TD3_NETWORKS_HPP_
#include <torch/torch.h>
#include <Eigen/Dense>
namespace drone_rl_cpp {
// Actor 网络:输入状态 (22 维)→输出动作 (4 维,连续值)
class ActorNetwork : public torch::nn::Module {
public:
ActorNetwork(int state_dim, int action_dim, double action_low, double action_high);
torch::Tensor forward(torch::Tensor x);
// Eigen 向量转 Tensor(推理用)
torch::Tensor eigen_to_tensor(const Eigen::VectorXd& x);
// Tensor 转 Eigen 向量(推理用)
Eigen::VectorXd tensor_to_eigen(const torch::Tensor& x);
private:
torch::nn::Linear fc1_{nullptr}, fc2_{nullptr}, fc3_{nullptr};
double action_low_;
double action_high_;
};
// Critic 网络:输入 (状态 + 动作)→输出 Q 值(单输出)
class CriticNetwork : public torch::nn::Module {
public:
CriticNetwork(int state_dim, int action_dim);
torch::Tensor forward(torch::Tensor x, torch::Tensor a);
:
torch::nn::Linear fc1_{}, fc2_{}, fc3_{};
};
: torch::nn::Module {
:
( state_dim, action_dim);
;
{ critic1_; }
{ critic2_; }
:
std::shared_ptr<CriticNetwork> critic1_;
std::shared_ptr<CriticNetwork> critic2_;
};
}
(4)网络实现(src/networks/TD3Networks.cpp)
#include "drone_rl_cpp/networks/TD3Networks.hpp"
namespace drone_rl_cpp {
// Actor 网络实现
ActorNetwork::ActorNetwork(int state_dim, int action_dim, double action_low, double action_high)
: action_low_(action_low), action_high_(action_high) {
// 三层 MLP:state_dim→256→128→action_dim
fc1_ = register_module("fc1", torch::nn::Linear(state_dim, 256));
fc2_ = register_module("fc2", torch::nn::Linear(256, 128));
fc3_ = register_module("fc3", torch::nn::Linear(128, action_dim));
// 初始化权重(Xavier 均匀分布)
torch::nn::init::xavier_uniform_(fc1_->weight);
torch::nn::init::xavier_uniform_(fc2_->weight);
torch::nn::init::xavier_uniform_(fc3_->weight);
torch::nn::init::constant_(fc1_->bias, 0.01);
torch::nn::init::constant_(fc2_->bias, 0.01);
torch::nn::init::constant_(fc3_->bias, 0.01);
}
torch::Tensor ActorNetwork::forward(torch::Tensor x) {
// 激活函数:ReLU + Tanh(将输出映射到 [-1,1],再缩放至动作范围)
x = torch::relu(fc1_->forward(x));
x = torch::relu(fc2_->forward(x));
x = torch::tanh(fc3_->forward(x)); // [-1,1]
// 缩放至动作范围 [action_low, action_high]
(x + ) * (action_high_ - action_low_) / + action_low_;
}
{
torch::(<*>(x.()), {, x.()}, torch::kFloat32).(torch::kCUDA);
}
{
cpu_tensor = x.().().();
;
std::(eigen_vec.(), cpu_tensor.(), cpu_tensor.() * ());
eigen_vec;
}
CriticNetwork::( state_dim, action_dim) {
fc1_ = (, torch::nn::(state_dim + action_dim, ));
fc2_ = (, torch::nn::(, ));
fc3_ = (, torch::nn::(, ));
torch::nn::init::(fc1_->weight);
torch::nn::init::(fc2_->weight);
torch::nn::init::(fc3_->weight);
torch::nn::init::(fc1_->bias, );
torch::nn::init::(fc2_->bias, );
torch::nn::init::(fc3_->bias, );
}
{
torch::Tensor cat = torch::({x, a}, );
cat = torch::(fc1_->forward(cat));
cat = torch::(fc2_->forward(cat));
fc3_->forward(cat);
}
TwinCriticNetworks::( state_dim, action_dim) {
critic1_ = std::<CriticNetwork>(state_dim, action_dim);
critic2_ = std::<CriticNetwork>(state_dim, action_dim);
(, critic1_);
(, critic2_);
}
{
{critic1_->forward(x, a), critic2_->forward(x, a)};
}
}
(5)经验回放缓冲区(include/drone_rl_cpp/utils/ReplayBuffer.hpp)
#ifndef REPLAY_BUFFER_HPP_
#define REPLAY_BUFFER_HPP_
#include <Eigen/Dense>
#include <vector>
#include <mutex>
#include <random>
namespace drone_rl_cpp {
struct Transition {
Eigen::VectorXd state;
Eigen::VectorXd action;
double reward;
Eigen::VectorXd next_state;
bool done;
};
class ReplayBuffer {
public:
ReplayBuffer(int capacity, int state_dim, int action_dim);
~ReplayBuffer() = default;
// 添加经验
void push(const Transition& transition);
// 采样批次经验(返回 Tensor,用于训练)
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> sample_batch(int batch_size);
// 缓冲区大小
int size() const { return static_cast<int>(buffer_.size()); }
// 缓冲区是否已满
bool { () >= capacity_; }
:
capacity_;
state_dim_;
action_dim_;
std::vector<Transition> buffer_;
std::mutex buffer_mutex_;
std::mt19937 rng_;
write_idx_;
};
}
(6)缓冲区实现(src/utils/ReplayBuffer.cpp)
#include "drone_rl_cpp/utils/ReplayBuffer.hpp"
#include <torch/torch.h>
namespace drone_rl_cpp {
ReplayBuffer::ReplayBuffer(int capacity, int state_dim, int action_dim)
: capacity_(capacity), state_dim_(state_dim), action_dim_(action_dim),
write_idx_(0), rng_(std::random_device{}) {
buffer_.reserve(capacity_);
}
void ReplayBuffer::push(const Transition& transition) {
std::lock_guard<std::mutex> lock(buffer_mutex_);
if (buffer_.size() < capacity_) {
buffer_.emplace_back(transition);
} else {
buffer_[write_idx_] = transition;
}
write_idx_ = (write_idx_ + 1) % capacity_;
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> ReplayBuffer::sample_batch(int batch_size) {
std::lock_guard<std::mutex> lock(buffer_mutex_);
int current_size = buffer_.size();
if (current_size < batch_size) {
throw std::runtime_error("Replay buffer size (" + std::to_string(current_size) + ") < batch size (" + std::to_string(batch_size) + );
}
std::uniform_int_distribution<> (, current_size - );
;
( i = ; i < batch_size; ++i) {
indices[i] = (rng_);
}
torch::Tensor states = torch::({batch_size, state_dim_}, torch::kFloat32).(torch::kCUDA);
torch::Tensor actions = torch::({batch_size, action_dim_}, torch::kFloat32).(torch::kCUDA);
torch::Tensor rewards = torch::({batch_size, }, torch::kFloat32).(torch::kCUDA);
torch::Tensor next_states = torch::({batch_size, state_dim_}, torch::kFloat32).(torch::kCUDA);
torch::Tensor dones = torch::({batch_size, }, torch::kFloat32).(torch::kCUDA);
( i = ; i < batch_size; ++i) {
& t = buffer_[indices[i]];
states[i] = torch::(<*>(t.state.()), {state_dim_}, torch::kFloat32);
actions[i] = torch::(<*>(t.action.()), {action_dim_}, torch::kFloat32);
rewards[i] = t.reward;
next_states[i] = torch::(<*>(t.next_state.()), {state_dim_}, torch::kFloat32);
dones[i] = t.done ? : ;
}
{states, actions, rewards, next_states, dones};
}
}
(7)TD3 智能体(include/drone_rl_cpp/TD3Agent.hpp)
整合网络、缓冲区、训练逻辑:
#ifndef TD3_AGENT_HPP_
#define TD3_AGENT_HPP_
#include "drone_rl_cpp/networks/TD3Networks.hpp"
#include "drone_rl_cpp/utils/ReplayBuffer.hpp"
#include <torch/torch.h>
#include <Eigen/Dense>
#include <vector>
namespace drone_rl_cpp {
class TD3Agent {
public:
TD3Agent(int state_dim, int action_dim, double action_low, double action_high,
int buffer_capacity = 1000000, int batch_size = 256, double gamma = 0.99,
double tau = 0.005, double lr_actor = 3e-4, double lr_critic = 3e-4,
double policy_noise = 0.1, double noise_clip = 0.2, int policy_freq = 2);
~TD3Agent() = default;
// 选择动作(训练时加噪声,推理时不加)
Eigen::VectorXd select_action(const Eigen::VectorXd& state, bool is_training = );
;
;
;
{ replay_buffer_; }
:
std::shared_ptr<ActorNetwork> actor_;
std::shared_ptr<ActorNetwork> target_actor_;
std::shared_ptr<TwinCriticNetworks> critics_;
std::shared_ptr<TwinCriticNetworks> target_critics_;
torch::optim::Adam actor_optimizer_;
torch::optim::Adam critics_optimizer_;
ReplayBuffer replay_buffer_;
state_dim_;
action_dim_;
action_low_;
action_high_;
batch_size_;
gamma_;
tau_;
policy_noise_;
noise_clip_;
policy_freq_;
update_count_;
std::mt19937 rng_;
std::normal_distribution<> noise_dist_;
};
}
(8)智能体实现(src/TD3Agent.cpp)
#include "drone_rl_cpp/TD3Agent.hpp"
#include <torch/torch.h>
#include <fstream>
namespace drone_rl_cpp {
TD3Agent::TD3Agent(int state_dim, int action_dim, double action_low, double action_high,
int buffer_capacity, int batch_size, double gamma, double tau,
double lr_actor, double lr_critic, double policy_noise, double noise_clip, int policy_freq)
: state_dim_(state_dim), action_dim_(action_dim), action_low_(action_low), action_high_(action_high),
batch_size_(batch_size), gamma_(gamma), tau_(tau), policy_noise_(policy_noise), noise_clip_(noise_clip),
policy_freq_(policy_freq), update_count_(0), rng_(std::random_device{}()),
noise_dist_(0.0, policy_noise),
replay_buffer_(buffer_capacity, state_dim, action_dim),
actor_(std::make_shared<ActorNetwork>(state_dim, action_dim, action_low, action_high)),
target_actor_(std::make_shared<ActorNetwork>(state_dim, action_dim, action_low, action_high)),
critics_(std::make_shared<TwinCriticNetworks>(state_dim, action_dim)),
(std::<TwinCriticNetworks>(state_dim, action_dim)),
(actor_->(), torch::optim::(lr_actor)),
(critics_->(), torch::optim::(lr_critic)) {
actor_->(torch::kCUDA);
target_actor_->(torch::kCUDA);
critics_->(torch::kCUDA);
target_critics_->(torch::kCUDA);
target_actor_->(actor_->());
target_critics_->(critics_->());
(& param : target_actor_->()) {
param.();
}
(& param : target_critics_->()) {
param.();
}
}
{
torch::NoGradGuard no_grad;
torch::Tensor state_tensor = actor_->(state);
torch::Tensor action_tensor = actor_->forward(state_tensor);
(is_training) {
;
( i = ; i < action_dim_; ++i) {
(i) = (rng_);
}
noise = noise.(-noise_clip_).(noise_clip_);
Eigen::VectorXd action = actor_->(action_tensor) + noise;
action.(action_low_).(action_high_);
}
actor_->(action_tensor);
}
{
[states, actions, rewards, next_states, dones] = replay_buffer_.(batch_size_);
critics_optimizer_.();
torch::Tensor target_actions = target_actor_->forward(next_states);
torch::Tensor noise = torch::(target_actions) * policy_noise_;
noise = noise.(-noise_clip_, noise_clip_);
target_actions = (target_actions + noise).(action_low_, action_high_);
[target_q1, target_q2] = target_critics_->forward(next_states, target_actions);
torch::Tensor target_q = torch::(target_q1, target_q2);
torch::Tensor target_q_values = rewards + ( - dones) * gamma_ * target_q;
[q1, q2] = critics_->forward(states, actions);
torch::Tensor critic_loss = torch::(q1, target_q_values) + torch::(q2, target_q_values);
critic_loss.();
critics_optimizer_.();
(update_count_ % policy_freq_ == ) {
actor_optimizer_.();
torch::Tensor actor_actions = actor_->forward(states);
torch::Tensor actor_loss = -critics_->()->forward(states, actor_actions).();
actor_loss.();
actor_optimizer_.();
(& [target_param, param] : std::(target_actor_->(), actor_->())) {
target_param.().(tau_ * param.() + ( - tau_) * target_param.());
}
(& [target_param, param] : std::(target_critics_->(), critics_->())) {
target_param.().(tau_ * param.() + ( - tau_) * target_param.());
}
}
update_count_++;
critic_loss.<>();
}
{
torch::(actor_, path + );
torch::(critics_, path + );
torch::(actor_optimizer_, path + );
torch::(critics_optimizer_, path + );
(rclcpp::(), , path.());
}
{
torch::(actor_, path + );
torch::(critics_, path + );
torch::(actor_optimizer_, path + );
torch::(critics_optimizer_, path + );
target_actor_->(actor_->());
target_critics_->(critics_->());
actor_->(torch::kCUDA);
target_actor_->(torch::kCUDA);
critics_->(torch::kCUDA);
target_critics_->(torch::kCUDA);
(rclcpp::(), , path.());
}
}
(9)TensorRT 推理封装(include/drone_rl_cpp/utils/TrtInfer.hpp)
#ifndef TRT_INFER_HPP_
#define TRT_INFER_HPP_
#include <tensorrt/NvInfer.h>
#include <cuda_runtime_api.h>
#include <Eigen/Dense>
#include <string>
#include <memory>
#include <vector>
namespace drone_rl_cpp {
class TrtInfer {
public:
TrtInfer(const std::string& engine_path);
~TrtInfer();
// 推理:输入 Eigen 向量(22 维),输出 Eigen 向量(4 维)
Eigen::VectorXd infer(const Eigen::VectorXd& input);
private:
// 资源释放辅助类
class TrtDeleter {
public:
void operator()(nvinfer1::ICudaEngine* engine) const { engine->destroy(); }
void operator()(nvinfer1::IExecutionContext* context) const { context->destroy(); }
void operator { runtime->(); }
};
std::unique_ptr<nvinfer1::IRuntime, TrtDeleter> runtime_;
std::unique_ptr<nvinfer1::ICudaEngine, TrtDeleter> engine_;
std::unique_ptr<nvinfer1::IExecutionContext, TrtDeleter> context_;
* d_input_ = ;
* d_output_ = ;
INPUT_DIM = ;
OUTPUT_DIM = ;
input_size_;
output_size_;
};
}
(10)TensorRT 推理实现(src/utils/TrtInfer.cpp)
#include "drone_rl_cpp/utils/TrtInfer.hpp"
#include <fstream>
#include <iostream>
#include <rclcpp/rclcpp.hpp>
using namespace nvinfer1;
namespace drone_rl_cpp {
TrtInfer::TrtInfer(const std::string& engine_path) {
// 1. 读取 TensorRT 引擎文件
std::ifstream engine_file(engine_path, std::ios::binary);
if (!engine_file) {
throw std::runtime_error("Failed to open TRT engine file: " + engine_path);
}
engine_file.seekg(0, std::ios::end);
const size_t engine_size = engine_file.tellg();
engine_file.seekg(0, std::ios::beg);
std::vector<char> engine_data(engine_size);
engine_file.read(engine_data.data(), engine_size);
// 2. 初始化 TensorRT 运行时
IRuntime* runtime = createInferRuntime(Logger(Logger::WARNING));
if (!runtime) {
throw std::runtime_error("Failed to create TRT runtime");
}
runtime_.reset(runtime);
// 3. 反序列化引擎
ICudaEngine* engine = runtime->deserializeCudaEngine(engine_data.(), engine_size, );
(!engine) {
std::();
}
engine_.(engine);
IExecutionContext* context = engine->();
(!context) {
std::();
}
context_.(context);
input_size_ = INPUT_DIM * ();
output_size_ = OUTPUT_DIM * ();
(&d_input_, input_size_);
(&d_output_, output_size_);
(rclcpp::(), , INPUT_DIM, OUTPUT_DIM);
}
TrtInfer::~() {
(d_input_);
(d_output_);
}
{
(input.() != INPUT_DIM) {
std::( + std::(INPUT_DIM) + + std::(input.()));
}
;
( i = ; i < INPUT_DIM; ++i) {
input_host[i] = <>((i));
}
(d_input_, input_host.(), input_size_, cudaMemcpyHostToDevice);
* bindings[] = {d_input_, d_output_};
context_->(bindings);
;
(output_host.(), d_output_, output_size_, cudaMemcpyDeviceToHost);
;
( i = ; i < OUTPUT_DIM; ++i) {
(i) = <>(output_host[i]);
}
output;
}
}
3. ROS2 训练节点(src/train_node.cpp)
#include "drone_rl_cpp/env/DroneEnv.hpp"
#include "drone_rl_cpp/TD3Agent.hpp"
#include <rclcpp/rclcpp.hpp>
#include <chrono>
#include <iomanip>
using namespace drone_rl_cpp;
using namespace std::chrono;
int main(int argc, char* argv[]) {
// 初始化 ROS2
rclcpp::init(argc, argv);
auto node = std::make_shared<rclcpp::Node>("td3_train_node");
// 初始化环境和智能体
auto env = std::make_shared<DroneEnv>();
if (!env->is_ready()) {
RCLCPP_FATAL(node->get_logger(), "Environment not ready, exit");
return -1;
}
TD3Agent agent(
DroneEnv::STATE_DIM, DroneEnv::ACTION_DIM, DroneEnv::ACTION_LOW, DroneEnv::ACTION_HIGH,
1000000, // 缓冲区容量
256, // 批次大小
0.99, // gamma
0.005, // tau
3e-4,
,
,
,
);
total_episodes = ;
max_steps_per_episode = ;
start_train_steps = ;
total_steps = ;
(node->(), , total_episodes);
( ep = ; ep < total_episodes && rclcpp::(); ++ep) {
Eigen::VectorXd state = env->();
ep_reward = ;
ep_done = ;
ep_start = high_resolution_clock::();
( step = ; step < max_steps_per_episode && !ep_done; ++step) {
Eigen::VectorXd action;
(total_steps < start_train_steps) {
std::uniform_real_distribution<> (DroneEnv::ACTION_LOW, DroneEnv::ACTION_HIGH);
action.(DroneEnv::ACTION_DIM);
( i = ; i < DroneEnv::ACTION_DIM; ++i) {
(i) = (agent.().rng_);
}
} {
action = agent.(state, );
}
[next_state, reward, done, info] = env->(action);
ep_reward += reward;
ep_done = done;
Transition transition{state, action, reward, next_state, done};
agent.().(transition);
(total_steps >= start_train_steps) {
loss = agent.();
(step % == ) {
(node->(), , ep, step, loss);
}
}
state = next_state;
total_steps++;
}
ep_end = high_resolution_clock::();
ep_duration = duration_cast<duration<>>(ep_end - ep_start).();
(node->(), ,
ep + , total_episodes, ep_reward, step, ep_duration, info.(), total_steps);
((ep + ) % == ) {
std::string model_path = + std::(ep + );
agent.(model_path);
}
}
agent.();
(node->(), );
rclcpp::();
;
}
4. ROS2 推理控制节点(src/infer_node.cpp)
#include "drone_rl_cpp/env/DroneEnv.hpp"
#include "drone_rl_cpp/utils/TrtInfer.hpp"
#include <rclcpp/rclcpp.hpp>
#include <chrono>
#include <sched.h>
using namespace drone_rl_cpp;
using namespace std::chrono;
int main(int argc, char* argv[]) {
// 初始化 ROS2
rclcpp::init(argc, argv);
auto node = std::make_shared<rclcpp::Node>("td3_infer_node");
// 设置实时优先级(Orin NX 实机必须,确保控制延迟)
struct sched_param param;
param.sched_priority = 99;
if (sched_setscheduler(0, SCHED_FIFO, ¶m) == -1) {
RCLCPP_WARN(node->get_logger(), "Failed to set real-time priority: %s", strerror(errno));
}
// 读取 TensorRT 引擎路径参数
std::string engine_path = node->declare_parameter<std::string>("engine_path", "./models/td3_final_fp16.engine");
// 初始化环境和 TensorRT 推理器
auto env = std::make_shared<DroneEnv>();
if (!env->()) {
(node->(), );
;
}
std::unique_ptr<TrtInfer> trt_infer;
{
trt_infer = std::<TrtInfer>(engine_path);
} ( std::exception& e) {
(node->(), , e.());
;
}
total_steps = ;
total_delay = ;
stat_window = ;
(node->(), );
Eigen::VectorXd state = env->();
(rclcpp::()) {
start = high_resolution_clock::();
Eigen::VectorXd action = trt_infer->(state);
action = action.(DroneEnv::ACTION_LOW).(DroneEnv::ACTION_HIGH);
[next_state, reward, done, info] = env->(action);
end = high_resolution_clock::();
delay = duration_cast<duration<, std::milli>>(end - start).();
total_delay += delay;
total_steps++;
state = next_state;
(total_steps % stat_window == ) {
avg_delay = total_delay / stat_window;
freq = / avg_delay;
(node->(), ,
total_steps, avg_delay, freq, reward, info.());
total_delay = ;
}
(done) {
(node->(), );
state = env->();
}
}
rclcpp::();
;
}
四、模型转换(LibTorch→ONNX→TensorRT)
1. LibTorch 模型导出为 ONNX(C++ 代码)
创建 src/utils/export_onnx.cpp:
#include "drone_rl_cpp/networks/TD3Networks.hpp"
#include <torch/torch.h>
#include <fstream>
int main(int argc, char* argv[]) {
if (argc != 3) {
std::cerr << "Usage: " << argv[0] << " <libtorch_model_path> <output_onnx_path>" << std::endl;
return -1;
}
std::string model_path = argv[1];
std::string onnx_path = argv[2];
// 初始化 Actor 网络
int state_dim = 22;
int action_dim = 4;
double action_low = 500.0;
double action_high = 2000.0;
auto actor = std::make_shared<drone_rl_cpp::ActorNetwork>(state_dim, action_dim, action_low, action_high);
actor->to(torch::kCUDA);
// 加载 LibTorch 模型
torch::load(actor, model_path + "/actor.pt");
actor->eval(); // 推理模式
// 构建虚拟输入(batch_size=1)
torch::Tensor dummy_input = torch::randn({1, state_dim}, torch::kFloat32).to(torch::kCUDA);
// 导出 ONNX
torch::onnx::export_to_onnx(*actor, dummy_input, onnx_path, torch::onnx::ExportConfig(),
{torch::onnx::OperatorExportTypes::ONNX_ATEN_FALLBACK});
std::cout << << onnx_path << std::endl;
;
}
2. ONNX 转换为 TensorRT 引擎(C++ 代码)
创建 src/utils/convert_trt.cpp:
#include <tensorrt/NvInfer.h>
#include <tensorrt/NvOnnxParser.h>
#include <cuda_runtime_api.h>
#include <fstream>
#include <iostream>
using namespace nvinfer1;
using namespace nvonnxparser;
int main(int argc, char* argv[]) {
if (argc != 3) {
std::cerr << "Usage: " << argv[0] << " <onnx_path> <output_engine_path>" << std::endl;
return -1;
}
std::string onnx_path = argv[1];
std::string engine_path = argv[2];
// 创建 Logger
Logger logger(Logger::WARNING);
// 1. 创建 Builder 和 Network
IBuilder* builder = createInferBuilder(logger);
INetworkDefinition* network = builder->createNetworkV2(1U << static_cast<int>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));
// 2. 解析 ONNX 模型
IParser* parser = createParser(*network, logger);
if (!parser->parseFromFile(onnx_path.c_str(), static_cast<>(Logger::WARNING))) {
std::cerr << << std::endl;
;
}
IBuilderConfig* config = builder->();
config->( << );
config->(BuilderFlag::kFP16);
ICudaEngine* engine = builder->(*network, *config);
(!engine) {
std::cerr << << std::endl;
;
}
IHostMemory* serialized_engine = engine->();
;
engine_file.(< *>(serialized_engine->()), serialized_engine->());
serialized_engine->();
engine->();
config->();
network->();
parser->();
builder->();
std::cout << << engine_path << std::endl;
;
}
五、CMakeLists.txt 配置
cmake_minimum_required(VERSION 3.20)
project(drone_rl_cpp)
if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
add_compile_options(-Wall -Wextra -Wpedantic -O3 -std=c++17)
endif()
# 查找依赖包
find_package(ament_cmake REQUIRED)
find_package(rclcpp REQUIRED)
find_package(sensor_msgs REQUIRED)
find_package(geometry_msgs REQUIRED)
find_package(std_msgs REQUIRED)
find_package(gazebo_ros2_control REQUIRED)
find_package(ros2_control REQUIRED)
find_package(Eigen3 REQUIRED)
find_package(OpenCV REQUIRED)
find_package(PCL REQUIRED COMPONENTS common io)
#LibTorch 依赖(自动查找)
find_package(Torch REQUIRED)
message(STATUS "LibTorch found: ${Torch_FOUND}, Version: ${Torch_VERSION}")
#TensorRT 和 CUDA 依赖
find_package(CUDAToolkit REQUIRED)
find_library(NVINFER_LIB nvinfer HINTS /usr/lib/aarch64-linux-gnu/)
find_library(NVONNXPARSER_LIB nvonnxparser HINTS /usr/lib/aarch64-linux-gnu/)
message(STATUS "TensorRT libs: ${NVINFER_LIB}, ${NVONNXPARSER_LIB}")
# 包含目录
include_directories(
include
${EIGEN3_INCLUDE_DIRS}
${OpenCV_INCLUDE_DIRS}
以上为基于强化学习的无人机端到端飞行控制算法开发流程示例。


