跳到主要内容
极客日志极客日志
首页博客AI提示词GitHub精选代理工具
搜索
|注册
博客列表
PythonAI算法

预训练语言模型与 BERT 实战应用

综述由AI生成预训练语言模型通过大规模无标注语料学习通用语义,结合微调范式解决 NLP 任务数据依赖问题。BERT 采用双向 Transformer 编码器架构,利用掩码语言模型与下一句预测任务实现上下文理解。基于 Hugging Face Transformers 库,演示了如何使用 BERT 进行中文文本分类实战,涵盖环境配置、模型构建、微调策略及推理预测全流程,并探讨了 RoBERTa 等改进方向与优化技巧。

宁静发布于 2026/3/21更新于 2026/4/304 浏览
预训练语言模型与 BERT 实战应用

预训练语言模型与 BERT 实战应用

学习目标与重点

掌握预训练语言模型的核心思想、BERT 模型的架构原理,以及基于 BERT 的文本分类任务实战流程。重点理解 BERT 的双向注意力机制与掩码语言模型预训练任务,学会使用 Hugging Face Transformers 库调用 BERT 模型并完成微调。

预训练语言模型的发展历程与核心思想

为什么需要预训练语言模型

传统的自然语言处理模型(如 LSTM+ 词嵌入)存在两个核心痛点:一是需要大量标注数据才能训练出高性能模型,二是模型对语言上下文的理解能力有限。预训练语言模型的出现解决了这些问题。它的核心思路是先在大规模无标注文本语料上进行预训练,学习通用的语言知识和语义表示,再针对特定任务进行微调。这种'预训练 + 微调'的范式,极大降低了对标注数据的依赖,同时显著提升了模型在各类 NLP 任务上的性能。

预训练语言模型的发展可以分为三个阶段:

  1. 单向语言模型阶段:以 ELMo 为代表,通过双向 LSTM 分别学习正向和反向的语言表示,再拼接得到词向量。但 ELMo 本质还是基于 RNN 的特征提取器,无法捕捉深层的上下文依赖。
  2. 自回归语言模型阶段:以 GPT 为代表,采用单向 Transformer 解码器架构,通过自回归的方式预测下一个词。但单向模型只能利用前文信息,无法利用后文信息,在理解类任务上表现受限。
  3. 双向语言模型阶段:以 BERT 为代表,采用双向 Transformer 编码器架构,通过掩码语言模型任务,让模型同时学习前文和后文的信息,真正实现了双向上下文理解。
预训练 + 微调的核心流程

预训练语言模型的应用流程分为两个关键步骤:

  1. 预训练阶段:在大规模无标注语料(如维基百科、书籍语料)上,通过设计特定的预训练任务(如掩码语言模型、下一句预测),让模型学习语言的语法、语义和常识知识,得到通用的语言表示模型。
  2. 微调阶段:针对具体的 NLP 任务(如文本分类、命名实体识别、机器翻译),在预训练模型的基础上,添加少量任务相关的输出层,使用少量标注数据进行训练,得到任务专用模型。

注意:预训练阶段通常需要海量的计算资源和数据,一般由大厂或研究机构完成。普通开发者只需下载预训练好的模型权重,直接进行微调即可。

BERT 模型架构与预训练任务详解

BERT 的核心架构

BERT 的全称是 Bidirectional Encoder Representations from Transformers,即基于 Transformer 编码器的双向表示模型。它的核心架构是多层双向 Transformer 编码器,没有解码器部分。BERT 的模型结构有两个版本,满足不同的算力需求:

  • BERT-Base:12 层 Transformer 编码器,12 个注意力头,隐藏层维度 768,参数量约 110M。
  • BERT-Large:24 层 Transformer 编码器,16 个注意力头,隐藏层维度 1024,参数量约 340M。

BERT 的输入表示是三种嵌入的求和:

  1. 词嵌入(Token Embedding):表示每个词的基础语义信息。
  2. 分段嵌入(Segment Embedding):用于区分两个句子(如判断句子是否为上下文关系),取值为 0 或 1。
  3. 位置嵌入(Position Embedding):和 Transformer 一样,用于注入词的位置信息,因为 Transformer 本身是无序的。
import tensorflow as tf
from transformers import BertConfig, BertModel

# 加载 BERT-Base 配置
config = BertConfig.from_pretrained('bert-base-uncased')
# 初始化 BERT 模型
bert_model = BertModel.from_pretrained('bert-base-uncased')

# 模拟输入:batch_size=2,sequence_length=10
input_ids = tf.random.randint(0, config.vocab_size, (2, 10))
attention_mask = tf.ones((2, 10))  # 1 表示有效 token,0 表示填充 token
token_type_ids = tf.zeros((2, 10))  # 0 表示第一个句子,1 表示第二个句子

# 获取 BERT 输出
outputs = bert_model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    token_type_ids=token_type_ids
)
last_hidden_state = outputs.last_hidden_state  # 最后一层隐藏状态,shape=(2,10,768)
pooler_output = outputs.pooler_output  # 特殊 token [CLS] 的输出,shape=(2,768)

print("最后一层隐藏状态形状:", last_hidden_state.shape)
print("CLS token 输出形状:", pooler_output.shape)
BERT 的预训练任务

BERT 的预训练包含两个核心任务,通过这两个任务让模型学习双向上下文信息:

  1. 掩码语言模型(Masked Language Model, MLM)
    • 随机选择 15% 的 token 进行掩码处理:80% 的概率替换为 [MASK],10% 的概率替换为随机 token,10% 的概率保持原 token 不变。
    • 模型的任务是预测被掩码的 token 的原始值。这个任务强制模型学习上下文的双向依赖关系,因为要预测掩码 token,必须同时考虑前后文的信息。
  2. 下一句预测(Next Sentence Prediction, NSP)
    • 输入一对句子(A 和 B),50% 的概率 B 是 A 的真实下一句,50% 的概率 B 是随机选择的句子。
    • 模型的任务是判断 B 是否是 A 的下一句。这个任务让模型学习句子之间的逻辑关系,适用于问答、文本摘要等需要理解句子关系的任务。

注意:后续的研究发现,NSP 任务对部分下游任务的提升有限,甚至可能带来负面影响。因此,一些改进版的 BERT 模型(如 RoBERTa)取消了 NSP 任务。

Hugging Face Transformers 库快速上手

Hugging Face Transformers 是目前最流行的预训练语言模型工具库,它提供了包括 BERT、GPT、RoBERTa、T5 等在内的数百种预训练模型的实现,支持 TensorFlow 和 PyTorch 两种框架,极大简化了预训练模型的使用流程。

安装与环境配置

我们先准备好环境,安装 Transformers 库和相关依赖:

pip install transformers datasets tensorflow
核心组件介绍

Transformers 库的核心组件包括:

  • Config:存储模型的配置信息,如层数、隐藏层维度、注意力头数等。
  • Tokenizer:负责文本的预处理,包括分词、转换为 token id、添加特殊 token、填充和截断等。
  • Model:预训练模型的核心代码,不同的模型对应不同的 Model 类,如 BertModel、BertForSequenceClassification 等。

实战:基于 BERT 的中文文本分类任务

任务介绍与数据集准备

本次实战任务是中文新闻文本分类。我们使用 THUCNews 数据集的子集,包含 10 个新闻类别:体育、娱乐、家居、房产、教育、时尚、时政、游戏、科技、财经。我们的目标是基于 BERT-base-chinese 模型,搭建文本分类模型,实现对新闻类别的自动判断。

首先加载 THUCNews 子集数据集,划分训练集、验证集和测试集。然后使用 BertTokenizer 对文本进行分词处理,转换为模型可接受的输入格式。设置序列最大长度为 128,对过长的文本进行截断,过短的文本进行填充。

from datasets import load_dataset
from transformers import BertTokenizerFast

# 加载数据集(这里使用本地的 THUCNews 子集,也可以使用 Hugging Face Hub 上的公开数据集)
dataset = load_dataset('csv', data_files={
    'train': 'thucnews_train.csv',
    'val': 'thucnews_val.csv',
    'test': 'thucnews_test.csv'
})

# 加载中文 BERT 分词器
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')

# 定义文本预处理函数
def preprocess_function(examples):
    # 对文本进行分词、转换为 token id、填充和截断
    return tokenizer(
        examples['text'],
        max_length=128,
        padding='max_length',
        truncation=True
    )

# 对数据集进行预处理
tokenized_dataset = dataset.map(preprocess_function, batched=True)

# 重命名标签列,适配模型输入
tokenized_dataset = tokenized_dataset.rename_column('label', 'labels')

# 设置数据集格式为 TensorFlow 格式
tokenized_dataset.set_format(type='tensorflow', columns=['input_ids','attention_mask','labels'])

# 生成训练集和验证集的 tf.data.Dataset
batch_size = 32
train_dataset = tokenized_dataset['train'].to_tf_dataset(
    columns=['input_ids','attention_mask'],
    label_cols=['labels'],
    batch_size=batch_size,
    shuffle=True
)
val_dataset = tokenized_dataset['val'].to_tf_dataset(
    columns=['input_ids','attention_mask'],
    label_cols=['labels'],
    batch_size=batch_size,
    shuffle=False
)
搭建 BERT 文本分类模型

我们使用 BertForSequenceClassification 类,它是 BERT 模型针对序列分类任务的专用版本。它在 BERT 的输出层后,添加了一个全连接层,用于将 [CLS] token 的输出映射到分类标签空间。

from transformers import TFBertForSequenceClassification

# 加载 BERT 中文预训练模型,指定分类类别数为 10
model = TFBertForSequenceClassification.from_pretrained(
    'bert-base-chinese',
    num_labels=10,
    problem_type='single_label_classification'
)

# 编译模型
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=2e-5),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy('accuracy')]
)

# 查看模型结构
model.summary()

注意:BERT 模型的学习率通常设置为 2e-5 或 5e-5,远小于普通深度学习模型的学习率。这是因为预训练模型已经学习了丰富的语言知识,过高的学习率会破坏预训练的权重。

模型微调与评估

设置训练参数,训练轮数设置为 3 轮(BERT 模型微调通常不需要太多轮数,否则容易过拟合)。使用验证集监控模型性能,保存最佳模型。最后在测试集上评估模型的最终性能。

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# 定义回调函数
callbacks = [
    # 早停:当验证集损失不再下降时停止训练
    EarlyStopping(monitor='val_loss', patience=1, restore_best_weights=True),
    # 保存最佳模型
    ModelCheckpoint('best_bert_thucnews.h5', monitor='val_accuracy', save_best_only=True)
]

# 开始微调模型
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=3,
    callbacks=callbacks
)

# 加载测试集
test_dataset = tokenized_dataset['test'].to_tf_dataset(
    columns=['input_ids','attention_mask'],
    label_cols=['labels'],
    batch_size=batch_size,
    shuffle=False
)

# 在测试集上评估模型
test_loss, test_acc = model.evaluate(test_dataset)
print(f"测试集损失:{test_loss:.4f}")
print(f"测试集准确率:{test_acc:.4f}")
模型预测与推理

训练完成后,我们可以使用模型对新的中文文本进行分类预测:

# 定义预测函数
def predict_text_category(text):
    # 预处理文本
    inputs = tokenizer(
        text,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='tf'
    )
    # 获取预测结果
    outputs = model(inputs)
    logits = outputs.logits
    # 转换为类别概率
    probabilities = tf.nn.softmax(logits, axis=-1)
    # 获取预测类别
    predicted_label = tf.argmax(probabilities, axis=-1).numpy()[0]
    # 类别映射字典
    label_map = {
        0:'体育', 1:'娱乐', 2:'家居', 3:'房产', 4:'教育',
        5:'时尚', 6:'时政', 7:'游戏', 8:'科技', 9:'财经'
    }
    return label_map[predicted_label]

# 测试预测
test_text = "北京时间 10 月 1 日,2024 年巴黎奥运会男篮决赛在法兰西体育场举行,美国队以 102-87 击败法国队,夺得金牌。"
print(f"文本内容:{test_text}")
print(f"预测类别:{predict_text_category(test_text)}")

BERT 模型的优化与改进方向

模型优化技巧

技巧 1:使用学习率调度器。在微调过程中,使用线性学习率衰减策略,让学习率随着训练轮数的增加而逐渐降低,提升模型的泛化能力。

技巧 2:使用梯度累积。当显存不足时,可以使用梯度累积技术,将多个小批次的梯度累积起来,再进行一次参数更新,相当于增大了批次大小。

技巧 3:使用知识蒸馏。将大模型(如 BERT-Large)的知识蒸馏到小模型(如 DistilBERT)中,在保证性能损失较小的前提下,显著提升模型的推理速度。

BERT 的改进模型

BERT 提出后,研究者们提出了许多改进版本,进一步提升了模型性能:

  • RoBERTa:取消了 NSP 任务,使用更大的批次大小和更多的训练数据,性能全面超越 BERT。
  • ALBERT:通过参数共享技术,大幅减少模型参数量,提升训练效率。
  • ERNIE:百度提出的中文增强版 BERT,通过引入实体级和短语级的掩码策略,提升了模型对中文语义的理解能力。
  • SpanBERT:将掩码单位从单个 token 改为连续的 token span,提升了模型对短语和实体的建模能力。

本章总结

预训练语言模型采用'预训练 + 微调'的范式,先在大规模无标注语料上学习通用语言知识,再针对具体任务进行微调。

BERT 是基于双向 Transformer 编码器的预训练模型,通过掩码语言模型和下一句预测任务,实现了双向上下文理解。

使用 Hugging Face Transformers 库可以快速调用 BERT 模型,只需少量代码即可完成中文文本分类等任务的微调。

BERT 模型的微调需要注意学习率的设置,通常使用 2e-5 或 5e-5 的小学习率,避免破坏预训练权重。

目录

  1. 预训练语言模型与 BERT 实战应用
  2. 学习目标与重点
  3. 预训练语言模型的发展历程与核心思想
  4. 为什么需要预训练语言模型
  5. 预训练 + 微调的核心流程
  6. BERT 模型架构与预训练任务详解
  7. BERT 的核心架构
  8. 加载 BERT-Base 配置
  9. 初始化 BERT 模型
  10. 模拟输入:batchsize=2,sequencelength=10
  11. 获取 BERT 输出
  12. BERT 的预训练任务
  13. Hugging Face Transformers 库快速上手
  14. 安装与环境配置
  15. 核心组件介绍
  16. 实战:基于 BERT 的中文文本分类任务
  17. 任务介绍与数据集准备
  18. 加载数据集(这里使用本地的 THUCNews 子集,也可以使用 Hugging Face Hub 上的公开数据集)
  19. 加载中文 BERT 分词器
  20. 定义文本预处理函数
  21. 对数据集进行预处理
  22. 重命名标签列,适配模型输入
  23. 设置数据集格式为 TensorFlow 格式
  24. 生成训练集和验证集的 tf.data.Dataset
  25. 搭建 BERT 文本分类模型
  26. 加载 BERT 中文预训练模型,指定分类类别数为 10
  27. 编译模型
  28. 查看模型结构
  29. 模型微调与评估
  30. 定义回调函数
  31. 开始微调模型
  32. 加载测试集
  33. 在测试集上评估模型
  34. 模型预测与推理
  35. 定义预测函数
  36. 测试预测
  37. BERT 模型的优化与改进方向
  38. 模型优化技巧
  39. BERT 的改进模型
  40. 本章总结
  • 💰 8折买阿里云服务器限时8折了解详情
  • GPT-5.5 超高智商模型1元抵1刀ChatGPT中转购买
  • 代充Chatgpt Plus/pro 帐号了解详情
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • MySQL 数据类型核心指南:选型、实战与避坑
  • 前端老鸟血泪史:搞定那些让人头秃的报错,让线上稳如老狗
  • 使用 Llama3 与 MaxKB 搭建本地私有 AI 知识库
  • Python Emoji 库使用教程
  • Windows 本地部署 Ollama 大模型:Qwen3.5 实战与并发优化
  • VSCode Copilot 登录失败的常见原因与排查方案
  • 基于 OpenClaw 与飞书搭建多 Agent AI 助理团队
  • MAVROS 安装与基础知识梳理及 ROS C++ 仿真案例
  • ChatGLM 医药行业舆情精选策略与大模型微调指南
  • Spec-Kit 与 Copilot 实现 AI 规格驱动开发
  • H.265 网页播放方案:WebAssembly + FFmpeg 实现硬解软解兼容
  • Docker Compose 多实例 Tomcat 部署示例
  • Python 开发环境搭建与安装指南(Windows 版)
  • 人工智能、机器学习与深度学习的区别及关系
  • AI 技术综述与产品经理角色重塑
  • Qwen3 与 Qwen Agent 智能体开发实战:接入 MCP 工具
  • Agent 上下文注入原理与 Web 架构映射实战
  • 二分答案专题实战:木材加工与砍树问题解析
  • 2026 年国际主流 AI IDE 排行榜与选型指南
  • Claude Code 辅助 Verilog 编程与 FPGA 设计实战

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online