跳到主要内容
BERT 文本分类实战:代码逐行注释与原理详解 | 极客日志
Python AI 算法
BERT 文本分类实战:代码逐行注释与原理详解 综述由AI生成 如何使用 BERT 模型进行文本分类任务。内容涵盖文本分类概念、BERT 原理、环境配置、数据预处理、模型构建、训练循环设计及推理预测全流程。基于 PyTorch 和 Hugging Face Transformers 库,以 IMDB 电影评论数据集为例,展示了从数据加载到模型评估的完整代码实现,并对关键参数和超调优策略进行了说明,旨在帮助开发者快速掌握 BERT 微调技术。
修罗 发布于 2025/2/7 更新于 2026/6/3 18 浏览BERT 文本分类实战:基于 PyTorch 的 IMDB 情感分析详解
1. 引言
1.1 什么是文本分类?
文本分类是自然语言处理(NLP)和机器学习领域的基础任务之一。其核心目标是教会计算机根据文本内容将其划分到预定义的类别中。
这通常被视为一种监督学习任务,意味着我们需要使用已经标注好类别的文本数据对算法进行训练。在模型训练完成后,它便能根据所学到的规律对新的、未标注的文本进行分类。算法通过分析单词、短语及其上下文中的特征来决定文本归属,类似于人类通过观察花的特征来辨别花种。
常见应用场景包括:
垃圾邮件过滤: 将电子邮件自动分为垃圾邮件或非垃圾邮件,基于特定关键词模式(如'中奖'、'点击链接')。
情感分析: 分析社交媒体帖子的情感倾向,检测仇恨言论或负面情绪。
新闻分类: 将新闻或视频归类为科技、体育、娱乐等主题,帮助用户快速筛选感兴趣的内容。
客服意图识别: 自动识别用户咨询的意图,如退款、查询物流或技术支持。
1.2 什么是 BERT?
BERT(Bidirectional Encoder Representations from Transformers,基于双向编码器表示的变换器)是由 Google 开发的一种强大的 NLP 模型。它基于 Transformer 的深度神经网络架构,能够生成高质量的文本表示。
BERT 的核心优势:
双向上下文理解: 与传统 NLP 模型按顺序处理文本不同,BERT 能够一次性处理整个输入序列,同时捕捉左侧和右侧的上下文信息。
预训练机制: BERT 已经在大规模语料库(如书籍、文章和网页)上进行了预训练,掌握了丰富的语言结构和语义知识。
微调适应性: 通过微调(Fine-tuning),BERT 可以快速适应特定的下游任务,如文本分类、问答系统等,无需从零开始训练。
本文将以计算效率与性能平衡较好的 bert-base-uncased 版本为例,在 IMDB 电影评论数据集上实现情感分类任务。
2. 环境准备与依赖导入
在开始之前,请确保已安装以下 Python 库:
torch: PyTorch 深度学习框架
transformers: Hugging Face 提供的 BERT 模型库
scikit-learn: 数据处理与评估工具
pandas: 数据处理与分析
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import pandas pd
as
3. 数据加载与预处理 IMDB 数据集包含 50,000 条评论,其中一半为正面评价,一半为负面评价。我们将首先加载数据并将其转换为数值标签。
3.1 加载函数 def load_imdb_data (data_file ):
"""
加载 IMDB 数据集并转换为文本列表和标签列表
:param data_file: CSV 文件路径
:return: texts (列表), labels (列表)
"""
df = pd.read_csv(data_file)
texts = df['review' ].tolist()
labels = [1 if sentiment == "positive" else 0 for sentiment in df['sentiment' ].tolist()]
return texts, labels
try :
data_file = "IMDB Dataset.csv"
texts, labels = load_imdb_data(data_file)
except FileNotFoundError:
print ("未找到数据集文件,请检查路径。" )
texts, labels = [], []
4. 自定义数据集类 PyTorch 的 Dataset 类允许我们自定义数据加载逻辑。我们需要将文本编码为 BERT 可接受的输入格式(Input IDs 和 Attention Mask)。
class TextClassificationDataset (Dataset ):
def __init__ (self, texts, labels, tokenizer, max_length ):
self .texts = texts
self .labels = labels
self .tokenizer = tokenizer
self .max_length = max_length
def __len__ (self ):
return len (self .texts)
def __getitem__ (self, idx ):
text = self .texts[idx]
label = self .labels[idx]
encoding = self .tokenizer(
text,
return_tensors='pt' ,
max_length=self .max_length,
padding='max_length' ,
truncation=True
)
return {
'input_ids' : encoding['input_ids' ].squeeze(0 ),
'attention_mask' : encoding['attention_mask' ].squeeze(0 ),
'label' : torch.tensor(label)
}
5. 构建 BERT 分类器 我们需要在预训练的 BERT 模型之上添加一个分类头(Classifier Head)。BERT 的输出经过池化层后,接入全连接层进行二分类。
class BERTClassifier (nn.Module):
def __init__ (self, bert_model_name, num_classes ):
super (BERTClassifier, self ).__init__()
self .bert = BertModel.from_pretrained(bert_model_name)
self .dropout = nn.Dropout(0.1 )
self .fc = nn.Linear(self .bert.config.hidden_size, num_classes)
def forward (self, input_ids, attention_mask ):
outputs = self .bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
x = self .dropout(pooled_output)
logits = self .fc(x)
return logits
6. 定义训练与评估函数
6.1 训练函数 训练过程包括前向传播、损失计算、反向传播和优化器更新。
def train (model, data_loader, optimizer, scheduler, device ):
model.train()
total_loss = 0
for batch in data_loader:
optimizer.zero_grad()
input_ids = batch['input_ids' ].to(device)
attention_mask = batch['attention_mask' ].to(device)
labels = batch['label' ].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
loss = nn.CrossEntropyLoss()(outputs, labels)
loss.backward()
optimizer.step()
scheduler.step()
total_loss += loss.item()
return total_loss / len (data_loader)
6.2 评估函数 评估阶段禁用梯度计算以提升效率,并计算准确率及分类报告。
def evaluate (model, data_loader, device ):
model.eval ()
all_labels = []
all_preds = []
with torch.no_grad():
for batch in data_loader:
input_ids = batch['input_ids' ].to(device)
attention_mask = batch['attention_mask' ].to(device)
labels = batch['label' ].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
preds = torch.argmax(outputs, dim=1 )
all_labels.extend(labels.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
acc = accuracy_score(all_labels, all_preds)
report = classification_report(
all_labels, all_preds,
target_names=["负面" , "正面" ],
digits=4
)
return acc, report
7. 超参数设置与模型初始化
bert_model_name = "bert-base-uncased"
max_length = 128
batch_size = 16
num_epochs = 4
learning_rate = 2e-5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu" )
print (f"使用设备:{device} " )
model = BERTClassifier(bert_model_name=bert_model_name, num_classes=2 )
model.to(device)
optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8 )
total_steps = len (train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=0 ,
num_training_steps=total_steps
)
8. 数据集划分与加载器创建 将数据划分为训练集和验证集,并使用 DataLoader 进行批量加载。
train_texts, val_texts, train_labels, val_labels = train_test_split(
texts, labels, test_size=0.2 , random_state=42
)
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, max_length)
val_dataset = TextClassificationDataset(val_texts, val_labels, tokenizer, max_length)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True )
val_loader = DataLoader(val_dataset, batch_size=batch_size)
9. 模型训练流程 执行多轮迭代训练,并在每个 Epoch 结束后验证模型性能。
for epoch in range (num_epochs):
print (f"\n训练第 {epoch+1 } /{num_epochs} 轮..." )
train(model, train_loader, optimizer, scheduler, device)
acc, report = evaluate(model, val_loader, device)
print (f"验证集准确率:{acc:.4 f} " )
print (f"分类报告:\n{report} " )
10. 推理与新数据预测 def predict (model, text, tokenizer, max_length, device ):
model.eval ()
encoding = tokenizer(
text,
return_tensors='pt' ,
max_length=max_length,
padding='max_length' ,
truncation=True
)
input_ids = encoding['input_ids' ].to(device)
attention_mask = encoding['attention_mask' ].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
pred = torch.argmax(outputs, dim=1 ).cpu().numpy()[0 ]
return "正面" if pred == 1 else "负面"
example_text = "The movie was absolutely amazing, I loved it!"
prediction = predict(model, example_text, tokenizer, max_length, device)
print (f"评论:{example_text} " )
print (f"预测结果:{prediction} " )
11. 总结与扩展建议 通过上述步骤,我们成功利用 BERT 实现了高效的文本分类器,并在 IMDB 数据集上完成了情感分析任务。BERT 的强大语言理解能力使得我们无需从零开始训练,大大降低了实现难度。
模型保存与加载: 生产环境中应保存 state_dict 以便后续部署。
超参数调优: 尝试不同的学习率、Batch Size 和 Max Length 以获得更高精度。
数据增强: 对于小样本场景,可使用回译等技术扩充训练数据。
分布式训练: 若数据量巨大,可考虑使用多 GPU 并行训练加速。
本教程提供了完整的代码框架与原理讲解,可作为 NLP 入门及 BERT 应用的基础参考。
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
随机西班牙地址生成器 随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
Gemini 图片去水印 基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online