在 CUDA Cores 的 FP32 寄存器中进行高精度的累加,最终结果经过 Scaling Factor 缩放,也就是反量化。
重复步骤 1-3,直到完成所有的矩阵乘法和累加操作。
Python Native 实现
核心代码示例:
defnative_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
M, K = A.shape
N, K_B = B.shape
block_k, block_n = block_size
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i inrange(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i inrange(k_tiles)
]
for j inrange(n_tiles)
]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j inrange(n_tiles)]
As_tiles = [As[:, i : i + 1] for i inrange(k_tiles)]
for i inrange(k_tiles):
for j inrange(n_tiles):
a = A_tiles[i] # [M, 128]
b = B_tiles[j][i] # [128, 128]
c = C_tiles[j] # [M, 128]
s = As_tiles[i] * Bs[j][i] #[M, 1]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(C_shape).to(output_dtype)
return C
可以结合上面对矩阵乘法的注释来理解分块矩阵乘法的过程:进行矩阵乘法的时候,先对矩阵 A 和 B 依照各自的量化粒度分块,在分块的粒度上进行矩阵乘法运算,然后再乘以量化因子进行反量化,得到分块的 FP32 浮点结果。
Triton 实现
代码参考 sglang 中的实现:
1. 函数接口
defw8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
output_dtype: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""pass
2. Triton 算子配置
# 尝试加载之前通过 tuning 方式获得的最佳配置信息。
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# If an optimal configuration map has been found, look up the# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Default config# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}