跳到主要内容
极客日志极客日志
首页博客AI提示词GitHub精选代理工具
搜索
|注册
博客列表
PythonAI算法

预训练语言模型与 BERT 实战应用

预训练语言模型通过大规模无标注语料学习通用语义,结合微调技术解决特定 NLP 任务。BERT 采用双向 Transformer 编码器,利用掩码语言模型和下一句预测任务实现上下文理解。演示基于 Hugging Face Transformers 库的中文文本分类实战,涵盖数据预处理、模型构建、微调训练及推理评估全流程,并探讨学习率设置与模型优化技巧。

LinuxPan发布于 2026/3/22更新于 2026/5/49 浏览
预训练语言模型与 BERT 实战应用

核心目标与重点

掌握预训练语言模型的核心思想、BERT 模型的架构原理,以及基于 BERT 的文本分类任务实战流程。理解 BERT 的双向注意力机制与掩码语言模型预训练任务,学会使用 Hugging Face Transformers 库调用 BERT 模型并完成微调。

预训练模型的发展脉络

传统的自然语言处理模型(如 LSTM+ 词嵌入)存在两个核心痛点:一是需要大量标注数据才能训练出高性能模型,二是模型对语言上下文的理解能力有限。预训练语言模型的出现解决了这些问题。它的核心思路是先在大规模无标注文本语料上进行预训练,学习通用的语言知识和语义表示,再针对特定任务进行微调。这种'预训练 + 微调'的范式,极大降低了对标注数据的依赖,同时显著提升了模型在各类 NLP 任务上的性能。

预训练语言模型的发展可以分为三个阶段:

  1. 单向语言模型阶段:以 ELMo 为代表,通过双向 LSTM 分别学习正向和反向的语言表示,再拼接得到词向量。但 ELMo 本质还是基于 RNN 的特征提取器,无法捕捉深层的上下文依赖。
  2. 自回归语言模型阶段:以 GPT 为代表,采用单向 Transformer 解码器架构,通过自回归的方式预测下一个词。但单向模型只能利用前文信息,无法利用后文信息,在理解类任务上表现受限。
  3. 双向语言模型阶段:以 BERT 为代表,采用双向 Transformer 编码器架构,通过掩码语言模型任务,让模型同时学习前文和后文的信息,真正实现了双向上下文理解。

预训练 + 微调的核心流程

预训练语言模型的应用流程分为两个关键步骤:

  1. 预训练阶段:在大规模无标注语料(如维基百科、书籍语料)上,通过设计特定的预训练任务(如掩码语言模型、下一句预测),让模型学习语言的语法、语义和常识知识,得到通用的语言表示模型。
  2. 微调阶段:针对具体的 NLP 任务(如文本分类、命名实体识别、机器翻译),在预训练模型的基础上,添加少量任务相关的输出层,使用少量标注数据进行训练,得到任务专用模型。

注意:预训练阶段通常需要海量的计算资源和数据,一般由大厂或研究机构完成。普通开发者只需下载预训练好的模型权重,直接进行微调即可。

BERT 模型架构与预训练任务详解

BERT 的核心架构

BERT 的全称是 Bidirectional Encoder Representations from Transformers,即基于 Transformer 编码器的双向表示模型。它的核心架构是多层双向 Transformer 编码器,没有解码器部分。

BERT 的模型结构有两个版本,满足不同的算力需求:

  • BERT-Base:12 层 Transformer 编码器,12 个注意力头,隐藏层维度 768,参数量约 110M。
  • BERT-Large:24 层 Transformer 编码器,16 个注意力头,隐藏层维度 1024,参数量约 340M。

BERT 的输入表示是三种嵌入的求和:

  1. 词嵌入(Token Embedding):表示每个词的基础语义信息。
  2. 分段嵌入(Segment Embedding):用于区分两个句子(如判断句子是否为上下文关系),取值为 0 或 1。
  3. 位置嵌入(Position Embedding):和 Transformer 一样,用于注入词的位置信息,因为 Transformer 本身是无序的。
import tensorflow as tf
from transformers import BertConfig, BertModel

# 加载 BERT-Base 配置
config = BertConfig.from_pretrained('bert-base-uncased')
# 初始化 BERT 模型
bert_model = BertModel.from_pretrained('bert-base-uncased')

# 模拟输入:batch_size=2,sequence_length=10
input_ids = tf.random.randint(0, config.vocab_size, (2, 10))
attention_mask = tf.ones((2, 10))  # 1 表示有效 token,0 表示填充 token
token_type_ids = tf.zeros((2, 10))  # 0 表示第一个句子,1 表示第二个句子

# 获取 BERT 输出
outputs = bert_model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    token_type_ids=token_type_ids
)
last_hidden_state = outputs.last_hidden_state  # 最后一层隐藏状态,shape=(2,10,768)
pooler_output = outputs.pooler_output  # 特殊 token [CLS] 的输出,shape=(2,768)

print("最后一层隐藏状态形状:", last_hidden_state.shape)
print("CLS token 输出形状:", pooler_output.shape)

BERT 的预训练任务

BERT 的预训练包含两个核心任务,通过这两个任务让模型学习双向上下文信息:

  1. 掩码语言模型(Masked Language Model, MLM)
    • 随机选择 15% 的 token 进行掩码处理:80% 的概率替换为 [MASK],10% 的概率替换为随机 token,10% 的概率保持原 token 不变。
    • 模型的任务是预测被掩码的 token 的原始值。这个任务强制模型学习上下文的双向依赖关系,因为要预测掩码 token,必须同时考虑前后文的信息。
  2. 下一句预测(Next Sentence Prediction, NSP)
    • 输入一对句子(A 和 B),50% 的概率 B 是 A 的真实下一句,50% 的概率 B 是随机选择的句子。
    • 模型的任务是判断 B 是否是 A 的下一句。这个任务让模型学习句子之间的逻辑关系,适用于问答、文本摘要等需要理解句子关系的任务。

注意:后续的研究发现,NSP 任务对部分下游任务的提升有限,甚至可能带来负面影响。因此,一些改进版的 BERT 模型(如 RoBERTa)取消了 NSP 任务。

Hugging Face Transformers 库快速上手

Hugging Face Transformers 是目前最流行的预训练语言模型工具库,它提供了包括 BERT、GPT、RoBERTa、T5 等在内的数百种预训练模型的实现,支持 TensorFlow 和 PyTorch 两种框架,极大简化了预训练模型的使用流程。

安装与环境配置

首先安装 Transformers 库和相关依赖:

pip install transformers datasets tensorflow

核心组件介绍

Transformers 库的核心组件包括:

  • Config:存储模型的配置信息,如层数、隐藏层维度、注意力头数等。
  • Tokenizer:负责文本的预处理,包括分词、转换为 token id、添加特殊 token、填充和截断等。
  • Model:预训练模型的核心代码,不同的模型对应不同的 Model 类,如 BertModel、BertForSequenceClassification 等。

实战:基于 BERT 的中文文本分类任务

任务介绍与数据集准备

本次实战任务是中文新闻文本分类。我们使用 THUCNews 数据集的子集,包含 10 个新闻类别:体育、娱乐、家居、房产、教育、时尚、时政、游戏、科技、财经。我们的目标是基于 BERT-base-chinese 模型,搭建文本分类模型,实现对新闻类别的自动判断。

  1. 加载 THUCNews 子集数据集,划分训练集、验证集和测试集
  2. 使用 BertTokenizer 对文本进行分词处理,转换为模型可接受的输入格式
  3. 设置序列最大长度为 128,对过长的文本进行截断,过短的文本进行填充
from datasets import load_dataset
from transformers import BertTokenizerFast

# 加载数据集(这里使用本地的 THUCNews 子集,也可以使用 Hugging Face Hub 上的公开数据集)
dataset = load_dataset('csv', data_files={
    'train': 'thucnews_train.csv',
    'val': 'thucnews_val.csv',
    'test': 'thucnews_test.csv'
})

# 加载中文 BERT 分词器
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')

# 定义文本预处理函数
def preprocess_function(examples):
    # 对文本进行分词、转换为 token id、填充和截断
    return tokenizer(
        examples['text'],
        max_length=128,
        padding='max_length',
        truncation=True
    )

# 对数据集进行预处理
tokenized_dataset = dataset.map(preprocess_function, batched=True)

# 重命名标签列,适配模型输入
tokenized_dataset = tokenized_dataset.rename_column('label', 'labels')

# 设置数据集格式为 TensorFlow 格式
tokenized_dataset.set_format(type='tensorflow', columns=['input_ids','attention_mask','labels'])

# 生成训练集和验证集的 tf.data.Dataset
batch_size = 32
train_dataset = tokenized_dataset['train'].to_tf_dataset(
    columns=['input_ids','attention_mask'],
    label_cols=['labels'],
    batch_size=batch_size,
    shuffle=True
)
val_dataset = tokenized_dataset['val'].to_tf_dataset(
    columns=['input_ids','attention_mask'],
    label_cols=['labels'],
    batch_size=batch_size,
    shuffle=False
)

搭建 BERT 文本分类模型

我们使用 BertForSequenceClassification 类,它是 BERT 模型针对序列分类任务的专用版本。它在 BERT 的输出层后,添加了一个全连接层,用于将 [CLS] token 的输出映射到分类标签空间。

from transformers import TFBertForSequenceClassification

# 加载 BERT 中文预训练模型,指定分类类别数为 10
model = TFBertForSequenceClassification.from_pretrained(
    'bert-base-chinese',
    num_labels=10,
    problem_type='single_label_classification'
)

# 编译模型
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=2e-5),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy('accuracy')]
)

# 查看模型结构
model.summary()

注意:BERT 模型的学习率通常设置为 2e-5 或 5e-5,远小于普通深度学习模型的学习率。这是因为预训练模型已经学习了丰富的语言知识,过高的学习率会破坏预训练的权重。

模型微调与评估

  1. 设置训练参数,训练轮数设置为 3 轮(BERT 模型微调通常不需要太多轮数,否则容易过拟合)
  2. 使用验证集监控模型性能,保存最佳模型
  3. 在测试集上评估模型的最终性能
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# 定义回调函数
callbacks = [
    # 早停:当验证集损失不再下降时停止训练
    EarlyStopping(monitor='val_loss', patience=1, restore_best_weights=True),
    # 保存最佳模型
    ModelCheckpoint('best_bert_thucnews.h5', monitor='val_accuracy', save_best_only=True)
]

# 开始微调模型
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=3,
    callbacks=callbacks
)

# 加载测试集
test_dataset = tokenized_dataset['test'].to_tf_dataset(
    columns=['input_ids','attention_mask'],
    label_cols=['labels'],
    batch_size=batch_size,
    shuffle=False
)

# 在测试集上评估模型
test_loss, test_acc = model.evaluate(test_dataset)
print(f"测试集损失:{test_loss:.4f}")
print(f"测试集准确率:{test_acc:.4f}")

模型预测与推理

训练完成后,我们可以使用模型对新的中文文本进行分类预测:

# 定义预测函数
def predict_text_category(text):
    # 预处理文本
    inputs = tokenizer(
        text,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='tf'
    )
    # 获取预测结果
    outputs = model(inputs)
    logits = outputs.logits
    # 转换为类别概率
    probabilities = tf.nn.softmax(logits, axis=-1)
    # 获取预测类别
    predicted_label = tf.argmax(probabilities, axis=-1).numpy()[0]
    # 类别映射字典
    label_map = {
        0:'体育', 1:'娱乐', 2:'家居', 3:'房产', 4:'教育',
        5:'时尚', 6:'时政', 7:'游戏', 8:'科技', 9:'财经'
    }
    return label_map[predicted_label]

# 测试预测
test_text = "北京时间 10 月 1 日,2024 年巴黎奥运会男篮决赛在法兰西体育场举行,美国队以 102-87 击败法国队,夺得金牌。"
print(f"文本内容:{test_text}")
print(f"预测类别:{predict_text_category(test_text)}")

BERT 模型的优化与改进方向

模型优化技巧

技巧 1:使用学习率调度器。在微调过程中,使用线性学习率衰减策略,让学习率随着训练轮数的增加而逐渐降低,提升模型的泛化能力。

技巧 2:使用梯度累积。当显存不足时,可以使用梯度累积技术,将多个小批次的梯度累积起来,再进行一次参数更新,相当于增大了批次大小。

技巧 3:使用知识蒸馏。将大模型(如 BERT-Large)的知识蒸馏到小模型(如 DistilBERT)中,在保证性能损失较小的前提下,显著提升模型的推理速度。

BERT 的改进模型

BERT 提出后,研究者们提出了许多改进版本,进一步提升了模型性能:

  • RoBERTa:取消了 NSP 任务,使用更大的批次大小和更多的训练数据,性能全面超越 BERT。
  • ALBERT:通过参数共享技术,大幅减少模型参数量,提升训练效率。
  • ERNIE:百度提出的中文增强版 BERT,通过引入实体级和短语级的掩码策略,提升了模型对中文语义的理解能力。
  • SpanBERT:将掩码单位从单个 token 改为连续的 token span,提升了模型对短语和实体的建模能力。

本章总结

预训练语言模型采用'预训练 + 微调'的范式,先在大规模无标注语料上学习通用语言知识,再针对具体任务进行微调。

BERT 是基于双向 Transformer 编码器的预训练模型,通过掩码语言模型和下一句预测任务,实现了双向上下文理解。

使用 Hugging Face Transformers 库可以快速调用 BERT 模型,只需少量代码即可完成中文文本分类等任务的微调。

BERT 模型的微调需要注意学习率的设置,通常使用 2e-5 或 5e-5 的小学习率,避免破坏预训练权重。

目录

  1. 核心目标与重点
  2. 预训练模型的发展脉络
  3. 预训练 + 微调的核心流程
  4. BERT 模型架构与预训练任务详解
  5. BERT 的核心架构
  6. 加载 BERT-Base 配置
  7. 初始化 BERT 模型
  8. 模拟输入:batchsize=2,sequencelength=10
  9. 获取 BERT 输出
  10. BERT 的预训练任务
  11. Hugging Face Transformers 库快速上手
  12. 安装与环境配置
  13. 核心组件介绍
  14. 实战:基于 BERT 的中文文本分类任务
  15. 任务介绍与数据集准备
  16. 加载数据集(这里使用本地的 THUCNews 子集,也可以使用 Hugging Face Hub 上的公开数据集)
  17. 加载中文 BERT 分词器
  18. 定义文本预处理函数
  19. 对数据集进行预处理
  20. 重命名标签列,适配模型输入
  21. 设置数据集格式为 TensorFlow 格式
  22. 生成训练集和验证集的 tf.data.Dataset
  23. 搭建 BERT 文本分类模型
  24. 加载 BERT 中文预训练模型,指定分类类别数为 10
  25. 编译模型
  26. 查看模型结构
  27. 模型微调与评估
  28. 定义回调函数
  29. 开始微调模型
  30. 加载测试集
  31. 在测试集上评估模型
  32. 模型预测与推理
  33. 定义预测函数
  34. 测试预测
  35. BERT 模型的优化与改进方向
  36. 模型优化技巧
  37. BERT 的改进模型
  38. 本章总结
  • 💰 8折买阿里云服务器限时8折了解详情
  • GPT-5.5 超高智商模型1元抵1刀ChatGPT中转购买
  • 代充Chatgpt Plus/pro 帐号了解详情
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • Linux 命名管道(FIFO)通信:原理与跨进程实战
  • 原生 JavaScript 跨浏览器事件处理兼容方案
  • faster-whisper 快速部署与核心功能实战
  • Ollama 免费接入 Gemini-3-Pro 模型并配置 AI Coding 工具
  • Google AI Studio 区域限制及年龄验证解决方法,Three.js 简介
  • AIGC 时代的医学统计学:Python 数据分析实战
  • Raphael AI:基于 Flux 模型的免费图像生成工具评测
  • Kubernetes 多版本同步更新:修复 Go 安全漏洞并升级编译器
  • DeerFlow 2.0 开源:字节跳动超级智能体框架技术解析
  • Linux 多线程互斥与同步机制解析
  • Git 疑难问题诊疗指南
  • 无人机航拍小目标检测:YOLO11 实战与 PyQt6 桌面应用
  • 数据结构:二叉树基础与链式存储实现
  • 数据结构:二叉树基础与 C 语言实现
  • iRobotCAM 机器人离线编程软件在激光加工中的应用优势
  • OpenClaw 搭建 QQ AI 办公机器人:关键词触发与邮件集成
  • 双足机器人 2-RSS-1U 并联踝关节设计与运动学解析
  • 深度学习模型优化策略与实战调参
  • MinHash 大规模文本近似去重策略详解
  • STL 容器适配器 stack 与 queue 底层模拟及算法实战

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online