跳到主要内容
极客日志极客日志
首页博客AI提示词GitHub精选代理工具
搜索
|注册
博客列表
PythonAI算法

预训练语言模型与 BERT 实战应用

预训练语言模型采用“预训练 + 微调”范式,大幅降低对标注数据的依赖。BERT 作为双向 Transformer 编码器代表,通过掩码语言模型和下一句预测任务实现上下文理解。基于 Hugging Face Transformers 库,演示中文文本分类实战流程,包括数据预处理、模型构建、微调训练及推理预测。同时分享学习率调度、梯度累积等优化技巧,并介绍 RoBERTa 等改进模型,助力开发者高效落地 NLP 任务。

极客零度发布于 2026/3/23更新于 2026/5/43 浏览
预训练语言模型与 BERT 实战应用

预训练语言模型与 BERT 实战应用

核心目标与重点

掌握预训练语言模型的核心思想、BERT 模型的架构原理,以及基于 BERT 的文本分类任务实战流程。理解 BERT 的双向注意力机制与掩码语言模型预训练任务,学会使用 Hugging Face Transformers 库调用 BERT 模型并完成微调。

为什么需要预训练语言模型

传统的自然语言处理模型(如 LSTM+ 词嵌入)存在两个核心痛点:一是需要大量标注数据才能训练出高性能模型,二是模型对语言上下文的理解能力有限。预训练语言模型的出现解决了这些问题。它的核心思路是先在大规模无标注文本语料上进行预训练,学习通用的语言知识和语义表示,再针对特定任务进行微调。这种'预训练 + 微调'的范式,极大降低了对标注数据的依赖,同时显著提升了模型在各类 NLP 任务上的性能。

预训练语言模型的发展可以分为三个阶段:

  1. 单向语言模型阶段:以 ELMo 为代表,通过双向 LSTM 分别学习正向和反向的语言表示,再拼接得到词向量。但 ELMo 本质还是基于 RNN 的特征提取器,无法捕捉深层的上下文依赖。
  2. 自回归语言模型阶段:以 GPT 为代表,采用单向 Transformer 解码器架构,通过自回归的方式预测下一个词。但单向模型只能利用前文信息,无法利用后文信息,在理解类任务上表现受限。
  3. 双向语言模型阶段:以 BERT 为代表,采用双向 Transformer 编码器架构,通过掩码语言模型任务,让模型同时学习前文和后文的信息,真正实现了双向上下文理解。

预训练 + 微调的核心流程

预训练语言模型的应用流程分为两个关键步骤:

  1. 预训练阶段:在大规模无标注语料(如维基百科、书籍语料)上,通过设计特定的预训练任务(如掩码语言模型、下一句预测),让模型学习语言的语法、语义和常识知识,得到通用的语言表示模型。
  2. 微调阶段:针对具体的 NLP 任务(如文本分类、命名实体识别、机器翻译),在预训练模型的基础上,添加少量任务相关的输出层,使用少量标注数据进行训练,得到任务专用模型。

这里得提醒一下,预训练阶段通常需要海量的计算资源和数据,一般由大厂或研究机构完成。普通开发者只需下载预训练好的模型权重,直接进行微调即可。

BERT 模型架构与预训练任务详解

BERT 的核心架构

BERT 的全称是 Bidirectional Encoder Representations from Transformers,即基于 Transformer 编码器的双向表示模型。它的核心架构是多层双向 Transformer 编码器,没有解码器部分。BERT 的模型结构有两个版本,满足不同的算力需求:

  • BERT-Base:12 层 Transformer 编码器,12 个注意力头,隐藏层维度 768,参数量约 110M。
  • BERT-Large:24 层 Transformer 编码器,16 个注意力头,隐藏层维度 1024,参数量约 340M。

BERT 的输入表示是三种嵌入的求和:

  1. 词嵌入(Token Embedding):表示每个词的基础语义信息。
  2. 分段嵌入(Segment Embedding):用于区分两个句子(如判断句子是否为上下文关系),取值为 0 或 1。
  3. 位置嵌入(Position Embedding):和 Transformer 一样,用于注入词的位置信息,因为 Transformer 本身是无序的。
import tensorflow as tf
from transformers import BertConfig, BertModel

# 加载 BERT-Base 配置
config = BertConfig.from_pretrained('bert-base-uncased')
# 初始化 BERT 模型
bert_model = BertModel.from_pretrained('bert-base-uncased')

# 模拟输入:batch_size=2,sequence_length=10
input_ids = tf.random.randint(0, config.vocab_size, (2, 10))
attention_mask = tf.ones((2, 10))  # 1 表示有效 token,0 表示填充 token
token_type_ids = tf.zeros((2, 10))  # 0 表示第一个句子,1 表示第二个句子

# 获取 BERT 输出
outputs = bert_model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    token_type_ids=token_type_ids
)

last_hidden_state = outputs.last_hidden_state  # 最后一层隐藏状态,shape=(2,10,768)
pooler_output = outputs.pooler_output  # 特殊 token [CLS] 的输出,shape=(2,768)

print("最后一层隐藏状态形状:", last_hidden_state.shape)
print("CLS token 输出形状:", pooler_output.shape)

BERT 的预训练任务

BERT 的预训练包含两个核心任务,通过这两个任务让模型学习双向上下文信息:

  1. 掩码语言模型(Masked Language Model, MLM)
    • 随机选择 15% 的 token 进行掩码处理:80% 的概率替换为 [MASK],10% 的概率替换为随机 token,10% 的概率保持原 token 不变。
    • 模型的任务是预测被掩码的 token 的原始值。这个任务强制模型学习上下文的双向依赖关系,因为要预测掩码 token,必须同时考虑前后文的信息。
  2. 下一句预测(Next Sentence Prediction, NSP)
    • 输入一对句子(A 和 B),50% 的概率 B 是 A 的真实下一句,50% 的概率 B 是随机选择的句子。
    • 模型的任务是判断 B 是否是 A 的下一句。这个任务让模型学习句子之间的逻辑关系,适用于问答、文本摘要等需要理解句子关系的任务。

注意:后续的研究发现,NSP 任务对部分下游任务的提升有限,甚至可能带来负面影响。因此,一些改进版的 BERT 模型(如 RoBERTa)取消了 NSP 任务。

Hugging Face Transformers 库快速上手

Hugging Face Transformers 是目前最流行的预训练语言模型工具库,它提供了包括 BERT、GPT、RoBERTa、T5 等在内的数百种预训练模型的实现,支持 TensorFlow 和 PyTorch 两种框架,极大简化了预训练模型的使用流程。

安装与环境配置

首先安装 Transformers 库和相关依赖:

pip install transformers datasets tensorflow

核心组件介绍

Transformers 库的核心组件包括:

  • Config:存储模型的配置信息,如层数、隐藏层维度、注意力头数等。
  • Tokenizer:负责文本的预处理,包括分词、转换为 token id、添加特殊 token、填充和截断等。
  • Model:预训练模型的核心代码,不同的模型对应不同的 Model 类,如 BertModel、BertForSequenceClassification 等。

实战:基于 BERT 的中文文本分类任务

任务介绍与数据集准备

本次实战任务是中文新闻文本分类。我们使用 THUCNews 数据集的子集,包含 10 个新闻类别:体育、娱乐、家居、房产、教育、时尚、时政、游戏、科技、财经。我们的目标是基于 BERT-base-chinese 模型,搭建文本分类模型,实现对新闻类别的自动判断。

  1. 加载 THUCNews 子集数据集,划分训练集、验证集和测试集
  2. 使用 BertTokenizer 对文本进行分词处理,转换为模型可接受的输入格式
  3. 设置序列最大长度为 128,对过长的文本进行截断,过短的文本进行填充
from datasets import load_dataset
from transformers import BertTokenizerFast

# 加载数据集(这里使用本地的 THUCNews 子集,也可以使用 Hugging Face Hub 上的公开数据集)
dataset = load_dataset(
    'csv',
    data_files={
        'train': 'thucnews_train.csv',
        'val': 'thucnews_val.csv',
        'test': 'thucnews_test.csv'
    }
)

# 加载中文 BERT 分词器
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')

# 定义文本预处理函数
def preprocess_function(examples):
    # 对文本进行分词、转换为 token id、填充和截断
    return tokenizer(
        examples['text'],
        max_length=128,
        padding='max_length',
        truncation=True
    )

# 对数据集进行预处理
tokenized_dataset = dataset.map(preprocess_function, batched=True)

# 重命名标签列,适配模型输入
tokenized_dataset = tokenized_dataset.rename_column('label', 'labels')

# 设置数据集格式为 TensorFlow 格式
tokenized_dataset.set_format(type='tensorflow', columns=['input_ids', 'attention_mask', 'labels'])

# 生成训练集和验证集的 tf.data.Dataset
batch_size = 32
train_dataset = tokenized_dataset['train'].to_tf_dataset(
    columns=['input_ids', 'attention_mask'],
    label_cols=['labels'],
    batch_size=batch_size,
    shuffle=True
)
val_dataset = tokenized_dataset['val'].to_tf_dataset(
    columns=['input_ids', 'attention_mask'],
    label_cols=['labels'],
    batch_size=batch_size,
    shuffle=False
)

搭建 BERT 文本分类模型

我们使用 BertForSequenceClassification 类,它是 BERT 模型针对序列分类任务的专用版本。它在 BERT 的输出层后,添加了一个全连接层,用于将 [CLS] token 的输出映射到分类标签空间。

from transformers import TFBertForSequenceClassification

# 加载 BERT 中文预训练模型,指定分类类别数为 10
model = TFBertForSequenceClassification.from_pretrained(
    'bert-base-chinese',
    num_labels=10,
    problem_type='single_label_classification'
)

# 编译模型
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=2e-5),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy('accuracy')]
)

# 查看模型结构
model.summary()

注意:BERT 模型的学习率通常设置为 2e-5 或 5e-5,远小于普通深度学习模型的学习率。这是因为预训练模型已经学习了丰富的语言知识,过高的学习率会破坏预训练的权重。

模型微调与评估

  1. 设置训练参数,训练轮数设置为 3 轮(BERT 模型微调通常不需要太多轮数,否则容易过拟合)
  2. 使用验证集监控模型性能,保存最佳模型
  3. 在测试集上评估模型的最终性能
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# 定义回调函数
callbacks = [
    # 早停:当验证集损失不再下降时停止训练
    EarlyStopping(monitor='val_loss', patience=1, restore_best_weights=True),
    # 保存最佳模型
    ModelCheckpoint('best_bert_thucnews.h5', monitor='val_accuracy', save_best_only=True)
]

# 开始微调模型
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=3,
    callbacks=callbacks
)

# 加载测试集
test_dataset = tokenized_dataset['test'].to_tf_dataset(
    columns=['input_ids', 'attention_mask'],
    label_cols=['labels'],
    batch_size=batch_size,
    shuffle=False
)

# 在测试集上评估模型
test_loss, test_acc = model.evaluate(test_dataset)
print(f"测试集损失:{test_loss:.4f}")
print(f"测试集准确率:{test_acc:.4f}")

模型预测与推理

训练完成后,我们可以使用模型对新的中文文本进行分类预测:

# 定义预测函数
def predict_text_category(text):
    # 预处理文本
    inputs = tokenizer(
        text,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='tf'
    )
    # 获取预测结果
    outputs = model(inputs)
    logits = outputs.logits
    # 转换为类别概率
    probabilities = tf.nn.softmax(logits, axis=-1)
    # 获取预测类别
    predicted_label = tf.argmax(probabilities, axis=-1).numpy()[0]
    # 类别映射字典
    label_map = {
        0: '体育', 1: '娱乐', 2: '家居', 3: '房产', 4: '教育',
        5: '时尚', 6: '时政', 7: '游戏', 8: '科技', 9: '财经'
    }
    return label_map[predicted_label]

# 测试预测
test_text = "北京时间 10 月 1 日,2024 年巴黎奥运会男篮决赛在法兰西体育场举行,美国队以 102-87 击败法国队,夺得金牌。"
print(f"文本内容:{test_text}")
print(f"预测类别:{predict_text_category(test_text)}")

BERT 模型的优化与改进方向

模型优化技巧

  1. 使用学习率调度器:在微调过程中,使用线性学习率衰减策略,让学习率随着训练轮数的增加而逐渐降低,提升模型的泛化能力。
  2. 使用梯度累积:当显存不足时,可以使用梯度累积技术,将多个小批次的梯度累积起来,再进行一次参数更新,相当于增大了批次大小。
  3. 使用知识蒸馏:将大模型(如 BERT-Large)的知识蒸馏到小模型(如 DistilBERT)中,在保证性能损失较小的前提下,显著提升模型的推理速度。

BERT 的改进模型

BERT 提出后,研究者们提出了许多改进版本,进一步提升了模型性能:

  • RoBERTa:取消了 NSP 任务,使用更大的批次大小和更多的训练数据,性能全面超越 BERT。
  • ALBERT:通过参数共享技术,大幅减少模型参数量,提升训练效率。
  • ERNIE:百度提出的中文增强版 BERT,通过引入实体级和短语级的掩码策略,提升了模型对中文语义的理解能力。
  • SpanBERT:将掩码单位从单个 token 改为连续的 token span,提升了模型对短语和实体的建模能力。

总结

预训练语言模型采用'预训练 + 微调'的范式,先在大规模无标注语料上学习通用语言知识,再针对具体任务进行微调。BERT 是基于双向 Transformer 编码器的预训练模型,通过掩码语言模型和下一句预测任务,实现了双向上下文理解。使用 Hugging Face Transformers 库可以快速调用 BERT 模型,只需少量代码即可完成中文文本分类等任务的微调。BERT 模型的微调需要注意学习率的设置,通常使用 2e-5 或 5e-5 的小学习率,避免破坏预训练权重。

目录

  1. 预训练语言模型与 BERT 实战应用
  2. 核心目标与重点
  3. 为什么需要预训练语言模型
  4. 预训练 + 微调的核心流程
  5. BERT 模型架构与预训练任务详解
  6. BERT 的核心架构
  7. 加载 BERT-Base 配置
  8. 初始化 BERT 模型
  9. 模拟输入:batchsize=2,sequencelength=10
  10. 获取 BERT 输出
  11. BERT 的预训练任务
  12. Hugging Face Transformers 库快速上手
  13. 安装与环境配置
  14. 核心组件介绍
  15. 实战:基于 BERT 的中文文本分类任务
  16. 任务介绍与数据集准备
  17. 加载数据集(这里使用本地的 THUCNews 子集,也可以使用 Hugging Face Hub 上的公开数据集)
  18. 加载中文 BERT 分词器
  19. 定义文本预处理函数
  20. 对数据集进行预处理
  21. 重命名标签列,适配模型输入
  22. 设置数据集格式为 TensorFlow 格式
  23. 生成训练集和验证集的 tf.data.Dataset
  24. 搭建 BERT 文本分类模型
  25. 加载 BERT 中文预训练模型,指定分类类别数为 10
  26. 编译模型
  27. 查看模型结构
  28. 模型微调与评估
  29. 定义回调函数
  30. 开始微调模型
  31. 加载测试集
  32. 在测试集上评估模型
  33. 模型预测与推理
  34. 定义预测函数
  35. 测试预测
  36. BERT 模型的优化与改进方向
  37. 模型优化技巧
  38. BERT 的改进模型
  39. 总结
  • 💰 8折买阿里云服务器限时8折了解详情
  • GPT-5.5 超高智商模型1元抵1刀ChatGPT中转购买
  • 代充Chatgpt Plus/pro 帐号了解详情
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • JSON-java CDL转换终极指南:快速掌握逗号分隔列表与JSONArray互转技巧
  • 图像畸变矫正原理及 MATLAB 与 FPGA 实现
  • 护网行动与红蓝对抗详解
  • FPGA与DSP协同通信系统设计与接口选型
  • 使用 frontend-design Skill 提升大模型前端设计审美
  • 前端流式输出实战:从原理到框架落地
  • WebODM 开源无人机地图制作完全指南
  • 小智 AI 设备绑定与解绑操作指南
  • 前端流式输出技术详解:原理与实战方案
  • GitHub 开源游戏项目与引擎资源汇总
  • C++ 红黑树实现与优化详解
  • 基于回调接口将 AI 小助手接入企业微信实现群聊机器人
  • 前端代码分割与懒加载实战指南
  • 机器人系统架构详解:2026 年最新技术路线
  • SpringAI Agent 实战:Java 开发者接入 Agent Skills 指南
  • OpenClaw 开源项目实现机器人空间记忆与具身智能突破
  • AI 产品架构设计:从 0 到 1 搭建信息架构与核心业务流程
  • 户外机器人 GNSS 仿真测试:双天线定向与 RTK 高精度定位实战
  • 大语言模型在线连续知识学习(OCKL)框架与方法研究
  • SystemVerilog 硬件验证实战:从基础语法到高级特性应用

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online