transformer图像切块与还原(window_partition+window_unpartition)
文章目录
前言
假如b ,h,w,c=(3,32,32,768)需将h w按照14尺寸切割,32/14无法整除,需pad为(3,42,42,768)完成固定尺寸块切割,进而完成transformer结构,最终摒弃pad数据还原为(3,32,32,768)。在使用Transformer结构提取特征时,通常会使用window_partition和window_unpartition来划分和还原图像块的过程。这两个步骤是为了将图像分割成小块,送入Transformer网络进行处理,然后再将处理后的特征重新组合成原始图像的尺寸。为此,我摘录TAM大模型处理方法代码,记录图像尺寸切割与还原。
一、切割图像(window_partition)
这一步骤是将原始图像按照设定的窗口大小划分成多个块,并将这些块重新排列成一个较大的矩阵,以便送入Transformer网络。
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: """ Partition into non-overlapping windows with padding if needed. Args: x (tensor): input tokens with [B, H, W, C]. window_size (int): window size. Returns: windows: windows after partition with [B * num_windows, window_size, window_size, C]. (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows, (Hp, Wp) 二、还原图像(window_unpartition)
在完成特征提取后,使用window_unpartition将处理后的特征重新还原为原始图像的尺寸。这样可以保持特征与原始图像之间的对应关系。
def window_unpartition( windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] ) -> torch.Tensor: """ Window unpartition into original sequences and removing padding. Args: windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. window_size (int): window size. pad_hw (Tuple): padded height and width (Hp, Wp). hw (Tuple): original height and width (H, W) before padding. Returns: x: unpartitioned sequences with [B, H, W, C]. """ Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) if Hp > H or Wp > W: x = x[:, :H, :W, :].contiguous() return x 三、整体代码
import torch from typing import Optional, Tuple, Type import torch.nn.functional as F def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: """ Partition into non-overlapping windows with padding if needed. Args: x (tensor): input tokens with [B, H, W, C]. window_size (int): window size. Returns: windows: windows after partition with [B * num_windows, window_size, window_size, C]. (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows, (Hp, Wp) def window_unpartition( windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] ) -> torch.Tensor: """ Window unpartition into original sequences and removing padding. Args: windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. window_size (int): window size. pad_hw (Tuple): padded height and width (Hp, Wp). hw (Tuple): original height and width (H, W) before padding. Returns: x: unpartitioned sequences with [B, H, W, C]. """ Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) if Hp > H or Wp > W: x = x[:, :H, :W, :].contiguous() return x if __name__ == '__main__': x=torch.randn((3,32,32,768)) # b,h,w,c window_size=14 H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, window_size) # 使用window_size尺寸划分图像块 print("使用window_partition填充,修改尺寸格式为:",x.shape) y = window_unpartition(x, window_size, pad_hw, (H, W)) # 在返回原有尺寸 print("window_unpartition,返回原有尺寸格式为:",y.shape) 结果显示:
