预训练语言模型与 BERT 实战应用
介绍预训练语言模型的发展及核心思想,详解 BERT 架构与预训练任务(MLM、NSP)。通过 Hugging Face Transformers 库,演示了基于 BERT-base-chinese 的中文新闻文本分类实战流程,包括数据预处理、模型搭建、微调训练及推理预测,并总结了优化技巧与改进方向。

介绍预训练语言模型的发展及核心思想,详解 BERT 架构与预训练任务(MLM、NSP)。通过 Hugging Face Transformers 库,演示了基于 BERT-base-chinese 的中文新闻文本分类实战流程,包括数据预处理、模型搭建、微调训练及推理预测,并总结了优化技巧与改进方向。

💡 学习目标:掌握预训练语言模型的核心思想、BERT 模型的架构原理,以及基于 BERT 的文本分类任务实战流程。 💡 学习重点:理解 BERT 的双向注意力机制与掩码语言模型预训练任务,学会使用 Hugging Face Transformers 库调用 BERT 模型并完成微调。
💡 传统的自然语言处理模型(如 LSTM+ 词嵌入)存在两个核心痛点:一是需要大量标注数据才能训练出高性能模型,二是模型对语言上下文的理解能力有限。 预训练语言模型的出现解决了这些问题。它的核心思路是先在大规模无标注文本语料上进行预训练,学习通用的语言知识和语义表示,再针对特定任务进行微调。这种'预训练 + 微调'的范式,极大降低了对标注数据的依赖,同时显著提升了模型在各类 NLP 任务上的性能。
预训练语言模型的发展可以分为三个阶段:
预训练语言模型的应用流程分为两个关键步骤:
⚠️ 注意:预训练阶段通常需要海量的计算资源和数据,一般由大厂或研究机构完成。普通开发者只需下载预训练好的模型权重,直接进行微调即可。
💡 BERT 的全称是Bidirectional Encoder Representations from Transformers,即基于 Transformer 编码器的双向表示模型。它的核心架构是多层双向 Transformer 编码器,没有解码器部分。 BERT 的模型结构有两个版本,满足不同的算力需求:
BERT 的输入表示是三种嵌入的求和:
import tensorflow as tf
from transformers import BertConfig, BertModel
# 加载 BERT-Base 配置
config = BertConfig.from_pretrained('bert-base-uncased')
# 初始化 BERT 模型
bert_model = BertModel.from_pretrained('bert-base-uncased')
# 模拟输入:batch_size=2,sequence_length=10
input_ids = tf.random.randint(0, config.vocab_size, (2, 10))
attention_mask = tf.ones((2, 10)) # 1 表示有效 token,0 表示填充 token
token_type_ids = tf.zeros((2, 10)) # 0 表示第一个句子,1 表示第二个句子
# 获取 BERT 输出
outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
last_hidden_state = outputs.last_hidden_state # 最后一层隐藏状态,shape=(2,10,768)
pooler_output = outputs.pooler_output # 特殊 token [CLS] 的输出,shape=(2,768)
print("最后一层隐藏状态形状:", last_hidden_state.shape)
print("CLS token 输出形状:", pooler_output.shape)
BERT 的预训练包含两个核心任务,通过这两个任务让模型学习双向上下文信息:
[MASK],10% 的概率替换为随机 token,10% 的概率保持原 token 不变。⚠️ 注意:后续的研究发现,NSP 任务对部分下游任务的提升有限,甚至可能带来负面影响。因此,一些改进版的 BERT 模型(如 RoBERTa)取消了 NSP 任务。
💡 Hugging Face Transformers是目前最流行的预训练语言模型工具库,它提供了包括 BERT、GPT、RoBERTa、T5 等在内的数百种预训练模型的实现,支持 TensorFlow 和 PyTorch 两种框架,极大简化了预训练模型的使用流程。
首先安装 Transformers 库和相关依赖:
pip install transformers datasets tensorflow
Transformers 库的核心组件包括:
BertModel、BertForSequenceClassification等。💡 本次实战任务是中文新闻文本分类。我们使用 THUCNews 数据集的子集,包含 10 个新闻类别:体育、娱乐、家居、房产、教育、时尚、时政、游戏、科技、财经。我们的目标是基于 BERT-base-chinese 模型,搭建文本分类模型,实现对新闻类别的自动判断。
① 加载 THUCNews 子集数据集,划分训练集、验证集和测试集
② 使用BertTokenizer对文本进行分词处理,转换为模型可接受的输入格式
③ 设置序列最大长度为 128,对过长的文本进行截断,过短的文本进行填充
from datasets import load_dataset
from transformers import BertTokenizerFast
# 加载数据集(这里使用本地的 THUCNews 子集,也可以使用 Hugging Face Hub 上的公开数据集)
dataset = load_dataset('csv', data_files={
'train': 'thucnews_train.csv',
'val': 'thucnews_val.csv',
'test': 'thucnews_test.csv'
})
# 加载中文 BERT 分词器
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
# 定义文本预处理函数
def preprocess_function(examples):
# 对文本进行分词、转换为 token id、填充和截断
return tokenizer(
examples['text'],
max_length=128,
padding='max_length',
truncation=True
)
# 对数据集进行预处理
tokenized_dataset = dataset.map(preprocess_function, batched=True)
# 重命名标签列,适配模型输入
tokenized_dataset = tokenized_dataset.rename_column('label', 'labels')
# 设置数据集格式为 TensorFlow 格式
tokenized_dataset.set_format(type='tensorflow', columns=['input_ids', 'attention_mask', 'labels'])
# 生成训练集和验证集的 tf.data.Dataset
batch_size = 32
train_dataset = tokenized_dataset['train'].to_tf_dataset(
columns=['input_ids', 'attention_mask'],
label_cols=['labels'],
batch_size=batch_size,
shuffle=True
)
val_dataset = tokenized_dataset[].to_tf_dataset(
columns=[, ],
label_cols=[],
batch_size=batch_size,
shuffle=
)
💡 我们使用BertForSequenceClassification类,它是 BERT 模型针对序列分类任务的专用版本。它在 BERT 的输出层后,添加了一个全连接层,用于将[CLS]token 的输出映射到分类标签空间。
from transformers import TFBertForSequenceClassification
# 加载 BERT 中文预训练模型,指定分类类别数为 10
model = TFBertForSequenceClassification.from_pretrained(
'bert-base-chinese',
num_labels=10,
problem_type='single_label_classification'
)
# 编译模型
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=2e-5),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy('accuracy')]
)
# 查看模型结构
model.summary()
⚠️ 注意:BERT 模型的学习率通常设置为 2e-5 或 5e-5,远小于普通深度学习模型的学习率。这是因为预训练模型已经学习了丰富的语言知识,过高的学习率会破坏预训练的权重。
① 设置训练参数,训练轮数设置为 3 轮(BERT 模型微调通常不需要太多轮数,否则容易过拟合) ② 使用验证集监控模型性能,保存最佳模型 ③ 在测试集上评估模型的最终性能
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
# 定义回调函数
callbacks = [
# 早停:当验证集损失不再下降时停止训练
EarlyStopping(monitor='val_loss', patience=1, restore_best_weights=True),
# 保存最佳模型
ModelCheckpoint('best_bert_thucnews.h5', monitor='val_accuracy', save_best_only=True)
]
# 开始微调模型
history = model.fit(
train_dataset,
validation_data=val_dataset,
epochs=3,
callbacks=callbacks
)
# 加载测试集
test_dataset = tokenized_dataset['test'].to_tf_dataset(
columns=['input_ids', 'attention_mask'],
label_cols=['labels'],
batch_size=batch_size,
shuffle=False
)
# 在测试集上评估模型
test_loss, test_acc = model.evaluate(test_dataset)
print(f"测试集损失:{test_loss:.4f}")
print(f"测试集准确率:{test_acc:.4f}")
训练完成后,我们可以使用模型对新的中文文本进行分类预测:
# 定义预测函数
def predict_text_category(text):
# 预处理文本
inputs = tokenizer(
text,
max_length=128,
padding='max_length',
truncation=True,
return_tensors='tf'
)
# 获取预测结果
outputs = model(inputs)
logits = outputs.logits
# 转换为类别概率
probabilities = tf.nn.softmax(logits, axis=-1)
# 获取预测类别
predicted_label = tf.argmax(probabilities, axis=-1).numpy()[0]
# 类别映射字典
label_map = {
0: '体育', 1: '娱乐', 2: '家居', 3: '房产', 4: '教育',
5: '时尚', 6: '时政', 7: '游戏', 8: '科技', 9: '财经'
}
return label_map[predicted_label]
# 测试预测
test_text = "北京时间 10 月 1 日,2024 年巴黎奥运会男篮决赛在法兰西体育场举行,美国队以 102-87 击败法国队,夺得金牌。"
print(f"文本内容:{test_text}")
print(f"预测类别:{predict_text_category(test_text)}")
💡 技巧 1:使用学习率调度器。在微调过程中,使用线性学习率衰减策略,让学习率随着训练轮数的增加而逐渐降低,提升模型的泛化能力。 💡 技巧 2:使用梯度累积。当显存不足时,可以使用梯度累积技术,将多个小批次的梯度累积起来,再进行一次参数更新,相当于增大了批次大小。 💡 技巧 3:使用知识蒸馏。将大模型(如 BERT-Large)的知识蒸馏到小模型(如 DistilBERT)中,在保证性能损失较小的前提下,显著提升模型的推理速度。
BERT 提出后,研究者们提出了许多改进版本,进一步提升了模型性能:
✅ 预训练语言模型采用'预训练 + 微调'的范式,先在大规模无标注语料上学习通用语言知识,再针对具体任务进行微调。 ✅ BERT 是基于双向 Transformer 编码器的预训练模型,通过掩码语言模型和下一句预测任务,实现了双向上下文理解。 ✅ 使用 Hugging Face Transformers 库可以快速调用 BERT 模型,只需少量代码即可完成中文文本分类等任务的微调。 ✅ BERT 模型的微调需要注意学习率的设置,通常使用 2e-5 或 5e-5 的小学习率,避免破坏预训练权重。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online