SAM 的掩码生成机制设计巧妙,它让计算机能够'读懂'图像并勾勒出物体轮廓。要理解其核心,我们需要拆解组件协作、多掩码逻辑以及底层的数学推导。
核心组件与协作关系
SAM 主要由三个模块构成,它们共同完成从输入到输出的转换:
| 组件模块 | 核心功能 | 关键实现机制 |
|---|---|---|
| 图像编码器 | 提取图像特征,生成图像嵌入(Image Embedding) | 使用基于 Transformer 的视觉骨干网络(如 ViT),将图像转换为高维特征表示。 |
| 提示编码器 | 将各种提示(点、框、掩码)转换为提示嵌入(Prompt Embedding) | 为点、框(视为点对)和掩码(通过卷积)分别设计编码方式,统一嵌入空间。 |
| 掩码解码器 | 核心:综合图像和提示嵌入,预测输出掩码 | 采用 Transformer 解码器,通过交叉注意力融合图像与提示信息;上采样层还原掩码分辨率;MLP 预测掩码质量分数。 |
| 提示采样策略 (自动掩码生成) | 在无人工提示时,自动生成有效提示以分割图中所有物体 | 原始 SAM:在图像上生成密集的网格点作为前景点提示。改进方案(如 MobileSAMv2):使用目标检测器 (如 YOLOv8) 生成目标感知的框提示,提升效率。 |
工作流程解析
上述组件是如何协作生成掩码的呢?整个过程主要分为两种模式:
- 交互式分割 (SegAny):当你提供点或框等提示时,模型会基于该特定提示生成一个或多个候选掩码及其质量分数。
- 自动分割 (SegEvery):当没有人工提示时,模型会使用提示采样策略(例如网格点或目标检测框)自动生成大量提示。对于每一个候选提示,都会经过流程生成候选掩码,最后再通过非极大值抑制(NMS)等后处理步骤过滤掉高度重叠的掩码,输出最终结果。
多掩码输出机制详解
1. 设计初衷:处理模糊性
当用户只提供一个点提示时,这个点可能对应多个合理的分割结果。例如点击在一个区域,可能是整个人体、上半身或是头部区域。为了应对这种不确定性,SAM 默认输出三个掩码。
2. 三个掩码的具体含义
每个掩码代表对提示的不同尺度或范围的解释:
| 掩码 | 通常含义 | 适用场景 | 示例 |
|---|---|---|---|
| Mask 1 | 最完整对象 (最大范围) | 需要整个物体 | 整个人、整辆车 |
| Mask 2 | 中等范围 (部分对象) | 物体主要部分 | 上半身、车体外壳 |
| Mask 3 | 最小范围 (核心区域) | 精细分割 | 头部、车轮 |
3. 代码层面的实现
在调用预测接口时,multimask_output 参数决定了是否启用多掩码输出:
def predict(
self,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
box: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True, # 关键参数!True 输出 3 个掩码处理模糊性,False 输出 1 个最佳掩码
return_logits: bool = False,
):
"""
multimask_output:
- True: 输出 3 个掩码处理模糊性
- False: 输出 1 个最佳掩码
"""
4. 实际案例分析
我们可以通过可视化来观察这三个掩码的区别:
import matplotlib.pyplot as plt
def demonstrate_multimask(image, point_coords, point_labels):
masks, scores, logits = predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=True
)
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
# 原图
axes[0].imshow(image)
axes[0].scatter(point_coords[:, 0], point_coords[:, 1], color='red', marker='*', s=200, edgecolor='white')
axes[0].set_title('Original Image with Prompt')
# 三个掩码
for i, (mask, score) in enumerate(zip(masks, scores)):
axes[i+1].imshow(image)
show_mask(mask, axes[i+1])
axes[i+1].set_title(f'Mask {i+1}\nScore: {score:.3f}')
axes[i+1].axis('off')
plt.tight_layout()
plt.show()
return masks, scores
# 使用示例
masks, scores = demonstrate_multimask(image, input_point, input_label)
5. 评分机制解析
分数 score 表示模型对每个掩码质量的置信度。这有助于在自动生成的大量掩码中筛选出高质量的结果。评分通常基于 IoU 预测、稳定性分数以及与图像特征的匹配度。
6. 如何选择掩码
如果你不需要多个选项,可以强制只输出一个掩码:
# 只输出最佳掩码
single_mask, single_score, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False # 关键参数!
)
print(f"最佳掩码分数:{single_score:.3f}")
如果仍需从多掩码中选择,可以根据策略进行筛选:
def select_best_mask(masks, scores, strategy="auto"):
"""选择最佳掩码的策略"""
if strategy == "highest_score":
best_idx = np.argmax(scores)
return masks[best_idx]
elif strategy == "balanced":
areas = [mask.sum() for mask in masks]
normalized_areas = areas / np.max(areas)
combined_scores = scores * (1 - 0.2 * np.abs(normalized_areas - 0.5))
best_idx = np.argmax(combined_scores)
return masks[best_idx]
elif strategy == "largest":
areas = [mask.sum() for mask in masks]
return masks[np.argmax(areas)]
elif strategy == "smallest":
areas = [mask.sum() for mask in masks]
return masks[np.argmin(areas)]
# 使用示例
best_mask = select_best_mask(masks, scores, strategy="balanced")
核心算法与数学推导
SAM 的掩码生成基于 Transformer 架构,其核心数学过程可以表示为:
$$M = \text{Decoder}(\text{Encoder}(I), \text{Encoder}(P))$$
其中 $I$ 是输入图像,$P$ 是提示,$M$ 是输出掩码。
1. 图像编码器的数学原理
1.1 Vision Transformer (ViT) 基础
图像分块与线性投影:
$$X = [x_{class}; x_p^1W; x_p^2W; ...; x_p^NW] + E_{pos}$$
- $x_p^i \in \mathbb{R}^{(P^2 \cdot C)}$:第 $i$ 个图像块
- $W \in \mathbb{R}^{(P^2 \cdot C \times D)}$:线性投影矩阵
- $E_{pos} \in \mathbb{R}^{((N+1) \times D)}$:位置编码
- $N = HW/P^2$:块的数量
1.2 自注意力机制
查询 - 键 - 值计算:
$$Q = XW_Q, K = XW_K, V = XW_V$$
自注意力公式:
$$\text{Attention}(Q, K, V) = \text{softmax}(QK^T/\sqrt{d_k})V$$
多头注意力则是将上述过程并行化后拼接:
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W_O$$
2. 提示编码器的数学推导
2.1 点提示编码
对于点 $p = (x, y)$,使用正弦位置编码:
$$PE(pos, 2i) = \sin(pos/10000^{(2i/D)})$$ $$PE(pos, 2i+1) = \cos(pos/10000^{(2i/D)})$$
点嵌入计算:
$$E_{point} = \text{MLP}(PE(x) \oplus PE(y)) + E_{type}$$
其中 $\oplus$ 表示拼接,$E_{type}$ 是点类型嵌入(前景/背景)。
2.2 框提示编码
框 $b = (x_1, y_1, x_2, y_2)$ 编码为两个角点:
$$E_{box} = \text{MLP}(PE(x_1)\oplus PE(y_1)) + \text{MLP}(PE(x_2)\oplus PE(y_2)) + E_{box_type}$$
3. 掩码解码器的核心算法
3.1 掩码解码器架构
掩码解码器是一个轻量级 Transformer,其输入为:
$$X = [E_{mask}^1, E_{mask}^2, E_{mask}^3, E_{prompt}, E_{iou}] \in \mathbb{R}^{(N_{tokens} \times D)}$$
3.2 交叉注意力机制
图像到提示的交叉注意力:
$$Q = XW_Q, K = F_{img}W_K, V = F_{img}W_V$$ $$\text{CrossAttn}(X, F_{img}) = \text{softmax}(QK^T/\sqrt{d_k})V$$
其中 $F_{img} \in \mathbb{R}^{(H'W' \times D)}$ 是图像特征。
3.3 掩码生成公式
动态掩码预测:
$$M_i = \text{Sigmoid}(\text{Conv}(\text{Upsample}(\text{Linear}(h_i) \odot F_{img})))$$
其中 $h_i$ 是第 $i$ 个掩码 token 的隐藏状态,$\odot$ 是逐元素乘法(特征调制)。
4. 训练目标与损失函数
总损失是多个损失项的加权和:
$$L_{total} = L_{mask} + \lambda_1 L_{iou} + \lambda_2 L_{consistency}$$
掩码损失通常结合 Focal Loss 和 Dice Loss:
$$L_{mask} = L_{focal} + L_{dice}$$
$$L_{focal} = -\alpha(1-p_t)^\gamma \log(p_t)$$
$$L_{dice} = 1 - \frac{(2\sum p_i y_i + \epsilon)}{(\sum p_i + \sum y_i + \epsilon)}$$
5. 自动掩码生成的算法原理
5.1 网格点采样
对于图像尺寸 $H \times W$,生成网格点:
$$points = {(i \cdot \Delta x, j \cdot \Delta y) \mid i=0,...,N_x-1; j=0,...,N_y-1}$$
5.2 非极大值抑制 (NMS)
掩码 IoU 计算:
$$IoU(M_i, M_j) = \frac{|M_i \cap M_j|}{|M_i \cup M_j|}$$
NMS 算法伪代码:
def nms(masks, scores, iou_threshold=0.8):
order = argsort(scores)[::-1]
keep = []
while order:
i = order[0]
keep.append(i)
ious = [IoU(masks[i], masks[j]) for j in order[1:]]
order = [j for j, iou in zip(order[1:], ious) if iou < iou_threshold]
return keep
6. 关键文件指引
阅读代码时,建议重点关注以下文件(以官方 SAM 仓库为例):
| 代码文件 | 功能描述 |
|---|---|
modeling/mask_decoder.py | 掩码解码器核心实现,包括 Transformer 结构、上采样、MLP 头。 |
modeling/prompt_encoder.py | 提示编码器,处理点、框、掩码的编码。 |
modeling/sam.py | SAM 模型整体结构,整合图像编码器、提示编码器、掩码解码器。 |
utils/amg.py | 自动掩码生成(AMG) 的具体实现,包含提示点网格生成、掩码后处理(如 NMS)等。 |
在阅读代码时,建议你重点关注 mask_decoder.py 中的 forward 函数,以及 amg.py 中生成提示点和过滤掩码的逻辑。一些基于 SAM 的第三方库(如 samtool)对原始接口进行了封装,可能更易于理解和使用。
总结
三个掩码的设计体现了 SAM 对视觉分割模糊性的深刻理解:
- 应对不确定性:一个点提示可能有多种合理解释
- 提供选择余地:用户可以根据具体需求选择最合适的尺度
- 分数指导选择:质量分数帮助用户做出明智决定
- 灵活性:可以通过
multimask_output参数控制这个行为
这种设计使得 SAM 在处理真实世界的复杂场景时更加鲁棒和实用。

