大模型微调对训练效率格外敏感。在消费级 GPU 上跑 7B/13B 模型时,一次迭代快零点几秒都可能决定实验能不能按计划完成。这里面两个组件出镜率很高:一个是能显著降低注意力计算开销的 Flash Attention,另一个是门槛极低的一站式微调框架 Llama-Factory。不少人问:在 Llama-Factory 里面启动训练,注意力层到底有没有用上 Flash Attention?
答案很简单:可以,但得配对版本、打通链路。
先看为什么 Flash Attention 是刚需。Transformer 的自注意力原本有严重的访存瓶颈——QK^T、Softmax、AV 被切分成多个 CUDA kernel,中间结果频繁在 HBM 和计算单元之间倒腾。序列越长,这个开销就越疼。Flash Attention 把这些操作融合成一个 kernel,在共享内存里一次做完,只加载一次 Q、K、V,中间状态不再落回全局内存。这带来的提速在序列长度超过 1024 时尤其明显,而且算法精确,没有精度折损。
import torch
from flash_attn import flash_attn_func
q = torch.randn(1, 2048, 16, 64, device='cuda', dtype=torch.float16)
k = torch.randn(1, 2048, 16, 64, device='cuda', dtype=torch.float16)
v = torch.randn(1, 2048, 16, 64, device='cuda', dtype=torch.float16)
out = flash_attn_func(q, k, v, causal=True) # 一次融合调用
那么 Llama-Factory 怎么接上这波红利?它自己没写注意力实现,底层靠的是 Hugging Face Transformers。所以只要 Transformers 支持,装好 flash-attn 库,给 use_flash_attention_2=True 就能启用。比如加载 Llama-2 或 Qwen:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
use_flash_attention_2=True, # 这里
device_map="auto"
)
但这开关不是一设就灵的。不少用户碰过这样的提示:'Flash Attention is not available',接着就退回到默认实现,速度没变化。常见原因几个:
- flash-attn 没装对。需要 ≥2.0 的版本,最好从源码或预编译 wheel 安装。
- CUDA 版本太低。起码得 11.8+,再低就不支持必需指令集。
- 显卡架构偏老。比如 V100 缺少相关的 Tensor Core 指令。
- Transformers 版本旧了。建议 4.36 以上。

