BERT 文本分类实战:代码逐行注释与原理详解
本文详细介绍了如何使用 BERT 模型进行文本分类任务。内容涵盖文本分类概念、BERT 原理、环境配置、数据预处理、模型构建、训练循环设计及推理预测全流程。基于 PyTorch 和 Hugging Face Transformers 库,以 IMDB 电影评论数据集为例,展示了从数据加载到模型评估的完整代码实现,并对关键参数和超调优策略进行了说明,旨在帮助开发者快速掌握 BERT 微调技术。

本文详细介绍了如何使用 BERT 模型进行文本分类任务。内容涵盖文本分类概念、BERT 原理、环境配置、数据预处理、模型构建、训练循环设计及推理预测全流程。基于 PyTorch 和 Hugging Face Transformers 库,以 IMDB 电影评论数据集为例,展示了从数据加载到模型评估的完整代码实现,并对关键参数和超调优策略进行了说明,旨在帮助开发者快速掌握 BERT 微调技术。

文本分类是自然语言处理(NLP)和机器学习领域的基础任务之一。其核心目标是教会计算机根据文本内容将其划分到预定义的类别中。
这通常被视为一种监督学习任务,意味着我们需要使用已经标注好类别的文本数据对算法进行训练。在模型训练完成后,它便能根据所学到的规律对新的、未标注的文本进行分类。算法通过分析单词、短语及其上下文中的特征来决定文本归属,类似于人类通过观察花的特征来辨别花种。
常见应用场景包括:
BERT(Bidirectional Encoder Representations from Transformers,基于双向编码器表示的变换器)是由 Google 开发的一种强大的 NLP 模型。它基于 Transformer 的深度神经网络架构,能够生成高质量的文本表示。
BERT 的核心优势:
本文将以计算效率与性能平衡较好的 bert-base-uncased 版本为例,在 IMDB 电影评论数据集上实现情感分类任务。
在开始之前,请确保已安装以下 Python 库:
torch: PyTorch 深度学习框架transformers: Hugging Face 提供的 BERT 模型库scikit-learn: 数据处理与评估工具pandas: 数据处理与分析import os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import pandas as pd
IMDB 数据集包含 50,000 条评论,其中一半为正面评价,一半为负面评价。我们将首先加载数据并将其转换为数值标签。
def load_imdb_data(data_file):
"""
加载 IMDB 数据集并转换为文本列表和标签列表
:param data_file: CSV 文件路径
:return: texts (列表), labels (列表)
"""
df = pd.read_csv(data_file) # 读取 CSV 文件
texts = df['review'].tolist() # 提取评论文本
# 将情感转化为数值:positive -> 1, negative -> 0
labels = [1 if sentiment == "positive" else 0 for sentiment in df['sentiment'].tolist()]
return texts, labels
# 指定数据集路径并加载数据
# 注意:实际使用时请确保 'IMDB Dataset.csv' 存在于当前目录
try:
data_file = "IMDB Dataset.csv"
texts, labels = load_imdb_data(data_file)
except FileNotFoundError:
print("未找到数据集文件,请检查路径。")
texts, labels = [], []
PyTorch 的 Dataset 类允许我们自定义数据加载逻辑。我们需要将文本编码为 BERT 可接受的输入格式(Input IDs 和 Attention Mask)。
class TextClassificationDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
# 分词与编码
encoding = self.tokenizer(
text,
return_tensors='pt',
max_length=self.max_length,
padding='max_length', # 填充至最大长度
truncation=True # 超过长度则截断
)
return {
'input_ids': encoding['input_ids'].squeeze(0), # 移除 batch 维度
'attention_mask': encoding['attention_mask'].squeeze(0),
'label': torch.tensor(label)
}
我们需要在预训练的 BERT 模型之上添加一个分类头(Classifier Head)。BERT 的输出经过池化层后,接入全连接层进行二分类。
class BERTClassifier(nn.Module):
def __init__(self, bert_model_name, num_classes):
super(BERTClassifier, self).__init__()
# 加载预训练 BERT 模型
self.bert = BertModel.from_pretrained(bert_model_name)
# Dropout 层防止过拟合
self.dropout = nn.Dropout(0.1)
# 全连接层:从 BERT 隐藏层大小映射到类别数
self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output # 获取 [CLS] 标记的输出作为句子表示
x = self.dropout(pooled_output)
logits = self.fc(x) # 输出分类得分
return logits
训练过程包括前向传播、损失计算、反向传播和优化器更新。
def train(model, data_loader, optimizer, scheduler, device):
model.train()
total_loss = 0
for batch in data_loader:
optimizer.zero_grad() # 清除上一次的梯度
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['label'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
loss = nn.CrossEntropyLoss()(outputs, labels) # 交叉熵损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
scheduler.step() # 调整学习率
total_loss += loss.item()
return total_loss / len(data_loader)
评估阶段禁用梯度计算以提升效率,并计算准确率及分类报告。
def evaluate(model, data_loader, device):
model.eval()
all_labels = []
all_preds = []
with torch.no_grad(): # 禁用梯度计算
for batch in data_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['label'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
preds = torch.argmax(outputs, dim=1) # 获取预测标签
all_labels.extend(labels.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
acc = accuracy_score(all_labels, all_preds)
report = classification_report(
all_labels, all_preds,
target_names=["负面", "正面"],
digits=4
)
return acc, report
合理的超参数选择对模型收敛至关重要。
# 定义参数
bert_model_name = "bert-base-uncased"
max_length = 128 # 文本最大长度,BERT 默认支持 512,但 128 通常足够且更快
batch_size = 16 # 批次大小,取决于 GPU 显存
num_epochs = 4 # 训练轮数
learning_rate = 2e-5 # 学习率,BERT 微调常用值
# 设备选择
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备:{device}")
# 初始化模型
model = BERTClassifier(bert_model_name=bert_model_name, num_classes=2)
model.to(device)
# 优化器和学习率调度器
optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8)
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=0,
num_training_steps=total_steps
)
将数据划分为训练集和验证集,并使用 DataLoader 进行批量加载。
# 划分训练集和验证集 (80% 训练,20% 验证)
train_texts, val_texts, train_labels, val_labels = train_test_split(
texts, labels, test_size=0.2, random_state=42
)
# 初始化分词器
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
# 创建自定义数据集
train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, max_length)
val_dataset = TextClassificationDataset(val_texts, val_labels, tokenizer, max_length)
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
执行多轮迭代训练,并在每个 Epoch 结束后验证模型性能。
for epoch in range(num_epochs):
print(f"\n训练第 {epoch+1}/{num_epochs} 轮...")
train(model, train_loader, optimizer, scheduler, device)
acc, report = evaluate(model, val_loader, device)
print(f"验证集准确率:{acc:.4f}")
print(f"分类报告:\n{report}")
# 可选:保存最佳模型
# torch.save(model.state_dict(), f"best_model_epoch_{epoch}.pth")
训练完成后,可以使用模型对新数据进行情感预测。
def predict(model, text, tokenizer, max_length, device):
model.eval()
# 编码输入文本
encoding = tokenizer(
text,
return_tensors='pt',
max_length=max_length,
padding='max_length',
truncation=True
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
# 获取预测
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
pred = torch.argmax(outputs, dim=1).cpu().numpy()[0]
return "正面" if pred == 1 else "负面"
# 示例预测
example_text = "The movie was absolutely amazing, I loved it!"
prediction = predict(model, example_text, tokenizer, max_length, device)
print(f"评论:{example_text}")
print(f"预测结果:{prediction}")
通过上述步骤,我们成功利用 BERT 实现了高效的文本分类器,并在 IMDB 数据集上完成了情感分析任务。BERT 的强大语言理解能力使得我们无需从零开始训练,大大降低了实现难度。
后续优化方向:
state_dict 以便后续部署。本教程提供了完整的代码框架与原理讲解,可作为 NLP 入门及 BERT 应用的基础参考。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online