跳到主要内容π0 源码剖析:基于 PaLI-Gemma 的扩散策略与 C/S 架构部署 | 极客日志PythonAI算法
π0 源码剖析:基于 PaLI-Gemma 的扩散策略与 C/S 架构部署
开源项目 openpi 的 π0 模型源码深度剖析。内容涵盖模型架构设计,包括 PaLI-Gemma 多模态输入处理、扩散策略去噪生成动作的核心逻辑。解析训练配置系统、数据加载管道及基于 C/S 架构的模型部署方案。重点讲解 Observation 数据结构、LoRA 微调策略、推理时的 KV 缓存优化以及客户端与服务器的通信机制。适合希望复现或二次开发具身智能模型的开发者参考。
蜜桃汽水7 浏览 引言
OpenPI 是 Physical Intelligence 推出的开源通用机器人控制框架,核心是基于 PaLI-Gemma 和扩散策略的 π0 模型。本文将深入剖析其源码结构,涵盖从模型架构实现到 C/S 架构下的训练与部署全流程。
第一部分 π0 模型架构的实现
核心位于 src/openpi/models。首先是 models/model.py,定义了基础抽象类。
1.1 基础组件与数据结构
这是模型框架的核心文件,定义了基础的抽象类和数据结构:
BaseModelConfig: 所有模型配置的抽象基类
BaseModel: 所有模型实现的抽象基类
Observation: 保存模型输入的数据类
Actions: 定义动作数据格式
提供了通用功能如 preprocess_observation 和 restore_params。
基础组件和关键常量
首先是模型类型枚举,定义了两种支持的模型类型:
class ModelType(enum.Enum):
"""Supported model types."""
PI0 = "pi0"
PI0_FAST = "pi0_fast"
接下来是图像输入配置,定义了模型期望的图像输入的键名。这表明模型设计为同时接收三个视角的图像:一个基础视图(机器人环境的全局视图)、左手腕视图、右手腕视图。
IMAGE_KEYS = (
"base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb",
)
再其次,是图像分辨率设置——定义了模型处理图像的标准分辨率为 224×224 像素。
IMAGE_RESOLUTION = (224, 224)
Observation 类与 Actions 类型的详解
Observation 类是 OpenPI 框架中的一个核心数据结构,用于存储和管理模型的输入数据。
首先,它包含了机器人感知系统收集的所有必要信息:
- PI0-FAST 模型特有字段:
token_ar_mask(自回归模型的标记掩码)、token_loss_mask(损失计算的标记掩码)。
- 语言提示相关字段:
tokenized_prompt(已经 tokenized 的语言提示)、tokenized_prompt_mask(语言提示的掩码)。
- 机器人状态 (
state):低维度的机器人状态向量。
- 图像掩码 (
image_masks):标记对应的图像是否有效。
- 图像数据 (
images):存储多个摄像头视角的图像数据。
接下来,定义了 from_dict 方法,用于从非结构化的字典数据创建 Observation 对象:
return cls(
images=data["image"],
image_masks=data["image_mask"],
state=data["state"],
tokenized_prompt=data.get("tokenized_prompt"),
tokenized_prompt_mask=data.get("tokenized_prompt_mask"),
token_ar_mask=data.get("token_ar_mask"),
token_loss_mask=data.get("token_loss_mask"),
)
最后,在类外定义了 Actions 类型,用于表示模型的输出动作:
Actions = at.Float[ArrayT, "*b ah ad"]
维度:*b 表示批量维度,ah 表示动作时间步长,ad 表示每个动作的维度。
在实际任务中,State 代表机器人当前的状态信息(关节角度、末端位置等),而 Action 代表机器人应该执行的下一步控制命令(目标关节角度或增量变化)。
BaseModelConfig 与 BaseModel
BaseModelConfig 继承自 abc.ABC,定义了配置参数。
BaseModel 继承自 nnx.Module 和 abc.ABC,实现了模型的基础逻辑。
1.2 models/pi0.py 的实现
Pi0 是一个多模态扩散模型:继承自 BaseModel,使用 SigLIP 处理视觉输入、使用 Gemma 处理语言输入,实现了基于扩散的动作生成系统,且包含 compute_loss 和 sample_actions 方法的实现。
make_attn_mask:注意力掩码生成函数
这个函数生成 transformer 中使用的注意力掩码,控制 token 之间的注意力流动方式。
def make_attn_mask(input_mask, mask_ar):
mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
cumsum = jnp.cumsum(mask_ar, axis=1)
attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
return jnp.logical_and(attn_mask, valid_mask)
它支持多种注意力模式:纯因果注意力、前缀语言模型注意力、块状因果注意力。
posemb_sincos:位置编码函数
def posemb_sincos( pos: at.Real[at.Array, Any], embedding_dim: int, min_period: float, max_period: float ) -> at.Float[at.Array, f"b {embedding_dim}"]:
if embedding_dim % 2 != 0:
raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")
fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
period = min_period * (max_period / min_period) ** fraction
sinusoid_input = jnp.einsum(
"i,j->ij", pos, 1.0 / period * 2 * jnp.pi,
precision=jax.lax.Precision.HIGHEST,
)
return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)
Pi0Config:模型配置参数
该类定义了模型的配置参数,比如 PaLI-Gemma 变体:gemma_2b,尤其值得注意的是在本 π0 的官方实现中,动作专家的底层结构用的 300M 大小的 gemma 模型变体。
class Pi0Config(_model.BaseModelConfig):
dtype: str = "bfloat16"
paligemma_variant: _gemma.Variant = "gemma_2b"
action_expert_variant: _gemma.Variant = "gemma_300m"
action_dim: int = 32
action_horizon: int = 50
max_token_len: int = 48
inputs_spec:定义了 π0 模型本身接收的输入数据格式
通过 inputs_spec 函数定义了 π0 模型本身接收的输入数据格式,返回一个包含观察规格和动作规格的元组。
- 其支持多种输入,比如视觉输入 (三个不同视角的 RGB 图像)、语言输入 (分词后的文本 prompt)、状态输入 (当前机器人状态)。
- 输出上则是一个时序动作序列 (包含 50 个连续的动作向量,每个动作向量有 32 个维度)。
image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
三、创建观察规格:包含视觉输入、机器人状态、指令输入
state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
get_freeze_filter:参数冻结器
该配置类还实现了 get_freeze_filter 这个函数,作用是如果选择 LoRA 微调 (冻结原始预训练模型的参数,只更新新添加的低秩适应层参数),则需要对模型中的某些参数做冻结。
- 只对动作专家使用 LoRA
- 对两者都使用 LoRA
- 只对 PaLI-Gemma 使用 LoRA
具体而言,该 get_freeze_filter 分为 4 大阶段:定义函数本身、初始化变量并创建参数过滤器;分情况添加 LoRA 权重;针对需要 LoRA 微调的少量参数处理;返回所有需要被冻结/被过滤的参数。
class Pi0:含特征嵌入、损失函数、推理
核心模型类,继承自 _model.BaseModel,实现了多模态输入处理、扩散过程、注意力机制。
初始化方法 __init__
其组合了多个核心组件:一个是 PaLI-Gemma 模型:结合了 Gemma 语言模型和 SigLIP 视觉模型;另一个是线性投影层:用于时间 - 动作混合等。
特征嵌入方法:embed_prefix/embed_suffix
embed_prefix:处理图像和文本输入 (图像通过 SigLip 模型编码,文本通过 Gemma LLM 编码),创建前缀 token,皆为双向注意力。
embed_suffix:处理机器人状态信息、噪声化的动作信息 (状态和噪声动作经过线性投影和 MLP 处理),创建后缀 token。
其中状态为单个 token,和第一个动作 token 均设置为单向注意力,用 ar_mask = true 表示;其余动作 tokens 之间设置为双向注意力,用 ar_mask = false 表示。
损失函数 compute_loss:训练模型去噪的准确率
训练的时候,对其中的「原始动作 action」数据加噪,最后去预测所添加的真实噪声。计算预测噪声与实际噪声间的均方误差。
return jnp.mean(jnp.square(v_t - u_t), axis=-1)
预测噪声 v_t 由模型输出投影回动作空间得到。模型前向传播调用 PaliGemma 进行推理,处理前缀和后缀 token。
创建带噪动作序列 x_t,相当于 x_t 是噪声化的动作,随着时间从 0 到 1,原始动作逐渐添加真实噪声,变为纯噪声。所添加的噪声即 = 加满噪声的动作 - 原始动作。
推理函数 sample_actions:基于扩散模型逆向采样
sample_actions 函数是 Pi0 模型的核心推理方法,实现了基于扩散模型的逆向采样过程——说白了 就是去噪,它从纯噪声开始,通过多步骤逐渐'去噪',最终生成符合条件分布的机器人动作序列。
函数的核心是一个基于 while 循环的迭代过程,每一步都使用训练好的神经网络预测从当前噪声化动作到目标动作的方向。
首先,函数对输入观察数据进行预处理,包括标准化图像大小等操作。然后设置时间步长 dt 为负值(因为是从 t=1 向 t=0 方向演化),生成初始随机噪声作为起点。
dt = -1.0 / num_steps
noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))
处理观察数据,得到前缀表示和相关掩码。然后使用 PaliGemma 语言模型进行一次前向传递,生成 Key-Value 缓存(kv_cache)——这是一个性能优化:因为前缀部分在整个采样过程中保持不变,预先计算并缓存它们的表示可以避免重复计算。
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
positions = jnp.cumsum(prefix_mask, axis=1) - 1
_, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)
第三,通过 step 函数构建注意力掩码系统并让 PaliGemma 做推理
核心迭代通过 jax.lax.while_loop 实现。step 函数通过 embed_suffix 处理当前状态,包括状态信息嵌入、噪声化动作、时间步编码。接着构建复杂的注意力掩码系统,处理前缀 - 后缀之间的注意力关系。之后,模型推理,使用 PaliGemma 语言模型进行推理,利用缓存的前缀信息(kv_cache)提高效率。
第四,step 函数中做最后的速度预测与动作更新 (去噪)
在每一步中,模型预测速度场 v_t(从噪声到目标的方向),并通过类欧拉法更新动作表示。
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
return x_t + dt * v_t, time + dt
至于 cond 函数确定何时停止迭代,通过检查时间是否接近零。
1.3 语言模型实现:models/gemma.py
src/openpi/models/gemma.py 实现了 Gemma 语言模型的核心组件,定义了 RMSNorm、Embedder、Attention、FeedForward 等模块,且提供了不同规模 Gemma 模型的配置(300M, 2B 等)。
1.4 视觉模型实现:models/siglip.py
siglip.py: 实现了视觉编码器,基于 Vision Transformer (ViT),定义了位置编码、注意力池化等组件,支持不同大小的模型变体。
1.5 tokenizer.py: 提供文本 tokenization 功能
这段代码实现了两个相关但功能不同的 tokenizer 类:PaligemmaTokenizer 和 FASTTokenizer。
PaligemmaTokenizer 类:专门处理文本 prompt
PaligemmaTokenizer 是一个相对简单的 Tokenizer,专门处理文本 prompt。
第一方面,在初始化阶段下载完成后,代码以二进制读取模式打开文件,并使用 SentencePiece 处理器加载模型。
with path.open("rb") as f:
self._tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
第二方面,tokenize 方法是处理文本输入的核心,它执行以下步骤:文本清理、Tokenizer 将清理后的文本送入 SentencePiece 编码器、长度处理、返回结果。
FASTTokenizer 类
FASTTokenizer 是一个更复杂的 Tokenizer,可同时处理文本和动作数据。
首先是初始化过程,设置 _fast_skip_tokens = 128 以跳过 PaliGemma 词汇表末尾的特殊 token,加载专门的 FAST Tokenizer,同样下载 PaliGemma Tokenizer 模型。
其次是 Tokenizer 流程,处理所有 token 序列和掩码的填充或截断,创建三种掩码:token_mask、ar_mask、loss_mask。如果提供了动作,使用 FAST Tokenizer 对动作进行 Tokenizer,通过 _act_tokens_to_paligemma_tokens 将这些动作 token 映射到 PaliGemma 词汇表中。
1.6 lora.py:实现了 LoRA (Low-Rank Adaptation) 微调方法
Einsum 类中的 setup 负责初始化模块所需的所有参数,Einsum 类中的 call 实现了支持 LoRA 技术的前向传播逻辑,Einsum 类中的_make_lora_eqns 负责将标准的 Einstein 求和表达式转换为两个新的表达式。
1.7 vit.py: Vision Transformer 实现
vit.py 实现了 Vision Transformer 的核心组件。
第二部分 策略适配接口:src 下 policy 的全面分析与解读
src/openpi/policies 目录包含以下文件:BasePolicy (policy.py)、AlohaPolicy (aloha_policy.py)、DroidPolicy (droid_policy.py)、LiberoPolicy (libero_policy.py)。
此外,每个特定机器人都有自己的策略文件,如 aloha_policy.py、droid_policy.py、libero_policy.py。这些文件定义了特定于机器人的输入和输出转换函数,处理数据格式、规范化和特定的转换需求。
2.1 policy.py:实现了 Policy 类和 PolicyRecorder 类
2.1.1 Policy 类
policy.py 定义了基本的 Policy 类和 PolicyRecorder 类,它们继承自 openpi_client.base_policy.BasePolicy。
首先,做一系列初始化,包括模型参数、随机数生成器、输入输出转换函数序列等。
其次,对于 infer 方法——在策略内部流程上:应用输入转换、复制输入观察数据、生成新的随机数键、模型推理、解除批处理并转换为 NumPy 数组、输出转换。
2.1.2 PolicyRecorder
PolicyRecorder 是一个装饰器类,它包装了一个基础策略,并在执行策略的同时将所有的输入和输出保存到磁盘,用于记录策略的行为。
2.2 policy_config.py
policy_config.py 定义了 PolicyConfig 类和 create_trained_policy 函数。
2.2.1 PolicyConfig 数据类
PolicyConfig 是一个使用 @dataclasses.dataclass 装饰的数据类,用于存储创建策略所需的所有配置信息。
2.2.2 create_trained_policy 函数
create_trained_policy 函数是从训练好的检查点创建可用策略的工厂函数。
函数的核心流程是:构建并返回 Policy 实例,将所有转换函数组织为有序的处理流程。如果未提供 norm_stats,从检查点加载归一化统计信息。使用 train_config 加载模型参数。处理输入参数,确保 repack_transforms 不为空。
2.3 policies/aloha_policy.py
这段代码实现了一个用于 Aloha 策略的输入输出处理和数据转换的模块。
2.3.1 make_aloha_example:输入示例
make_aloha_example 函数创建了一个随机的输入示例,包括一个 14 维的状态向量和四个摄像头的图像数据,以及一个文本提示信息。
2.3.2 AlohaInputs:定义 Aloha 策略的输入数据结构
AlohaInputs 类定义了 Aloha 策略的输入数据结构。call 方法实现了对 Aloha 策略输入数据的标准化处理。该方法将原始输入数据转换为模型可接受的格式,包括多项关键处理步骤,比如进行必要的解码和填充操作,并检查图像数据是否包含预期的摄像头视角。
2.3.3 AlohaOutputs:定义 Aloha 策略的输出数据结构
AlohaOutputs 类定义了 Aloha 策略的输出数据结构,同样使用 dataclasses.dataclass 装饰器。
2.3.4 辅助函数
此外,代码中还包含多个辅助函数,用于数据的标准化、反标准化、关节角度翻转、夹持器位置的线性和角度转换等。
第三部分 模型训练的配置:src 下 training 模块的全面分析与解读
training 模块是 OpenPI 项目中负责训练相关功能的核心部分,该目录下包含了以下主要文件:checkpoints.py、config.py、data_loader.py、optimizer.py、sharding.py、utils.py、weight_loaders.py。
3.1 配置系统 (config.py)
定义了训练过程的各种配置类型,包括 TrainConfig、DataConfigFactory、AssetsConfig 等。
3.1.1 基础配置类 AssetsConfig、DataConfig
AssetsConfig 用于确定数据 pipeline 所需资产的位置。DataConfig 包含 repo_id、asset_id、norm_stats 等。
3.1.2 数据集配置
涉及 LeRobotLiberoDataConfig 和 LeRobotAlohaDataConfig。
对于后者的结构,LeRobotLiberoDataConfig 是一个用于机器人控制系统的数据配置类,它负责定义整个数据管道中不同阶段的数据转换操作。
特别值得注意的是关于动作表示的转换:该配置支持将绝对动作转换为相对动作。
3.1.3 训练配置 TrainConfig
TrainConfig 包含 name、project_name、exp_name、model、batch_size 等参数。
3.1.4 预定义配置
文件最后定义了多个具体的训练配置,比如 pi0_libero、pi0_aloha_pen_uncap 等。
3.2 数据加载系统 data_loader.py
定义了数据集和数据加载器的接口(Dataset 和 DataLoader)。
3.2.1 FakeDataset 类
3.2.2 create_dataset:创建适合训练的数据集
create_dataset 函数是一个关键的数据准备工具,负责根据配置参数创建适合模型训练的数据集。
最后,如果 data_config.prompt_from_task 设置为 True,函数会将原始数据集包装在 TransformedDataset 中,并应用 PromptFromLeRobotTask 转换。
3.2.3 transform_dataset:对数据集应用转换
transform_dataset 函数是数据预处理管道中的关键组件,负责对原始数据集应用一系列转换操作。
3.2.4 create_data_loader:创建用于训练的数据加载器
create_data_loader 函数是整个数据处理流水线的核心组件,它协调多个模块共同工作,创建一个用于模型训练的数据加载器。
3.3 优化器系统 (optimizer.py)
3.4 检查点系统 (checkpoints.py)
负责模型状态的保存和恢复,使用 Orbax 库实现高效的检查点存储。
3.5 模型分片系统 (sharding.py)
实现分布式训练时的模型参数分片,提供 fsdp_sharding 函数用于全参数数据并行 (FSDP) 的实现。
3.6 权重加载系统 (weight_loaders.py)
定义了 WeightLoader 协议,用于加载预训练权重。
3.7 辅助工具 (utils.py)
定义了 TrainState 数据类,封装了训练过程的状态。
第四部分 模型的训练与部署:基于客户端 - 服务器 C/S 架构
packages/openpi-client,是一个独立的客户端库 openpi-client 库,主要负责提供与策略服务器通信的接口。
scripts 这个模块提供了服务器端的各种工具和脚本,主要包括 serve_policy.py、train.py、compute_norm_stats.py 等。
总的来说,这是一个典型的分布式系统设计:packages/openpi-client 提供轻量级的客户端接口,而 scripts/ 则提供服务器端的功能实现,两者通过 WebSocket 协议进行通信。
4.1 packages/openpi-client:帮真机或 Sim 与策略服务器进行通信和交互
该模块的目录结构如下,主要用于连接到 OpenPI 服务器,处理观察数据和动作序列。
4.1.1 核心接口层
BasePolicy、Environment、Agent。
4.1.2 通信层 WebsocketClientPolicy
WebsocketClientPolicy 实现与服务器的 WebSocket 通信。
4.1.3 数据处理层
ActionChunkBroker、image_tools。
4.1.4 运行时系统层
Runtime、Subscriber、agents。
4.1.5 工具支持
4.2 scripts(策略服务器)
scripts 目录包含多个 Python 脚本,这些脚本用于数据处理、模型训练和服务部署等任务。
4.2.1 init.py
4.2.2 compute_norm_stats.py:计算数据的归一化统计信息
4.2.3 serve_policy.py:启动策略服务
serve_policy.py 是 openpi 中的策略推理服务端脚本,作用为:启动一个 WebSocket 服务器,加载预训练策略模型,等待外部请求,然后执行动作推理并返回结果。
main 函数是脚本的入口点,它首先调用 create_policy 函数创建策略,然后记录策略的元数据。接着获取主机名和本地 IP 地址,并创建一个 WebSocket 服务器来提供策略服务。
load_policy 函数根据传入的参数创建策略,如果参数中指定了检查点,则从检查点加载策略,否则使用默认策略。
4.2.4 train_test.py:训练和测试模型
4.2.5 train.py:训练模型
这段代码是一个基于 JAX 的分布式训练脚本,集成了模型初始化、训练循环、日志记录、实验跟踪和检查点管理等功能。
4.2.6 scripts/docker
docker 目录通包含与 Docker 相关的脚本和配置文件,用于构建和管理 Docker 容器。
第五部分 examples:各种机器人平台及策略客户端的示例实现
examples 模块的结构涉及以下模块:aloha_real、aloha_sim、droid、libero、simple_client、ur5、inference.ipynb、policy_records.ipynb。
5.1 aloha_real
aloha_real 模块是 OpenPI 项目中用于控制真实 ALOHA 双臂机器人的完整实现。
5.1.1 核心架构
硬件常量定义、数据转换工具、环境接口、主控制流程。
5.1.2 系统工作流程与部署方式
- 初始化阶段:启动 ROS 节点 → 初始化双臂机器人 → 连接摄像头 → 建立 WebSocket 连接。
- 运行时循环 (50Hz):获取观察 → 发送到策略服务器 → 接收动作序列 → 执行动作 → 更新状态。
- 动作执行:策略预测 25 步动作序列,ActionChunkBroker 管理动作缓冲和执行。
本地部署需要启动 3 个终端:机器人客户端、ROS 硬件层、openpi 策略服务器。
综上,三进程间的协同流程可以总结为:[ROS 系统(终端 2)] <== 硬件数据 ==> [主控进程 main.py(终端 1)] <== 请求 ==> [推理服务 serve_policy.py(终端 3)]。
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- 随机西班牙地址生成器
随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
- Gemini 图片去水印
基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online