PyTorch scatter() 与 scatter_() 详解
在 PyTorch 中,函数名末尾带下划线(如 scatter_)通常表示原地操作(in-place),即直接修改原 Tensor;而不带下划线的版本(scatter)则返回一个新的 Tensor,不改变输入数据。
函数参数说明
scatter(dim, index, src) 的核心逻辑是根据索引将源数据填充到目标位置:
- dim: 指定沿着哪个维度进行索引填充。例如 0 代表行,1 代表列。
- index: 一个长整型张量,形状需与
src一致,用于指示每个元素的目标位置。 - src: 源数据,可以是标量或张量。
按行填充示例
我们通过一个具体的例子来理解 dim=0 时的行为。假设我们要把随机生成的矩阵 x 填充到一个全零矩阵的特定行上。
import torch
# 创建一个 2x5 的随机张量作为源数据
x = torch.rand(2, 5)
# 定义索引张量,形状必须与 src 一致
# 这里的值代表目标张量在 dim=0 方向上的行号
index = torch.LongTensor([
[0, 1, 2, 0, 0],
[2, 0, 0, 1, 2]
])
# 创建一个 3x5 的全零张量作为目标
# scatter_ 会按照 index 的指引,把 x 的值填进去
result = torch.zeros(3, 5).scatter_(0, index, x)
在这个例子中,index 里的第 1 行第 0 列是 2,意味着 x[0, 0] 的值会被填入 result 的第 2 行第 0 列。注意,目标张量的列数必须与 src 相同,而 index 中的最大值不能超过目标张量在 dim 维度上的大小减一。
运行后,result 的形状依然是 (3, 5),但部分位置已被 x 的数据覆盖。这种机制常用于 One-hot 编码转换、动态掩码设置等场景,比循环赋值高效得多。


