跳到主要内容
π0 源码深度剖析:从 PaLI-Gemma 扩散策略到 C/S 架构部署 | 极客日志
Python AI 算法
π0 源码深度剖析:从 PaLI-Gemma 扩散策略到 C/S 架构部署 综述由AI生成 OpenPI 项目源码解析涵盖模型架构、策略接口、训练配置及 C/S 部署流程。核心基于 PaLI-Gemma 视觉语言模型与扩散策略,通过多视角图像与状态输入生成机器人动作序列。文章深入剖析了 Pi0 模型的注意力掩码机制、LoRA 微调策略、数据加载管道以及基于 WebSocket 的客户端 - 服务器通信架构。内容涉及 Aloha、Libero 等平台的适配实现,为具身智能领域的模型复现与工程落地提供详细的技术参考。
观心 发布于 2026/3/16 更新于 2026/4/26 4 浏览π0 源码深度剖析
前言
随着大模型技术的快速发展,具身智能领域迎来了新的突破。OpenPI(π0)项目作为一套基于视觉语言动作(VLA)的通用机器人控制框架,展示了如何利用 PaLI-Gemma 和扩散策略实现高效的动作生成。本文将对 OpenPI 的源码结构进行系统性解读,涵盖模型架构、策略适配、训练配置以及基于客户端 - 服务器(C/S)架构的部署流程。
第一部分 π0 模型架构的实现:src 下 models 的全面分析与解读
1.1 models/model.py:核心基础模型的定义
这是模型框架的核心文件,定义了基础的抽象类和数据结构:
BaseModelConfig: 所有模型配置的抽象基类
BaseModel: 所有模型实现的抽象基类
Observation: 保存模型输入的数据类
Actions: 定义动作数据格式
提供了通用功能如 preprocess_observation 和 restore_params
1.1.1 基础组件和关键常量
首先是模型类型枚举,定义了两种支持的模型类型:
class ModelType (enum.Enum):
"""Supported model types."""
PI0 = "pi0"
PI0_FAST = "pi0_fast"
接下来是图像输入配置,定义了模型期望的图像输入的键名。这表明模型设计为同时接收三个视角的图像:
一个基础视图(机器人环境的全局视图)
左手腕视图(来自左手腕摄像头)
右手腕视图(来自右手腕摄像头)
IMAGE_KEYS = (
"base_0_rgb" ,
"left_wrist_0_rgb" ,
"right_wrist_0_rgb" ,
)
再其次,是图像分辨率设置——定义了模型处理图像的标准分辨率为 224×224 像素。
IMAGE_RESOLUTION = (224 , 224 )
1.1.2 Observation 类与 Actions 类型的详解
Observation 类是 OpenPI 框架中的一个核心数据结构,用于存储和管理模型的输入数据。
首先,它包含了机器人感知系统收集的所有必要信息:
PI0-FAST 模型特有字段
token_ar_mask: 自回归模型的标记掩码
token_loss_mask: 损失计算的标记掩码
语言提示相关字段
tokenized_prompt: 已经 tokenized 的语言提示
: 语言提示的掩码
tokenized_prompt_mask
state: at.Float[ArrayT, "*b s" ]
类型:at.Float[ArrayT, "*b s"]
用途:存储低维度的机器人状态向量
维度:*b 表示批量维度,s 表示状态向量维度
class Observation (Generic [ArrayT]):
"""Holds observations, i.e., inputs to the model..."""
images: dict [str , at.Float[ArrayT, "*b h w c" ]]
接下来,定义了 from_dict 方法,用于从非结构化的字典数据创建 Observation 对象:
return cls(
images=data["image" ],
image_masks=data["image_mask" ],
state=data["state" ],
tokenized_prompt=data.get("tokenized_prompt" ),
tokenized_prompt_mask=data.get("tokenized_prompt_mask" ),
token_ar_mask=data.get("token_ar_mask" ),
token_loss_mask=data.get("token_loss_mask" ),
)
最后,在类外定义了 Actions 类型,用于表示模型的输出动作:
Actions = at.Float[ArrayT, "*b ah ad" ]
关于 State 和 Action 的区别:State 代表机器人当前的状态信息(如关节角度、末端位置),而 Action 代表机器人应该执行的下一步控制命令(如目标关节角度或增量变化)。State 描述'我在哪里',Action 描述'我要去哪里'。
1.1.3 preprocess_observation
1.1.4 BaseModelConfig(abc.ABC)
1.1.5 class BaseModel(nnx.Module, abc.ABC)
1.1.6 restore_params
1.2 models/pi0.py 的实现 Pi0 是一个多模态扩散模型:继承自 BaseModel,使用 SigLIP 处理视觉输入、使用 Gemma 处理语言输入,实现了基于扩散的动作生成系统,且包含 compute_loss 和 sample_actions 方法的实现。
1.2.1 make_attn_mask:注意力掩码生成函数 这个函数生成 transformer 中使用的注意力掩码,控制 token 之间的注意力流动方式。
def make_attn_mask (input_mask, mask_ar ):
"""
从 big_vision 项目改编的注意力掩码生成函数
Token 可以关注那些累积 mask_ar 小于等于自己的有效输入 token。
"""
mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
cumsum = jnp.cumsum(mask_ar, axis=1 )
attn_mask = cumsum[:, None , :] <= cumsum[:, :, None ]
valid_mask = input_mask[:, None , :] * input_mask[:, :, None ]
return jnp.logical_and(attn_mask, valid_mask)
它支持多种注意力模式:纯因果注意力、前缀语言模型注意力、块状因果注意力。
1.2.2 posemb_sincos:位置编码函数 def posemb_sincos (
pos: at.Real[at.Array, Any ],
embedding_dim: int ,
min_period: float ,
max_period: float
) -> at.Float[at.Array, f"b {embedding_dim} " ]:
"""计算标量位置的正弦余弦位置嵌入向量"""
if embedding_dim % 2 != 0 :
raise ValueError(f"embedding_dim ({embedding_dim} ) must be divisible by 2" )
fraction = jnp.linspace(0.0 , 1.0 , embedding_dim // 2 )
period = min_period * (max_period / min_period) ** fraction
sinusoid_input = jnp.einsum(
"i,j->ij" , pos, 1.0 / period * 2 * jnp.pi,
precision=jax.lax.Precision.HIGHEST,
)
return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1 )
1.2.3 class Pi0Config:含 inputs_spec、get_freeze_filter
1.2.3.1 模型配置参数的定义 class Pi0Config (_model.BaseModelConfig):
dtype: str = "bfloat16"
paligemma_variant: _gemma.Variant = "gemma_2b"
action_expert_variant: _gemma.Variant = "gemma_300m"
action_dim: int = 32
action_horizon: int = 50
max_token_len: int = 48
1.2.3.2 inputs_spec:定义了π0 模型本身接收的输入数据格式 def inputs_spec (self, *, batch_size: int = 1 ) -> Tuple [Type [_model.Observation], Type [_model.Actions]]:
具体操作包括创建图像规格、图像掩码规格、观察规格(包含视觉输入、机器人状态、指令输入)以及动作规格。
1.2.3.3 get_freeze_filter:参数冻结器 该配置类实现了 get_freeze_filter 函数,作用是如果选择 LoRA 微调,则需要对模型中的某些参数做冻结。
只对动作专家使用 LoRA
对两者都使用 LoRA
只对 PaLI-Gemma 使用 LoRA
def get_freeze_filter (self ) -> nnx.filterlib.Filter:
"""返回基于模型配置的冻结过滤器"""
return nnx.All(*filters)
值得注意的是,如果需要只调整动作专家的参数,可以通过修改 get_freeze_filter 方法来冻结 VLM 的参数。
1.2.4 class Pi0:含特征嵌入、损失函数、推理 核心模型类,继承自 _model.BaseModel,实现了多模态输入处理、扩散过程、注意力机制。
1.2.4.1 初始化方法 __init__ 其组合了多个核心组件:PaLI-Gemma 模型(结合了 Gemma 语言模型和 SigLIP 视觉模型)、线性投影层等。
class Pi0 (_model.BaseModel):
def __init__ (self, config: Pi0Config, rngs: nnx.Rngs ):
super ().__init__(config.action_dim, config.action_horizon, config.max_token_len)
paligemma_config = _gemma.get_config(config.paligemma_variant)
action_expert_config = _gemma.get_config(config.action_expert_variant)
1.2.4.2 特征嵌入方法:embed_prefix/embed_suffix
embed_prefix: 处理图像和文本输入,创建前缀 token,皆为双向注意力。
embed_suffix: 处理机器人状态信息和噪声化的动作信息,创建后缀 token。
def embed_prefix (self, obs: _model.Observation ) -> Tuple [at.Float[at.Array, Any ], at.Bool[at.Array, Any ], at.Bool[at.Array, Any ]]:
tokens = jnp.concatenate(tokens, axis=1 )
return tokens, input_mask, ar_mask
1.2.4.3 损失函数 compute_loss:训练模型去噪的准确率 训练的时候,对其中的「原始动作 action」数据加噪,最后去预测所添加的真实噪声。计算预测噪声与实际噪声间的均方误差。
def compute_loss (self, rng, observation, actions, *, train: bool = False ):
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
v_t = self .action_out_proj(suffix_out[:, -self .action_horizon :])
return jnp.mean(jnp.square(v_t - u_t), axis=-1 )
数据集主要使用 LeRobotDataset,包含真实机器人操作数据,如 ALOHA 数据集和 Libero 数据集。
1.2.4.4 推理函数 sample_actions:基于扩散模型逆向采样 sample_actions 函数是 Pi0 模型的核心推理方法,实现了基于扩散模型的逆向采样过程。
def sample_actions (self, rng, observation, *, num_steps: int = 10 ):
prefix_tokens, prefix_mask, prefix_ar_mask = self .embed_prefix(observation)
_, kv_cache = self .PaliGemma.llm([prefix_tokens, None ], mask=prefix_attn_mask, positions=positions)
def step (carry ):
x_t, time = carry
suffix_tokens, suffix_mask, suffix_ar_mask = self .embed_suffix(observation, x_t, jnp.broadcast_to(time, batch_size))
(prefix_out, suffix_out), _ = self .PaliGemma.llm([None , suffix_tokens], mask=full_attn_mask, positions=positions, kv_cache=kv_cache)
v_t = self .action_out_proj(suffix_out[:, -self .action_horizon :])
return x_t + dt * v_t, time + dt
x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0 ))
return x_0
1.3 语言模型实现:models/gemma.py src/openpi/models/gemma.py 实现了 Gemma 语言模型的核心组件,定义了 RMSNorm、Embedder、Attention、FeedForward 等模块。
1.4 视觉模型实现:models/siglip.py siglip.py 实现了视觉编码器,基于 Vision Transformer (ViT),定义了位置编码、注意力池化等组件。
1.5 tokenizer.py: 提供文本 tokenization 功能 这段代码实现了两个相关但功能不同的 tokenizer 类:PaligemmaTokenizer 和 FASTTokenizer。
1.5.1 PaligemmaTokenizer 类 专门处理文本 prompt,下载 SentencePiece 模型并加载。
1.5.2 FASTTokenizer 类 可同时处理文本和动作数据,通过映射将 FAST token 转换为 PaliGemma token。
1.6 lora.py:实现了 LoRA (Low-Rank Adaptation) 微调方法
1.6.1 Einsum 类中的 setup 负责初始化模块所需的所有参数,包括 LoRA 参数 A 和 B 矩阵。
1.6.2 Einsum 类中的__call__ 实现了支持 LoRA 的前向传播逻辑,将标准 einsum 结果与 LoRA 修正项相加。
1.6.3 Einsum 类中的_make_lora_eqns 负责将标准的 Einstein 求和表达式转换为两个新的表达式,以支持 LoRA 的低秩分解计算。
1.7 vit.py: Vision Transformer 实现
第二部分 策略适配接口:src 下 policy 的全面分析与解读 src/openpi/policies 目录包含以下文件:
BasePolicy (policy.py)
AlohaPolicy (aloha_policy.py)
DroidPolicy (droid_policy.py)
LiberoPolicy (libero_policy.py)
这些文件定义了特定于机器人的输入和输出转换函数,处理数据格式、规范化和特定的转换需求。
2.1 policy.py:实现了 Policy 类和 PolicyRecorder 类
2.1.1 Policy 类 定义了基本的 Policy 类,继承自 openpi_client.base_policy.BasePolicy。
class Policy (BasePolicy ):
def __init__ (self, model, *, rng=None , transforms=( ), output_transforms=( ), ... ):
self ._sample_actions = nnx_utils.module_jit(model.sample_actions)
self ._input_transform = _transforms.compose(transforms)
self ._output_transform = _transforms.compose(output_transforms)
def infer (self, obs: dict ) -> dict :
inputs = self ._input_transform(inputs)
outputs = {
"state" : inputs["state" ],
"actions" : self ._sample_actions(sample_rng, _model.Observation.from_dict(inputs), **self ._sample_kwargs),
}
return self ._output_transform(outputs)
2.1.2 PolicyRecorder 装饰器类,包装了一个基础策略,并在执行策略的同时将所有的输入和输出保存到磁盘。
2.2 policy_config.py 定义了 PolicyConfig 类和 create_trained_policy 函数。
2.2.1 PolicyConfig 数据类 @dataclasses.dataclass
class PolicyConfig :
model: _model.BaseModel
norm_stats: dict [str , transforms.NormStats]
input_layers: Sequence [transforms.DataTransformFn]
output_layers: Sequence [transforms.DataTransformFn]
model_type: _model.ModelType = _model.ModelType.PI0
default_prompt: str | None = None
2.2.2 create_trained_policy 函数 def create_trained_policy (train_config, checkpoint_dir, repack_transforms=None , ... ):
model = train_config.model.load(_model.restore_params(checkpoint_dir / "params" , dtype=jnp.bfloat16))
return _policy.Policy(
model=model,
transforms=[...],
output_transforms=[...],
sample_kwargs=sample_kwargs,
metadata=train_config.policy_metadata,
)
2.3 policies/aloha_policy.py 实现了一个用于 Aloha 策略的输入输出处理和数据转换的模块。
2.3.1 make_aloha_example 创建随机输入示例,包括状态向量和四个摄像头的图像数据。
2.3.2 AlohaInputs 定义 Aloha 策略的输入数据结构,处理图像、状态和动作数据的标准化。
2.3.3 AlohaOutputs 定义 Aloha 策略的输出数据结构,仅返回前 14 个维度的动作数据。
第三部分 模型训练的配置:src 下 training 模块的全面分析与解读 training 模块是 OpenPI 项目中负责训练相关功能的核心部分。
3.1 配置系统 (config.py) 定义了训练过程的各种配置类型,包括 TrainConfig、DataConfigFactory、AssetsConfig 等。
3.1.1 基础配置类 AssetsConfig、DataConfig
3.1.2 数据集配置:包含 ALOHA、Libero 两套数据集 涉及 LeRobotLiberoDataConfig 和 LeRobotAlohaDataConfig。
3.1.3 训练配置 TrainConfig class TrainConfig :
name: str
project_name: str = "openpi"
exp_name: str
model: _model.BaseModelConfig
batch_size: int = 32
num_train_steps: int = 30_000
lr_schedule: _optimizer.LRScheduleConfig
optimizer: _optimizer.OptimizerConfig
3.1.4 预定义配置 基于 ALOHA/Libero 数据集微调π0 的配置示例。
3.2 数据加载系统 data_loader.py 定义了数据集和数据加载器的接口,实现了数据转换管道。
3.2.1 FakeDataset 类
3.2.2 create_dataset:创建适合训练的数据集 根据配置参数创建适合模型训练的数据集,支持真实数据集和模拟数据。
3.2.3 transform_dataset:对数据集应用转换 负责对原始数据集应用一系列转换操作,包括重新打包、清洗、归一化等。
3.2.4 create_data_loader:创建用于训练的数据加载器 协调多个模块共同工作,创建一个用于模型训练的数据加载器。
3.3 优化器系统 (optimizer.py) 定义了多种学习率调度策略和优化器配置,如 AdamW、SGD。
3.4 检查点系统 (checkpoints.py) 负责模型状态的保存和恢复,使用 Orbax 库实现高效的检查点存储。
3.5 模型分片系统(sharding.py) 实现分布式训练时的模型参数分片,提供 FSDP 的实现。
3.6 权重加载系统 (weight_loaders.py) 定义了 WeightLoader 协议,用于加载预训练权重。
3.7 辅助工具(utils.py) 定义了 TrainState 数据类,封装了训练过程的状态。
第四部分 模型的训练与部署:基于客户端 - 服务器 C/S 架构 packages/openpi-client 是一个独立的客户端库,scripts 模块提供了服务器端的各种工具和脚本。
4.1 packages/openpi-client:帮真机或 Sim 与策略服务器进行通信和交互 该模块主要用于连接到 OpenPI 服务器,处理观察数据和动作序列。
4.1.1 核心接口层
4.1.2 通信层 WebsocketClientPolicy
4.1.3 数据处理层
4.1.4 运行时系统层
4.2 scripts(策略服务器) scripts 目录包含多个 Python 脚本,用于数据处理、模型训练和服务部署等任务。
4.2.1 init .py
4.2.2 compute_norm_stats.py:计算数据的归一化统计信息
4.2.3 serve_policy.py:启动策略服务,用于模型推理 serve_policy.py 是 openpi 中的策略推理服务端脚本,作用为启动一个 WebSocket 服务器,加载预训练策略模型,等待外部请求,然后执行动作推理并返回结果。
def main (args: Args ) -> None :
policy = create_policy(args)
server = websocket_policy_server.WebsocketPolicyServer(
policy=policy, host="0.0.0.0" , port=args.port, metadata=policy_metadata,
)
server.serve_forever()
4.2.3.1 分别启动 WebSocket 服务器、WebSocket 客户端并互联 设定 prompt,随后分别启动 WebSocket 服务器、WebSocket 客户端并互联。服务端发送 metadata 给客户端。
4.2.3.2 客户端发送推理请求、服务端处理推理请求 服务器处理推理请求,调用 policy.infer(obs)。如果传入的 obs 字典没有 "prompt" 键,策略会使用默认 prompt。
4.2.3.3 模型获得全部输入数据,生成动作序列 获取到的 prompt 被传递给分词器 Tokenizer,将文本指令转换为 token ID 序列,与图像数据、状态数据一起被输入到π0 中,生成预测的动作序列。
4.2.4 train_test.py:训练和测试模型
4.2.5 train.py:训练模型 基于 JAX 的分布式训练脚本,集成了模型初始化、训练循环、日志记录等功能。
4.2.6 scripts/docker 包含与 Docker 相关的脚本和配置文件,用于构建和管理 Docker 容器。
第五部分 examples:各种机器人平台及策略客户端的示例实现
5.1 aloha_real aloha_real 模块是 OpenPI 项目中用于控制真实 ALOHA 双臂机器人的完整实现。
5.1.1 核心架构
硬件常量定义 (constants.py)
数据转换工具
环境接口 (env.py 和 real_env.py)
主控制流程 (main.py)
5.1.2 系统工作流程与部署方式
初始化阶段:启动 ROS 节点 → 初始化双臂机器人 → 连接摄像头 → 建立 WebSocket 连接
运行时循环 (50Hz):获取观察 → 发送到策略服务器 → 接收动作序列 → 执行动作
部署方式:Docker 部署或本地部署(需启动 3 个终端:ROS 节点、主控进程、推理服务)
uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
roslaunch aloha ros_nodes.launch
uv pip install -e packages/openpi-client
python -m examples.aloha_real.main
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
随机西班牙地址生成器 随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
Gemini 图片去水印 基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online