语音数据天然携带隐私,把录音上传到云端训练越来越不被接受。联邦学习解决了这个矛盾:模型更新在本地计算,只把加密的梯度发给中心服务器聚合,原始音频从不离开设备。OpenAI 的 Whisper 在语音识别上能力很强,把它放进联邦学习框架,能在保护用户数据的同时迭代模型。这篇文章记录我们尝试的过程——不是什么完美方案,但可跑、可改,适合在敏感领域做原型。
为什么要把 Whisper 和联邦学习放在一起
Whisper 是个多语言语音识别模型,由 OpenAI 开源,识别准确率不错。但它传统的训练方式依赖集中式数据:把所有录音收集起来,统一训练。对医疗、金融、企业内部会议这些场景,原始语音不能出域。联邦学习提供了一种'数据不动模型动'的思路:每个客户端在自己的数据上训练本地模型,只上传梯度(或参数差值),服务器做聚合,形成新的全局模型,再发下去。
这里的核心点:客户端保留语音数据;上传的更新量可以加噪(差分隐私)或加密(如同态加密)来进一步加强保护。严格来说,仅上传梯度仍可能泄漏部分信息,叠加上差分隐私后,攻击者很难反推个体数据。
这张图概括了整个流程:
graph TD
Client1[客户端 1: 本地语音数据] --> Train1[本地 Whisper 训练]
Client2[客户端 2: 本地语音数据] --> Train2[本地 Whisper 训练]
Client3[客户端 3: 本地语音数据] --> Train3[本地 Whisper 训练]
Train1 --> Upload1[上传加密梯度]
Train2 --> Upload2[上传加密梯度]
Train3 --> Upload3[上传加密梯度]
Server[联邦服务器] --> Aggregate[安全聚合]
Upload1 --> Server
Upload2 --> Server
Upload3 --> Server
Aggregate --> Privacy[应用差分隐私]
Privacy --> Update[更新全局模型]
Update --> Distribute[分发新模型]
Distribute --> Client1
Distribute --> Client2
Distribute --> Client3
算法骨架:FedAvg 加差分隐私
联邦学习里最常用的优化目标是联邦平均(FedAvg):
$$\min_w \sum_{k=1}^K \frac{n_k}{N} F_k(w)$$
其中 $w$ 是全局模型参数,$K$ 是参与客户端数,$n_k$ 是客户端 $k$ 的数据量,$N$ 是总数据量,$F_k$ 是客户端损失。每一轮训练,选取部分客户端,各自用本地 SGD 更新模型,把参数差值发给服务器,服务器按数据量加权平均得到全局更新。
差分隐私的保证如下(对相邻数据集 $D$ 和 $D'$,任意输出集合 $S$):
$$\Pr[\mathcal{A}(D) \in S] \leq e^\epsilon \cdot \Pr[\mathcal{A}(D') \in S] + \delta$$
实现上,最简单的方式是在聚合后的梯度上叠加高斯噪声:
$$\Delta w_{priv} = \Delta w + \mathcal{N}(0, \sigma^2 I)$$
$\sigma$ 会由隐私预算 $(\epsilon, \delta)$ 确定。注意,$\epsilon$ 越小隐私越强,但模型收敛会变慢,这是典型的权衡。
代码实现:从环境搭建到完整训练
环境用 conda 建一个干净隔离:
conda create -n fl-whisper python=3.8
conda activate fl-whisper
pip install torch transformers datasets soundfile librosa
核心逻辑拆成几块。第一个是数据准备,我们用 LibriSpeech 模拟客户端分片:
import torch
from datasets import load_dataset
from transformers import WhisperFeatureExtractor, WhisperTokenizer
def prepare_dataset(client_id, num_clients=5):
librispeech = load_dataset("librispeech_asr", , split=)
total_samples = (librispeech)
samples_per_client = total_samples // num_clients
start = client_id * samples_per_client
end = (client_id + ) * samples_per_client client_id != num_clients - total_samples
client_data = librispeech.select((start, end))
feature_extractor = WhisperFeatureExtractor.from_pretrained()
tokenizer = WhisperTokenizer.from_pretrained(, language=, task=)
():
audio = example[]
inputs = feature_extractor(audio[], sampling_rate=audio[], return_tensors=).input_features[]
labels = tokenizer(example[]).input_ids
{: inputs, : labels}
client_data.(prepare_example)

