torch.view() 基础用法
在 PyTorch 中,如果你需要改变一个 Tensor 的大小或者形状,torch.view() 是最常用的工具之一。它的核心作用就是重构张量的维度,功能上类似于 NumPy 中的 resize()。需要注意的是,view() 返回的数据和传入的 Tensor 是同一个对象,只是形状不同而已。
-1 参数的含义
在使用 view() 时,经常会在某个维度位置看到 -1。这个符号的意思是让 PyTorch 帮你计算该维度的具体大小。只要保证所有维度相乘后的总数与原 Tensor 的元素数量一致即可,系统会自动推算出缺失的那一项。
内存共享机制
这一点非常关键:view() 返回的 Tensor 和传入的 Tensor 共享底层内存。这意味着它们指向的是同一块数据区域。如果你修改了视图中的数据,原始 Tensor 也会随之改变;反之亦然。这种设计避免了不必要的内存拷贝,提升了效率,但也要求我们在操作时格外小心。
代码示例
import torch
# 创建一个 4x4 的张量
t = torch.arange(16).view(4, 4)
print(t.shape)
# 使用 -1 自动推断行数为 2,列数由系统计算
new_t = t.view(-1, 8)
print(new_t.shape) # 输出:torch.Size([2, 8])
注意事项
虽然 view() 很灵活,但它对张量的存储顺序有要求。如果张量不是连续的(contiguous),直接调用 view() 可能会报错。遇到这种情况,可以先调用 .contiguous() 方法处理一下再重塑形状。

