跳到主要内容基于强化学习的无人机端到端飞行控制算法开发 | 极客日志C++AI算法
基于强化学习的无人机端到端飞行控制算法开发
综述由AI生成基于强化学习(RL)的无人机端到端飞行控制算法开发方案,采用纯 C++ 技术栈。核心内容包括使用 ROS2 Humble 进行传感器通信与节点管理,利用 LibTorch 实现 TD3 算法进行网络训练,并通过 TensorRT 进行模型加速部署。文章详细阐述了环境准备(Orin NX)、核心模块设计(ROS2 环境封装、TD3 网络、经验回放缓冲区)、训练与推理节点实现以及模型转换流程(LibTorch 转 ONNX 转 TensorRT)。该方法旨在解决传统控制算法建模难、抗干扰弱的问题,实现从传感器输入到控制输出的直接映射,适用于复杂环境下的无人机自主飞行控制。
禅心23 浏览 基于强化学习的无人机端到端飞行控制算法开发
人工智能和飞控结合,有几个方向可选:
- AI 增强传统控制:用 AI 解决传统控制的'建模难、抗干扰弱'问题,保留传统控制的稳定性(如 PID、MPC);
- 强化学习(RL)端到端控制:无需系统模型,通过强化学习训练智能体(Agent)直接从'传感器输入→控制输出'映射,适合复杂环境(如动态避障、多机协作);
- 感知 - 控制一体化:跳过单独的感知模块(如目标检测、障碍物分割),直接用视觉/激光雷达原始数据作为 AI 输入,输出控制指令,减少模块间延迟。
思路 1 做的比较多,用神经网络补偿 PID。思路 2 这两年开始兴起,苏黎世大学做得很成功还发了顶刊。

下面介绍这种基于强化学习的无人机端到端飞行控制算法开发方法。
一、核心技术栈(纯 C++)
| 模块 | 工具/框架 | 作用 |
|---|
| 机器人框架 | ROS2 Humble(C++ API) | 传感器/执行器通信、节点管理、实时调度 |
| 强化学习框架 | LibTorch 2.1.0(CUDA 11.4) | TD3 算法实现、网络训练、模型导出 |
| 模型加速 | TensorRT 8.5(C++ API) | 模型量化、推理引擎构建、GPU/DLA 加速 |
| 仿真环境 | Gazebo 11 + ros2_control(C++接口) | 无人机动力学仿真、传感器模拟、避障场景 |
| 数据处理 | Eigen 3.4、OpenCV 4.5、PCL 1.10 | 状态向量构建、点云/图像预处理 |
| 编译构建 | CMake 3.20+、ament_cmake | 跨平台编译、依赖管理、优化编译选项 |
二、环境准备(Orin NX 专属)
1. 系统与依赖安装
(1)基础依赖
sudo apt install ros-humble-ros2-control ros-humble-ros2-controllers ros-humble-gazebo-ros2-control
sudo apt install libeigen3-dev libopencv-dev libpcl-dev cmake gcc-9 g++-9
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 50
sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-9 50
(2)LibTorch(C++ PyTorch)安装
Orin NX 为 ARM64 架构,需下载对应 CUDA 版本的 LibTorch:
wget https://download.pytorch.org/libtorch/cu114/libtorch-cxx11-abi-shared-with-deps-2.1.0%2Bcu114.zip
unzip libtorch-cxx11-abi-shared-with-deps-2.1.0+cu114.zip -d /opt/
>> ~/.bashrc
~/.bashrc
echo
"export Torch_DIR=/opt/libtorch/share/cmake/Torch"
source
(3)TensorRT 依赖(JetPack 预装,验证即可)
dpkg -l | grep TensorRT
echo "export LD_LIBRARY_PATH=/usr/lib/aarch64-linux-gnu/:$LD_LIBRARY_PATH" >> ~/.bashrc
source ~/.bashrc
2. ROS2 功能包创建
cd ~/ros2_ws/src
ros2 pkg create drone_rl_cpp --build-type ament_cmake --dependencies \
rclcpp sensor_msgs geometry_msgs std_msgs gazebo_ros2_control ros2_control \
Eigen3 opencv4 pcl_common pcl_io
cd drone_rl_cpp
mkdir -p include/drone_rl_cpp src config launch models src/networks src/env src/utils
三、核心模块设计(纯 C++ 实现)
1. 模块划分
drone_rl_cpp/
├── include/drone_rl_cpp/
│ ├── env/DroneEnv.hpp
│ ├── networks/TD3Networks.hpp
│ ├── utils/ReplayBuffer.hpp
│ ├── utils/TrtInfer.hpp
│ └── TD3Agent.hpp
├── src/
│ ├── env/DroneEnv.cpp
│ ├── networks/TD3Networks.cpp
│ ├── utils/ReplayBuffer.cpp
│ ├── utils/TrtInfer.cpp
│ ├── TD3Agent.cpp
│ ├── train_node.cpp
│ └── infer_node.cpp
├── config/
├── launch/
└── models/
2. 核心模块实现
(1)ROS2 环境封装(include/drone_rl_cpp/env/DroneEnv.hpp)
封装传感器订阅、动作发布、状态构建、奖励计算,替代 Python Gym:
#ifndef DRONE_ENV_HPP_
#define DRONE_ENV_HPP_
#include <rclcpp/rclcpp.hpp>
#include <sensor_msgs/msg/imu.hpp>
#include <sensor_msgs/msg/laser_scan.hpp>
#include <geometry_msgs/msg/pose_stamped.hpp>
#include <geometry_msgs/msg/twist_stamped.hpp>
#include <std_msgs/msg/float64_multi_array.hpp>
#include <Eigen/Dense>
#include <vector>
#include <mutex>
#include <atomic>
namespace drone_rl_cpp {
class DroneEnv : public rclcpp::Node {
public:
using Ptr = std::shared_ptr<DroneEnv>;
static constexpr int STATE_DIM = 22;
static constexpr int ACTION_DIM = 4;
static constexpr double ACTION_LOW = 500.0;
static constexpr double ACTION_HIGH = 2000.0;
DroneEnv(const std::string& node_name = "drone_env_node");
~DroneEnv() = default;
Eigen::VectorXd reset();
std::tuple<Eigen::VectorXd, double, bool, std::string> step(const Eigen::VectorXd& action);
bool is_ready() const { return is_ready_.load(); }
private:
void imu_callback(const sensor_msgs::msg::Imu::SharedPtr msg);
void gps_pose_callback(const geometry_msgs::msg::PoseStamped::SharedPtr msg);
void gps_twist_callback(const geometry_msgs::msg::TwistStamped::SharedPtr msg);
void lidar_callback(const sensor_msgs::msg::LaserScan::SharedPtr msg);
Eigen::VectorXd build_state();
double compute_reward(const Eigen::VectorXd& state);
bool check_done(const Eigen::VectorXd& state);
rclcpp::Subscription<sensor_msgs::msg::Imu>::SharedPtr imu_sub_;
rclcpp::Subscription<geometry_msgs::msg::PoseStamped>::SharedPtr gps_pose_sub_;
rclcpp::Subscription<geometry_msgs::msg::TwistStamped>::SharedPtr gps_twist_sub_;
rclcpp::Subscription<sensor_msgs::msg::LaserScan>::SharedPtr lidar_sub_;
rclcpp::Publisher<std_msgs::msg::Float64MultiArray>::SharedPtr motor_pub_;
std::mutex data_mutex_;
Eigen::VectorXd imu_data_;
Eigen::Vector3d gps_pose_;
Eigen::Vector3d gps_twist_;
Eigen::Vector5d lidar_data_;
std::atomic<bool> is_ready_;
std::vector<Eigen::Vector3d> target_points_;
int current_target_idx_;
std::atomic<bool> collision_;
};
}
#endif
(2)环境实现(src/env/DroneEnv.cpp)
#include "drone_rl_cpp/env/DroneEnv.hpp"
#include <cmath>
#include <algorithm>
namespace drone_rl_cpp {
DroneEnv::DroneEnv(const std::string& node_name) : Node(node_name),
imu_data_(Eigen::VectorXd::Zero(6)),
gps_pose_(Eigen::Vector3d::Zero()),
gps_twist_(Eigen::Vector3d::Zero()),
lidar_data_(Eigen::Vector5d::Ones()*10.0),
is_ready_(false),
current_target_idx_(0),
collision_(false) {
target_points_ = {
Eigen::Vector3d(2.0, 0.0, 1.0),
Eigen::Vector3d(2.0, 2.0, 1.0),
Eigen::Vector3d(0.0, 2.0, 1.0),
Eigen::Vector3d(0.0, 0.0, 1.0)
};
imu_sub_ = this->create_subscription<sensor_msgs::msg::Imu>("/drone/imu", 10,
std::bind(&DroneEnv::imu_callback, this, std::placeholders::_1));
gps_pose_sub_ = this->create_subscription<geometry_msgs::msg::PoseStamped>("/drone/gps/pose", 10,
std::bind(&DroneEnv::gps_pose_callback, this, std::placeholders::_1));
gps_twist_sub_ = this->create_subscription<geometry_msgs::msg::TwistStamped>("/drone/gps/twist", 10,
std::bind(&DroneEnv::gps_twist_callback, this, std::placeholders::_1));
lidar_sub_ = this->create_subscription<sensor_msgs::msg::LaserScan>("/drone/lidar", 10,
std::bind(&DroneEnv::lidar_callback, this, std::placeholders::_1));
motor_pub_ = this->create_publisher<std_msgs::msg::Float64MultiArray>("/drone/motor_vel_cmd", 10);
auto start = this->now();
while(rclcpp::ok() && !is_ready_ && (this->now() - start).seconds() < 1.0) {
rclcpp::spin_some(this->get_node_base_interface());
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
if(is_ready_) {
RCLCPP_INFO(this->get_logger(), "Drone environment is ready");
} else {
RCLCPP_ERROR(this->get_logger(), "Sensor data not ready, environment init failed");
}
}
Eigen::VectorXd DroneEnv::reset() {
std::lock_guard<std::mutex> lock(data_mutex_);
current_target_idx_ = 0;
collision_ = false;
imu_data_.setZero();
gps_pose_.setZero();
gps_twist_.setZero();
lidar_data_.setOnes()*10.0;
auto zero_cmd = std_msgs::msg::Float64MultiArray();
zero_cmd.data = {500.0, 500.0, 500.0, 500.0};
motor_pub_->publish(zero_cmd);
std::this_thread::sleep_for(std::chrono::milliseconds(50));
rclcpp::spin_some(this->get_node_base_interface());
return build_state();
}
std::tuple<Eigen::VectorXd, double, bool, std::string> DroneEnv::step(const Eigen::VectorXd& action) {
if(action.size() != ACTION_DIM) {
RCLCPP_ERROR(this->get_logger(), "Action dimension mismatch: expected %d, got %ld", ACTION_DIM, action.size());
return {Eigen::VectorXd::Zero(STATE_DIM), -1000.0, true, "action_dim_error"};
}
Eigen::VectorXd clipped_action = action.cwiseMax(ACTION_LOW).cwiseMin(ACTION_HIGH);
auto motor_cmd = std_msgs::msg::Float64MultiArray();
motor_cmd.data.resize(ACTION_DIM);
for(int i = 0; i < ACTION_DIM; ++i) {
motor_cmd.data[i] = clipped_action(i);
}
motor_pub_->publish(motor_cmd);
std::this_thread::sleep_for(std::chrono::milliseconds(5));
rclcpp::spin_some(this->get_node_base_interface());
Eigen::VectorXd state = build_state();
double reward = compute_reward(state);
bool done = check_done(state);
std::string info = collision_ ? "collision" : (done ? "task_completed" : "running");
Eigen::Vector3d current_target = target_points_[current_target_idx_];
double dist_to_target = (gps_pose_ - current_target).norm();
if(dist_to_target <= 0.05) {
current_target_idx_ = (current_target_idx_ + 1) % target_points_.size();
RCLCPP_INFO(this->get_logger(), "Switch to target %d (pos: %.2f, %.2f, %.2f)", current_target_idx_, current_target.x(), current_target.y(), current_target.z());
}
return {state, reward, done, info};
}
void DroneEnv::imu_callback(const sensor_msgs::msg::Imu::SharedPtr msg) {
std::lock_guard<std::mutex> lock(data_mutex_);
imu_data_(0) = msg->angular_velocity.x;
imu_data_(1) = msg->angular_velocity.y;
imu_data_(2) = msg->angular_velocity.z;
imu_data_(3) = msg->linear_acceleration.x;
imu_data_(4) = msg->linear_acceleration.y;
imu_data_(5) = msg->linear_acceleration.z;
is_ready_ = true;
}
void DroneEnv::gps_pose_callback(const geometry_msgs::msg::PoseStamped::SharedPtr msg) {
std::lock_guard<std::mutex> lock(data_mutex_);
gps_pose_(0) = msg->pose.position.x;
gps_pose_(1) = msg->pose.position.y;
gps_pose_(2) = msg->pose.position.z;
is_ready_ = true;
}
void DroneEnv::gps_twist_callback(const geometry_msgs::msg::TwistStamped::SharedPtr msg) {
std::lock_guard<std::mutex> lock(data_mutex_);
gps_twist_(0) = msg->twist.linear.x;
gps_twist_(1) = msg->twist.linear.y;
gps_twist_(2) = msg->twist.linear.z;
is_ready_ = true;
}
void DroneEnv::lidar_callback(const sensor_msgs::msg::LaserScan::SharedPtr msg) {
std::lock_guard<std::mutex> lock(data_mutex_);
const auto& ranges = msg->ranges;
size_t n = ranges.size();
lidar_data_(0) = *std::min_element(ranges.begin() + n*350/360, ranges.begin() + n*10/360);
lidar_data_(1) = *std::min_element(ranges.begin() + n*170/360, ranges.begin() + n*190/360);
lidar_data_(2) = *std::min_element(ranges.begin() + n*80/360, ranges.begin() + n*100/360);
lidar_data_(3) = *std::min_element(ranges.begin() + n*260/360, ranges.begin() + n*280/360);
lidar_data_(4) = *std::min_element(ranges.begin(), ranges.end(), [](float a, float b){
return a < b && a > 0.1;
});
collision_ = std::any_of(lidar_data_.data(), lidar_data_.data()+5, [](double d){ return d < 0.1; });
is_ready_ = true;
}
Eigen::VectorXd DroneEnv::build_state() {
std::lock_guard<std::mutex> lock(data_mutex_);
Eigen::VectorXd state(STATE_DIM);
Eigen::Vector3d current_target = target_points_[current_target_idx_];
state.segment(0,3) = gps_pose_;
state.segment(3,3) = gps_twist_;
state.segment(6,6) = imu_data_;
state.segment(12,3) = current_target - gps_pose_;
state.segment(15,5) = lidar_data_;
state(20) = (current_target.head(2) - gps_pose_.head(2)).norm();
state(21) = std::abs(current_target.z() - gps_pose_.z());
return state;
}
double DroneEnv::compute_reward(const Eigen::VectorXd& state) {
double err_xy = state(20);
double err_z = state(21);
const Eigen::VectorXd& lidar_dist = state.segment(15,5);
const Eigen::VectorXd& angular_vel = state.segment(6,3);
double track_reward = -0.5*(err_xy * err_xy + err_z * err_z);
double obstacle_reward = 0.0;
for(int i = 0; i < 5; ++i) {
obstacle_reward += (lidar_dist(i)>=0.5)?1.0:-10.0;
}
double smooth_reward = -0.1* angular_vel.norm();
double terminal_reward = (std::sqrt(err_xy*err_xy + err_z*err_z)<=0.05)?100.0:0.0;
double collision_penalty = collision_ ? -200.0 : 0.0;
return track_reward + obstacle_reward + smooth_reward + terminal_reward + collision_penalty;
}
bool DroneEnv::check_done(const Eigen::VectorXd& state) {
double err_xy = state(20);
double err_z = state(21);
const Eigen::Vector3d& gps_pos = state.segment(0,3);
if(collision_) return true;
if(gps_pos.cwiseAbs().maxCoeff() > 10.0) return true;
if(current_target_idx_ == 0 && std::sqrt(err_xy*err_xy + err_z*err_z) <= 0.05) {
RCLCPP_INFO(this->get_logger(), "Task completed! All targets reached");
return true;
}
return false;
}
}
(3)TD3 网络定义(include/drone_rl_cpp/networks/TD3Networks.hpp)
用 LibTorch 实现 Actor(策略网络)和 Critic(价值网络):
#ifndef TD3_NETWORKS_HPP_
#define TD3_NETWORKS_HPP_
#include <torch/torch.h>
#include <Eigen/Dense>
namespace drone_rl_cpp {
class ActorNetwork : public torch::nn::Module {
public:
ActorNetwork(int state_dim, int action_dim, double action_low, double action_high);
torch::Tensor forward(torch::Tensor x);
torch::Tensor eigen_to_tensor(const Eigen::VectorXd& x);
Eigen::VectorXd tensor_to_eigen(const torch::Tensor& x);
private:
torch::nn::Linear fc1_{nullptr}, fc2_{nullptr}, fc3_{nullptr};
double action_low_;
double action_high_;
};
class CriticNetwork : public torch::nn::Module {
public:
CriticNetwork(int state_dim, int action_dim);
torch::Tensor forward(torch::Tensor x, torch::Tensor a);
private:
torch::nn::Linear fc1_{nullptr}, fc2_{nullptr}, fc3_{nullptr};
};
class TwinCriticNetworks : public torch::nn::Module {
public:
TwinCriticNetworks(int state_dim, int action_dim);
std::pair<torch::Tensor, torch::Tensor> forward(torch::Tensor x, torch::Tensor a);
std::shared_ptr<CriticNetwork> get_critic1(){ return critic1_; }
std::shared_ptr<CriticNetwork> get_critic2(){ return critic2_; }
private:
std::shared_ptr<CriticNetwork> critic1_;
std::shared_ptr<CriticNetwork> critic2_;
};
}
#endif
(4)网络实现(src/networks/TD3Networks.cpp)
#include "drone_rl_cpp/networks/TD3Networks.hpp"
namespace drone_rl_cpp {
ActorNetwork::ActorNetwork(int state_dim, int action_dim, double action_low, double action_high) :
action_low_(action_low), action_high_(action_high) {
fc1_ = register_module("fc1", torch::nn::Linear(state_dim, 256));
fc2_ = register_module("fc2", torch::nn::Linear(256, 128));
fc3_ = register_module("fc3", torch::nn::Linear(128, action_dim));
torch::nn::init::xavier_uniform_(fc1_->weight);
torch::nn::init::xavier_uniform_(fc2_->weight);
torch::nn::init::xavier_uniform_(fc3_->weight);
torch::nn::init::constant_(fc1_->bias, 0.01);
torch::nn::init::constant_(fc2_->bias, 0.01);
torch::nn::init::constant_(fc3_->bias, 0.01);
}
torch::Tensor ActorNetwork::forward(torch::Tensor x) {
x = torch::relu(fc1_->forward(x));
x = torch::relu(fc2_->forward(x));
x = torch::tanh(fc3_->forward(x));
return (x + 1.0)*(action_high_ - action_low_)/2.0 + action_low_;
}
torch::Tensor ActorNetwork::eigen_to_tensor(const Eigen::VectorXd& x) {
return torch::from_blob(const_cast<double*>(x.data()), {1, x.size()}, torch::kFloat32).to(torch::kCUDA);
}
Eigen::VectorXd ActorNetwork::tensor_to_eigen(const torch::Tensor& x) {
auto cpu_tensor = x.detach().cpu().squeeze();
Eigen::VectorXd eigen_vec(cpu_tensor.size(0));
std::memcpy(eigen_vec.data(), cpu_tensor.data_ptr(), cpu_tensor.numel()*sizeof(float));
return eigen_vec;
}
CriticNetwork::CriticNetwork(int state_dim, int action_dim) {
fc1_ = register_module("fc1", torch::nn::Linear(state_dim + action_dim, 256));
fc2_ = register_module("fc2", torch::nn::Linear(256, 128));
fc3_ = register_module("fc3", torch::nn::Linear(128, 1));
torch::nn::init::xavier_uniform_(fc1_->weight);
torch::nn::init::xavier_uniform_(fc2_->weight);
torch::nn::init::xavier_uniform_(fc3_->weight);
torch::nn::init::constant_(fc1_->bias, 0.01);
torch::nn::init::constant_(fc2_->bias, 0.01);
torch::nn::init::constant_(fc3_->bias, 0.01);
}
torch::Tensor CriticNetwork::forward(torch::Tensor x, torch::Tensor a) {
torch::Tensor cat = torch::cat({x, a}, 1);
cat = torch::relu(fc1_->forward(cat));
cat = torch::relu(fc2_->forward(cat));
return fc3_->forward(cat);
}
TwinCriticNetworks::TwinCriticNetworks(int state_dim, int action_dim) {
critic1_ = std::make_shared<CriticNetwork>(state_dim, action_dim);
critic2_ = std::make_shared<CriticNetwork>(state_dim, action_dim);
register_module("critic1", critic1_);
register_module("critic2", critic2_);
}
std::pair<torch::Tensor, torch::Tensor> TwinCriticNetworks::forward(torch::Tensor x, torch::Tensor a) {
return {critic1_->forward(x, a), critic2_->forward(x, a)};
}
}
(5)经验回放缓冲区(include/drone_rl_cpp/utils/ReplayBuffer.hpp)
#ifndef REPLAY_BUFFER_HPP_
#define REPLAY_BUFFER_HPP_
#include <Eigen/Dense>
#include <vector>
#include <mutex>
#include <random>
namespace drone_rl_cpp {
struct Transition {
Eigen::VectorXd state;
Eigen::VectorXd action;
double reward;
Eigen::VectorXd next_state;
bool done;
};
class ReplayBuffer {
public:
ReplayBuffer(int capacity, int state_dim, int action_dim);
~ReplayBuffer() = default;
void push(const Transition& transition);
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> sample_batch(int batch_size);
int size() const { return static_cast<int>(buffer_.size()); }
bool is_full() const { return size() >= capacity_; }
private:
int capacity_;
int state_dim_;
int action_dim_;
std::vector<Transition> buffer_;
std::mutex buffer_mutex_;
std::mt19937 rng_;
int write_idx_;
};
}
#endif
(6)缓冲区实现(src/utils/ReplayBuffer.cpp)
#include "drone_rl_cpp/utils/ReplayBuffer.hpp"
#include <torch/torch.h>
namespace drone_rl_cpp {
ReplayBuffer::ReplayBuffer(int capacity, int state_dim, int action_dim) :
capacity_(capacity), state_dim_(state_dim), action_dim_(action_dim),
write_idx_(0), rng_(std::random_device{}) {
buffer_.reserve(capacity_);
}
void ReplayBuffer::push(const Transition& transition) {
std::lock_guard<std::mutex> lock(buffer_mutex_);
if(buffer_.size() < capacity_) {
buffer_.emplace_back(transition);
} else {
buffer_[write_idx_] = transition;
}
write_idx_ = (write_idx_ + 1) % capacity_;
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> ReplayBuffer::sample_batch(int batch_size) {
std::lock_guard<std::mutex> lock(buffer_mutex_);
int current_size = buffer_.size();
if(current_size < batch_size) {
throw std::runtime_error("Replay buffer size (" + std::to_string(current_size) + ") < batch size (" + std::to_string(batch_size) + ")");
}
std::uniform_int_distribution<> dist(0, current_size - 1);
std::vector<int> indices(batch_size);
for(int i = 0; i < batch_size; ++i) {
indices[i] = dist(rng_);
}
torch::Tensor states = torch::zeros({batch_size, state_dim_}, torch::kFloat32).to(torch::kCUDA);
torch::Tensor actions = torch::zeros({batch_size, action_dim_}, torch::kFloat32).to(torch::kCUDA);
torch::Tensor rewards = torch::zeros({batch_size, 1}, torch::kFloat32).to(torch::kCUDA);
torch::Tensor next_states = torch::zeros({batch_size, state_dim_}, torch::kFloat32).to(torch::kCUDA);
torch::Tensor dones = torch::zeros({batch_size, 1}, torch::kFloat32).to(torch::kCUDA);
for(int i = 0; i < batch_size; ++i) {
const auto& t = buffer_[indices[i]];
states[i] = torch::from_blob(const_cast<double*>(t.state.data()), {state_dim_}, torch::kFloat32);
actions[i] = torch::from_blob(const_cast<double*>(t.action.data()), {action_dim_}, torch::kFloat32);
rewards[i] = t.reward;
next_states[i] = torch::from_blob(const_cast<double*>(t.next_state.data()), {state_dim_}, torch::kFloat32);
dones[i] = t.done ? 1.0f : 0.0f;
}
return {states, actions, rewards, next_states, dones};
}
}
(7)TD3 智能体(include/drone_rl_cpp/TD3Agent.hpp)
#ifndef TD3_AGENT_HPP_
#define TD3_AGENT_HPP_
#include "drone_rl_cpp/networks/TD3Networks.hpp"
#include "drone_rl_cpp/utils/ReplayBuffer.hpp"
#include <torch/torch.h>
#include <Eigen/Dense>
#include <vector>
namespace drone_rl_cpp {
class TD3Agent {
public:
TD3Agent(int state_dim, int action_dim, double action_low, double action_high,
int buffer_capacity = 1000000, int batch_size = 256,
double gamma = 0.99, double tau = 0.005,
double lr_actor = 3e-4, double lr_critic = 3e-4,
double policy_noise = 0.1, double noise_clip = 0.2,
int policy_freq = 2);
~TD3Agent() = default;
Eigen::VectorXd select_action(const Eigen::VectorXd& state, bool is_training = true);
double train();
void save_model(const std::string& path);
void load_model(const std::string& path);
ReplayBuffer& get_replay_buffer() { return replay_buffer_; }
private:
std::shared_ptr<ActorNetwork> actor_;
std::shared_ptr<ActorNetwork> target_actor_;
std::shared_ptr<TwinCriticNetworks> critics_;
std::shared_ptr<TwinCriticNetworks> target_critics_;
torch::optim::Adam actor_optimizer_;
torch::optim::Adam critics_optimizer_;
ReplayBuffer replay_buffer_;
int state_dim_;
int action_dim_;
double action_low_;
double action_high_;
int batch_size_;
double gamma_;
double tau_;
double policy_noise_;
double noise_clip_;
int policy_freq_;
int update_count_;
std::mt19937 rng_;
std::normal_distribution<> noise_dist_;
};
}
#endif
(8)智能体实现(src/TD3Agent.cpp)
#include "drone_rl_cpp/TD3Agent.hpp"
#include <torch/torch.h>
#include <fstream>
namespace drone_rl_cpp {
TD3Agent::TD3Agent(int state_dim, int action_dim, double action_low, double action_high,
int buffer_capacity, int batch_size, double gamma, double tau,
double lr_actor, double lr_critic, double policy_noise, double noise_clip, int policy_freq) :
state_dim_(state_dim), action_dim_(action_dim), action_low_(action_low), action_high_(action_high),
batch_size_(batch_size), gamma_(gamma), tau_(tau), policy_noise_(policy_noise), noise_clip_(noise_clip),
policy_freq_(policy_freq), update_count_(0), rng_(std::random_device{}()),
noise_dist_(0.0, policy_noise),
replay_buffer_(buffer_capacity, state_dim, action_dim),
actor_(std::make_shared<ActorNetwork>(state_dim, action_dim, action_low, action_high)),
target_actor_(std::make_shared<ActorNetwork>(state_dim, action_dim, action_low, action_high)),
critics_(std::make_shared<TwinCriticNetworks>(state_dim, action_dim)),
target_critics_(std::make_shared<TwinCriticNetworks>(state_dim, action_dim)),
actor_optimizer_(actor_->parameters(), torch::optim::AdamOptions(lr_actor)),
critics_optimizer_(critics_->parameters(), torch::optim::AdamOptions(lr_critic)) {
actor_->to(torch::kCUDA);
target_actor_->to(torch::kCUDA);
critics_->to(torch::kCUDA);
target_critics_->to(torch::kCUDA);
target_actor_->load_state_dict(actor_->state_dict());
target_critics_->load_state_dict(critics_->state_dict());
for(auto& param : target_actor_->parameters()) { param.requires_grad_(false); }
for(auto& param : target_critics_->parameters()) { param.requires_grad_(false); }
}
Eigen::VectorXd TD3Agent::select_action(const Eigen::VectorXd& state, bool is_training) {
torch::NoGradGuard no_grad;
torch::Tensor state_tensor = actor_->eigen_to_tensor(state);
torch::Tensor action_tensor = actor_->forward(state_tensor);
if(is_training) {
Eigen::VectorXd noise(action_dim_);
for(int i = 0; i < action_dim_; ++i) {
noise(i) = noise_dist_(rng_);
}
noise = noise.cwiseMax(-noise_clip_).cwiseMin(noise_clip_);
Eigen::VectorXd action = actor_->tensor_to_eigen(action_tensor) + noise;
return action.cwiseMax(action_low_).cwiseMin(action_high_);
}
return actor_->tensor_to_eigen(action_tensor);
}
double TD3Agent::train() {
auto[states, actions, rewards, next_states, dones] = replay_buffer_.sample_batch(batch_size_);
critics_optimizer_.zero_grad();
torch::Tensor target_actions = target_actor_->forward(next_states);
torch::Tensor noise = torch::randn_like(target_actions)* policy_noise_;
noise = noise.clamp(-noise_clip_, noise_clip_);
target_actions = (target_actions + noise).clamp(action_low_, action_high_);
auto[target_q1, target_q2] = target_critics_->forward(next_states, target_actions);
torch::Tensor target_q = torch::min(target_q1, target_q2);
torch::Tensor target_q_values = rewards + (1.0 - dones)* gamma_ * target_q;
auto[q1, q2] = critics_->forward(states, actions);
torch::Tensor critic_loss = torch::mse_loss(q1, target_q_values) + torch::mse_loss(q2, target_q_values);
critic_loss.backward();
critics_optimizer_.step();
if(update_count_ % policy_freq_ == 0) {
actor_optimizer_.zero_grad();
torch::Tensor actor_actions = actor_->forward(states);
torch::Tensor actor_loss = -critics_->get_critic1()->forward(states, actor_actions).mean();
actor_loss.backward();
actor_optimizer_.step();
for(auto&[target_param, param] : std::make_pair(target_actor_->parameters(), actor_->parameters())) {
target_param.data().copy_(tau_ * param.data() + (1.0 - tau_)* target_param.data());
}
for(auto&[target_param, param] : std::make_pair(target_critics_->parameters(), critics_->parameters())) {
target_param.data().copy_(tau_ * param.data() + (1.0 - tau_)* target_param.data());
}
}
update_count_++;
return critic_loss.item<double>();
}
void TD3Agent::save_model(const std::string& path) {
torch::save(actor_, path + "/actor.pt");
torch::save(critics_, path + "/critics.pt");
torch::save(actor_optimizer_, path + "/actor_optimizer.pt");
torch::save(critics_optimizer_, path + "/critics_optimizer.pt");
RCLCPP_INFO(rclcpp::get_logger("TD3Agent"), "Model saved to %s", path.c_str());
}
void TD3Agent::load_model(const std::string& path) {
torch::load(actor_, path + "/actor.pt");
torch::load(critics_, path + "/critics.pt");
torch::load(actor_optimizer_, path + "/actor_optimizer.pt");
torch::load(critics_optimizer_, path + "/critics_optimizer.pt");
target_actor_->load_state_dict(actor_->state_dict());
target_critics_->load_state_dict(critics_->state_dict());
actor_->to(torch::kCUDA);
target_actor_->to(torch::kCUDA);
critics_->to(torch::kCUDA);
target_critics_->to(torch::kCUDA);
RCLCPP_INFO(rclcpp::get_logger("TD3Agent"), "Model loaded from %s", path.c_str());
}
}
(9)TensorRT 推理封装(include/drone_rl_cpp/utils/TrtInfer.hpp)
#ifndef TRT_INFER_HPP_
#define TRT_INFER_HPP_
#include <tensorrt/NvInfer.h>
#include <cuda_runtime_api.h>
#include <Eigen/Dense>
#include <string>
#include <memory>
#include <vector>
namespace drone_rl_cpp {
class TrtInfer {
public:
TrtInfer(const std::string& engine_path);
~TrtInfer();
Eigen::VectorXd infer(const Eigen::VectorXd& input);
private:
class TrtDeleter {
public:
void operator()(nvinfer1::ICudaEngine* engine) const { engine->destroy(); }
void operator()(nvinfer1::IExecutionContext* context) const { context->destroy(); }
void operator()(nvinfer1::IRuntime* runtime) const { runtime->destroy(); }
};
std::unique_ptr<nvinfer1::IRuntime, TrtDeleter> runtime_;
std::unique_ptr<nvinfer1::ICudaEngine, TrtDeleter> engine_;
std::unique_ptr<nvinfer1::IExecutionContext, TrtDeleter> context_;
void* d_input_ = nullptr;
void* d_output_ = nullptr;
static constexpr int INPUT_DIM = 22;
static constexpr int OUTPUT_DIM = 4;
size_t input_size_;
size_t output_size_;
};
}
#endif
(10)TensorRT 推理实现(src/utils/TrtInfer.cpp)
#include "drone_rl_cpp/utils/TrtInfer.hpp"
#include <fstream>
#include <iostream>
#include <rclcpp/rclcpp.hpp>
using namespace nvinfer1;
namespace drone_rl_cpp {
TrtInfer::TrtInfer(const std::string& engine_path) {
std::ifstream engine_file(engine_path, std::ios::binary);
if(!engine_file) {
throw std::runtime_error("Failed to open TRT engine file: " + engine_path);
}
engine_file.seekg(0, std::ios::end);
const size_t engine_size = engine_file.tellg();
engine_file.seekg(0, std::ios::beg);
std::vector<char> engine_data(engine_size);
engine_file.read(engine_data.data(), engine_size);
IRuntime* runtime = createInferRuntime(Logger(Logger::WARNING));
if(!runtime) {
throw std::runtime_error("Failed to create TRT runtime");
}
runtime_.reset(runtime);
ICudaEngine* engine = runtime->deserializeCudaEngine(engine_data.data(), engine_size, nullptr);
if(!engine) {
throw std::runtime_error("Failed to deserialize TRT engine");
}
engine_.reset(engine);
IExecutionContext* context = engine->createExecutionContext();
if(!context) {
throw std::runtime_error("Failed to create TRT execution context");
}
context_.reset(context);
input_size_ = INPUT_DIM * sizeof(float);
output_size_ = OUTPUT_DIM * sizeof(float);
cudaMalloc(&d_input_, input_size_);
cudaMalloc(&d_output_, output_size_);
RCLCPP_INFO(rclcpp::get_logger("TrtInfer"), "TRT engine initialized successfully (input dim: %d, output dim: %d)", INPUT_DIM, OUTPUT_DIM);
}
TrtInfer::~TrtInfer() {
cudaFree(d_input_);
cudaFree(d_output_);
}
Eigen::VectorXd TrtInfer::infer(const Eigen::VectorXd& input) {
if(input.size() != INPUT_DIM) {
throw std::runtime_error("Input dimension mismatch: expected " + std::to_string(INPUT_DIM) + ", got " + std::to_string(input.size()));
}
std::vector<float> input_host(INPUT_DIM);
for(int i = 0; i < INPUT_DIM; ++i) {
input_host[i] = static_cast<float>(input(i));
}
cudaMemcpy(d_input_, input_host.data(), input_size_, cudaMemcpyHostToDevice);
void* bindings[] = {d_input_, d_output_};
context_->executeV2(bindings);
std::vector<float> output_host(OUTPUT_DIM);
cudaMemcpy(output_host.data(), d_output_, output_size_, cudaMemcpyDeviceToHost);
Eigen::VectorXd output(OUTPUT_DIM);
for(int i = 0; i < OUTPUT_DIM; ++i) {
output(i) = static_cast<double>(output_host[i]);
}
return output;
}
}
3. ROS2 训练节点(src/train_node.cpp)
#include "drone_rl_cpp/env/DroneEnv.hpp"
#include "drone_rl_cpp/TD3Agent.hpp"
#include <rclcpp/rclcpp.hpp>
#include <chrono>
#include <iomanip>
using namespace drone_rl_cpp;
using namespace std::chrono;
int main(int argc, char* argv[]) {
rclcpp::init(argc, argv);
auto node = std::make_shared<rclcpp::Node>("td3_train_node");
auto env = std::make_shared<DroneEnv>();
if(!env->is_ready()) {
RCLCPP_FATAL(node->get_logger(), "Environment not ready, exit");
return -1;
}
TD3Agent agent(
DroneEnv::STATE_DIM, DroneEnv::ACTION_DIM, DroneEnv::ACTION_LOW, DroneEnv::ACTION_HIGH,
1000000,
256,
0.99,
0.005,
3e-4,
3e-4,
0.1,
0.2,
2
);
const int total_episodes = 500;
const int max_steps_per_episode = 1000;
const int start_train_steps = 10000;
double total_steps = 0;
RCLCPP_INFO(node->get_logger(), "Start TD3 training (total episodes: %d)", total_episodes);
for(int ep = 0; ep < total_episodes && rclcpp::ok(); ++ep) {
Eigen::VectorXd state = env->reset();
double ep_reward = 0.0;
bool ep_done = false;
auto ep_start = high_resolution_clock::now();
for(int step = 0; step < max_steps_per_episode && !ep_done; ++step) {
Eigen::VectorXd action;
if(total_steps < start_train_steps) {
std::uniform_real_distribution<> dist(DroneEnv::ACTION_LOW, DroneEnv::ACTION_HIGH);
action.resize(DroneEnv::ACTION_DIM);
for(int i = 0; i < DroneEnv::ACTION_DIM; ++i) {
action(i) = dist(agent.get_replay_buffer().rng_);
}
} else {
action = agent.select_action(state, true);
}
auto[next_state, reward, done, info] = env->step(action);
ep_reward += reward;
ep_done = done;
Transition transition{state, action, reward, next_state, done};
agent.get_replay_buffer().push(transition);
if(total_steps >= start_train_steps) {
double loss = agent.train();
if(step % 100 == 0) {
RCLCPP_INFO(node->get_logger(), "Episode %d, Step %d, Critic Loss: %.4f", ep, step, loss);
}
}
state = next_state;
total_steps++;
}
auto ep_end = high_resolution_clock::now();
double ep_duration = duration_cast<duration<double>>(ep_end - ep_start).count();
RCLCPP_INFO(node->get_logger(), "Episode [%d/%d] | Reward: %.2f | Steps: %d | Duration: %.2fs | Info: %s | Total Steps: %.0f",
ep + 1, total_episodes, ep_reward, step, ep_duration, info.c_str(), total_steps);
if((ep + 1) % 50 == 0) {
std::string model_path = "./models/td3_ep" + std::to_string(ep + 1);
agent.save_model(model_path);
}
}
agent.save_model("./models/td3_final");
RCLCPP_INFO(node->get_logger(), "Training completed! Final model saved to ./models/td3_final");
rclcpp::shutdown();
return 0;
}
4. ROS2 推理控制节点(src/infer_node.cpp)
#include "drone_rl_cpp/env/DroneEnv.hpp"
#include "drone_rl_cpp/utils/TrtInfer.hpp"
#include <rclcpp/rclcpp.hpp>
#include <chrono>
#include <sched.h>
using namespace drone_rl_cpp;
using namespace std::chrono;
int main(int argc, char* argv[]) {
rclcpp::init(argc, argv);
auto node = std::make_shared<rclcpp::Node>("td3_infer_node");
struct sched_param param;
param.sched_priority = 99;
if(sched_setscheduler(0, SCHED_FIFO, ¶m) == -1) {
RCLCPP_WARN(node->get_logger(), "Failed to set real-time priority: %s", strerror(errno));
}
std::string engine_path = node->declare_parameter<std::string>("engine_path", "./models/td3_final_fp16.engine");
auto env = std::make_shared<DroneEnv>();
if(!env->is_ready()) {
RCLCPP_FATAL(node->get_logger(), "Environment not ready, exit");
return -1;
}
std::unique_ptr<TrtInfer> trt_infer;
try {
trt_infer = std::make_unique<TrtInfer>(engine_path);
} catch(const std::exception& e) {
RCLCPP_FATAL(node->get_logger(), "Failed to initialize TRT infer: %s", e.what());
return -1;
}
int total_steps = 0;
double total_delay = 0.0;
const int stat_window = 100;
RCLCPP_INFO(node->get_logger(), "Start TD3 inference (control frequency target: 200Hz)");
Eigen::VectorXd state = env->reset();
while(rclcpp::ok()) {
auto start = high_resolution_clock::now();
Eigen::VectorXd action = trt_infer->infer(state);
action = action.cwiseMax(DroneEnv::ACTION_LOW).cwiseMin(DroneEnv::ACTION_HIGH);
auto[next_state, reward, done, info] = env->step(action);
auto end = high_resolution_clock::now();
double delay = duration_cast<duration<double, std::milli>>(end - start).count();
total_delay += delay;
total_steps++;
state = next_state;
if(total_steps % stat_window == 0) {
double avg_delay = total_delay / stat_window;
double freq = 1000.0 / avg_delay;
RCLCPP_INFO(node->get_logger(), "Step: %d | Avg Delay: %.2fms | Control Freq: %.1fHz | Reward: %.2f | Info: %s",
total_steps, avg_delay, freq, reward, info.c_str());
total_delay = 0.0;
}
if(done) {
RCLCPP_INFO(node->get_logger(), "Episode done, reset environment");
state = env->reset();
}
}
rclcpp::shutdown();
return 0;
}
四、模型转换(LibTorch→ONNX→TensorRT)
1. LibTorch 模型导出为 ONNX(C++ 代码)
创建 src/utils/export_onnx.cpp:
#include "drone_rl_cpp/networks/TD3Networks.hpp"
#include <torch/torch.h>
#include <fstream>
int main(int argc, char* argv[]) {
if(argc != 3) {
std::cerr << "Usage: " << argv[0] << " <libtorch_model_path> <output_onnx_path>" << std::endl;
return -1;
}
std::string model_path = argv[1];
std::string onnx_path = argv[2];
int state_dim = 22;
int action_dim = 4;
double action_low = 500.0;
double action_high = 2000.0;
auto actor = std::make_shared<drone_rl_cpp::ActorNetwork>(state_dim, action_dim, action_low, action_high);
actor->to(torch::kCUDA);
torch::load(actor, model_path + "/actor.pt");
actor->eval();
torch::Tensor dummy_input = torch::randn({1, state_dim}, torch::kFloat32).to(torch::kCUDA);
torch::onnx::export_to_onnx(*actor, dummy_input, onnx_path, torch::onnx::ExportConfig(),
{torch::onnx::OperatorExportTypes::ONNX_ATEN_FALLBACK});
std::cout << "ONNX model exported to: " << onnx_path << std::endl;
return 0;
}
2. ONNX 转换为 TensorRT 引擎(C++ 代码)
创建 src/utils/convert_trt.cpp:
#include <tensorrt/NvInfer.h>
#include <tensorrt/NvOnnxParser.h>
#include <cuda_runtime_api.h>
#include <fstream>
#include <iostream>
using namespace nvinfer1;
using namespace nvonnxparser;
int main(int argc, char* argv[]) {
if(argc != 3) {
std::cerr << "Usage: " << argv[0] << " <onnx_path> <output_engine_path>" << std::endl;
return -1;
}
std::string onnx_path = argv[1];
std::string engine_path = argv[2];
Logger logger(Logger::WARNING);
IBuilder* builder = createInferBuilder(logger);
INetworkDefinition* network = builder->createNetworkV2(1U<<static_cast<int>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));
IParser* parser = createParser(*network, logger);
if(!parser->parseFromFile(onnx_path.c_str(), static_cast<int>(Logger::WARNING))) {
std::cerr << "Failed to parse ONNX model" << std::endl;
return -1;
}
IBuilderConfig* config = builder->createBuilderConfig();
config->setMaxWorkspaceSize(1ULL<<30);
config->setFlag(BuilderFlag::kFP16);
ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
if(!engine) {
std::cerr << "Failed to build TensorRT engine" << std::endl;
return -1;
}
IHostMemory* serialized_engine = engine->serialize();
std::ofstream engine_file(engine_path, std::ios::binary);
engine_file.write(reinterpret_cast<const char*>(serialized_engine->data()), serialized_engine->size());
serialized_engine->destroy();
engine->destroy();
config->destroy();
network->destroy();
parser->destroy();
builder->destroy();
std::cout << "TensorRT engine saved to: " << engine_path << std::endl;
return 0;
}
五、CMakeLists.txt 配置
cmake_minimum_required(VERSION 3.20)
project(drone_rl_cpp)
if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
add_compile_options(-Wall -Wextra -Wpedantic -O3 -std=c++17)
endif()
# 查找依赖包
find_package(ament_cmake REQUIRED)
find_package(rclcpp REQUIRED)
find_package(sensor_msgs REQUIRED)
find_package(geometry_msgs REQUIRED)
find_package(std_msgs REQUIRED)
find_package(gazebo_ros2_control REQUIRED)
find_package(ros2_control REQUIRED)
find_package(Eigen3 REQUIRED)
find_package(OpenCV REQUIRED)
find_package(PCL REQUIRED COMPONENTS common io)
# LibTorch 依赖(自动查找)
find_package(Torch REQUIRED)
message(STATUS "LibTorch found: ${Torch_FOUND}, Version: ${Torch_VERSION}")
# TensorRT 和 CUDA 依赖
find_package(CUDAToolkit REQUIRED)
find_library(NVINFER_LIB nvinfer HINTS /usr/lib/aarch64-linux-gnu/)
find_library(NVONNXPARSER_LIB nvonnxparser HINTS /usr/lib/aarch64-linux-gnu/)
message(STATUS "TensorRT libs: ${NVINFER_LIB}, ${NVONNXPARSER_LIB}")
# 包含目录
include_directories(
include
${EIGEN3_INCLUDE_DIRS}
${OpenCV_INCLUDE_DIRS}
${PCL_INCLUDE_DIRS}
${TORCH_INCLUDE_DIRS}
)
# 链接库
link_directories(
${TORCH_LIBRARIES}
${NVINFER_LIB}
${NVONNXPARSER_LIB}
)
# 源文件
set(SOURCES
src/env/DroneEnv.cpp
src/networks/TD3Networks.cpp
src/utils/ReplayBuffer.cpp
src/utils/TrtInfer.cpp
src/TD3Agent.cpp
src/train_node.cpp
src/infer_node.cpp
src/utils/export_onnx.cpp
src/utils/convert_trt.cpp
)
# 可执行文件
add_executable(td3_train_node ${SOURCES})
add_executable(td3_infer_node ${SOURCES})
# 链接库
target_link_libraries(td3_train_node
${TORCH_LIBRARIES}
${NVINFER_LIB}
${NVONNXPARSER_LIB}
${PCL_LIBRARIES}
${OpenCV_LIBS}
rclcpp
sensor_msgs
geometry_msgs
std_msgs
gazebo_ros2_control
ros2_control
Eigen3::Eigen
)
target_link_libraries(td3_infer_node
${TORCH_LIBRARIES}
${NVINFER_LIB}
${NVONNXPARSER_LIB}
${PCL_LIBRARIES}
${OpenCV_LIBS}
rclcpp
sensor_msgs
geometry_msgs
std_msgs
gazebo_ros2_control
ros2_control
Eigen3::Eigen
)
# 安装规则
install(TARGETS td3_train_node td3_infer_node DESTINATION lib/${PROJECT_NAME})
ament_package()

相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- 随机西班牙地址生成器
随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
- Gemini 图片去水印
基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
- Base64 字符串编码/解码
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online