π0 源码深度解析
前言
π0 (OpenPI) 是一套用于通用机器人控制的视觉 - 语言 - 动作(VLA)模型框架。本文旨在深入剖析其源码实现,从模型架构设计、训练配置到 C/S 架构下的实际部署流程,帮助开发者理解如何基于 PaLI-Gemma 和扩散策略去噪生成动作。
第一部分 π0 模型架构的实现
1.1 models/model.py:核心基础模型定义
这是模型框架的核心文件,定义了基础的抽象类和数据结构。
BaseModelConfig: 所有模型配置的抽象基类BaseModel: 所有模型实现的抽象基类Observation: 保存模型输入的数据类Actions: 定义动作数据格式
1.1.1 基础组件和关键常量
模型类型枚举定义了两种支持的模型类型:
class ModelType(enum.Enum):
"""Supported model types."""
PI0 = "pi0"
PI0_FAST = "pi0_fast"
图像输入配置定义了模型期望接收三个视角的图像:
IMAGE_KEYS = (
"base_0_rgb",
"left_wrist_0_rgb",
"right_wrist_0_rgb",
)
1.1.2 Observation 类详解
Observation 类是 OpenPI 框架中的核心数据结构,用于存储和管理模型的输入数据。它包含了机器人感知系统收集的所有必要信息,包括低维度的机器人状态、图像数据及掩码等。
class Observation(Generic[ArrayT]):
"""Holds observations, i.e., inputs to the model."""
images: dict[str, at.Float[ArrayT, "*b h w c"]]
image_masks: dict[str, at.Bool[ArrayT, "*b"]]
state: at.Float[ArrayT, "*b s"]
# Tokenized prompt fields for FAST autoregressive model...
from_dict 方法负责从非结构化的字典数据创建 Observation 对象,并处理图像格式的转换(如将 uint8 转换为 float32)。
1.2 models/pi0.py 的实现
Pi0 是一个多模态扩散模型,继承自 BaseModel,使用 SigLIP 处理视觉输入、使用 Gemma 处理语言输入,实现了基于扩散的动作生成系统。
1.2.1 make_attn_mask:注意力掩码生成函数
该函数生成 Transformer 中使用的注意力掩码,控制 token 之间的注意力流动方式,支持纯因果注意力、前缀语言模型注意力等多种模式。
1.2.2 posemb_sincos:位置编码函数
使用正弦余弦函数实现位置编码,确保模型能感知序列顺序。
1.2.3 Pi0Config:模型配置
定义了动作专家底层结构(如 300M 大小的 Gemma 变体)及输入规格。
class Pi0Config(_model.BaseModelConfig):
dtype: str = "bfloat16"
paligemma_variant: _gemma.Variant = "gemma_2b"
action_expert_variant: _gemma.Variant = "gemma_300m"
action_dim: int = 32
action_horizon: int = 50
1.2.4 class Pi0:核心模型类
实现了多模态输入处理、扩散过程(训练去噪、推理采样)及注意力机制。
1.2.4.1 初始化方法
组合了 PaLI-Gemma 模型(Gemma LLM + SigLIP 视觉模型)、线性投影层等核心组件。
1.2.4.2 特征嵌入方法
embed_prefix: 处理图像和文本输入,创建前缀 token,采用双向注意力。embed_suffix: 处理机器人状态信息和噪声化动作信息,创建后缀 token。
1.2.4.3 损失函数 compute_loss
训练时,对原始动作添加噪声,让模型学习预测噪声。计算预测噪声与实际噪声之间的均方误差。
# 创建带噪声的动作
x_t = time_expanded * noise + (1 - time_expanded) * actions
# 计算真实噪声
u_t = noise - actions
1.2.4.4 推理函数 sample_actions
基于扩散模型的逆向采样过程,从纯噪声开始,通过多步骤逐渐去噪,最终生成符合条件的机器人动作序列。使用了 KV 缓存优化推理速度。
1.3 语言模型与视觉模型实现
models/gemma.py: 实现了 Gemma 语言模型的核心组件。models/siglip.py: 实现了基于 Vision Transformer 的视觉编码器。
1.4 tokenizer.py
提供了文本 tokenization 功能,包括 PaligemmaTokenizer 和 FASTTokenizer。
第二部分 策略适配接口
src/openpi/policies 目录包含策略适配逻辑,针对不同机器人平台(ALOHA、DROID、LIBERO)定义了特定的输入输出转换函数。
2.1 policy.py
定义了基本的 Policy 类和 PolicyRecorder 类,负责调用模型进行推理及记录行为。
2.2 policy_config.py
定义了 PolicyConfig 数据和工厂函数 create_trained_policy,用于从检查点加载模型并创建可用策略实例。
2.3 policies/aloha_policy.py
实现了 ALOHA 策略的输入输出处理和数据转换,包括状态向量、图像数据的标准化及特定摄像头的映射处理。
第三部分 模型训练的配置
training 模块负责训练相关功能,包括检查点管理、配置系统、数据加载器及优化器等。
3.1 配置系统 (config.py)
定义了 TrainConfig、DataConfig 等配置类,支持预定义多种常用配置(如 ALOHA、Libero)。
3.2 数据加载系统 data_loader.py
实现了数据集接口,支持真实数据集(LeRobot)和模拟数据(FakeDataset),并提供数据归一化和转换功能。
3.3 优化器系统 (optimizer.py)
定义了多种学习率调度策略(如余弦衰减)和优化器配置(如 AdamW)。
3.4 检查点系统 (checkpoints.py)
负责模型状态的保存和恢复,使用 Orbax 库实现高效的检查点存储。
3.5 模型分片系统 (sharding.py)
实现分布式训练时的模型参数分片,支持 FSDP 全参数数据并行。
第四部分 模型的训练与部署
4.1 packages/openpi-client
提供与策略服务器通信的接口,使用 WebSocketClientPolicy 连接服务器,处理观察数据和动作序列。
4.2 scripts(策略服务器)
包含数据处理、模型训练和服务部署脚本。
4.2.3 serve_policy.py
启动策略服务,提供模型推理接口。支持定义特定任务的文本指令 prompt。
Prompt 流转流程:
- 设定 Prompt: 在
serve_policy.py中通过default_prompt参数或客户端请求传入。 - 服务器初始化: 启动 WebSocket 服务器,加载策略模型,将元数据(含默认 prompt)发送给客户端。
- 客户端请求: 客户端发送观察数据(图像 + 状态),若未包含 prompt 则使用服务器下发的默认值。
- 模型推理: 模型获取全部输入数据,生成动作序列返回给客户端。
第五部分 examples 示例
examples 模块提供了各种机器人平台及策略客户端的示例实现。
5.1 aloha_real
用于控制真实 ALOHA 双臂机器人的完整实现,包含硬件驱动、环境接口及主控制流程。
部署方式:
- Docker 部署: 直接运行 compose 命令。
- 本地部署: 需启动三个终端分别负责 ROS 硬件层、主控进程和推理服务。
# 启动推理服务
uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
# 启动 ROS 驱动
roslaunch aloha ros_nodes.launch
# 运行机器人控制主程序
python -m examples.aloha_real.main
综上,OpenPI 项目通过清晰的模块化设计,实现了从模型训练到真实机器人部署的全流程闭环,为具身智能领域的开发提供了坚实基础。


