
Python 开源 AI 模型引入与测试全流程
Python 开源 AI 模型引入与测试全流程。涵盖环境配置、BERT 模型加载、IMDB 数据集预处理、模型微调训练、性能评估指标计算、Pytest 单元测试与集成测试、FastAPI 服务实现、Docker 容器化部署及监控日志配置。提供完整可执行代码示例与工程化最佳实践,助力开发者构建生产级 AI 应用系统。

Python 开源 AI 模型引入与测试全流程。涵盖环境配置、BERT 模型加载、IMDB 数据集预处理、模型微调训练、性能评估指标计算、Pytest 单元测试与集成测试、FastAPI 服务实现、Docker 容器化部署及监控日志配置。提供完整可执行代码示例与工程化最佳实践,助力开发者构建生产级 AI 应用系统。


本文详细介绍了在 Python 环境中引入开源 AI 模型并进行全面测试的完整技术流程。我们将以 Hugging Face Transformers 库中的 BERT 模型为例,从环境配置、模型加载、数据处理、模型训练与微调、性能评估到部署测试,提供一套完整的可执行方案。文章包含详细的原理解析、代码实现和命令操作,帮助开发者掌握开源 AI 模型集成的最佳实践。
开源 AI 模型已成为现代人工智能应用的核心组成部分。从 Google 的 BERT 到 OpenAI 的 GPT 系列,再到 Meta 的 Llama,开源模型推动了 AI 技术的民主化进程。Hugging Face 作为目前最流行的开源 AI 模型社区,提供了超过 10 万个预训练模型和 1 万个数据集。
本文选择以下技术栈:
实现一个完整的 BERT 文本分类模型引入流程,包括:
# 检查系统环境
python --version # Python 3.8+
nvidia-smi # GPU 支持(可选但推荐)
# 创建项目目录
mkdir openai-introduction && cd openai-introduction
# 创建虚拟环境
python -m venv venv
# 激活虚拟环境
# Linux/Mac
source venv/bin/activate
# Windows
venv\Scripts\activate
创建 requirements.txt:
# 核心 AI 库
torch>=2.0.0
transformers>=4.30.0
datasets>=2.12.0
accelerate>=0.20.0
# 数据处理
numpy>=1.24.0
pandas>=2.0.0
scikit-learn>=1.3.0
# 实验跟踪
wandb>=0.15.0
tensorboard>=2.13.0
# API 服务
fastapi>=0.100.0
uvicorn[standard]>=0.23.0
pydantic>=2.0.0
# 测试工具
pytest>=7.4.0
hypothesis>=6.82.0
pytest-benchmark>=4.0.0
# 开发工具
black>=23.0.0
flake8>=6.0.0
mypy>=1.5.0
pre-commit>=3.3.0
# 模型优化
optimum>=1.12.0
onnxruntime>=1.15.0
# 其他工具
jupyter>=1.0.0
ipython>=8.14.0
matplotlib>=3.7.0
seaborn>=0.12.0
# 安装基础依赖
pip install -r requirements.txt
# 安装带 CUDA 支持的 PyTorch(如使用 GPU)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
创建项目目录结构:
openai-introduction/
├── src/
│ ├── __init__.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── processor.py
│ │ └── dataset.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── bert_classifier.py
│ │ └── model_utils.py
│ ├── training/
│ │ ├── __init__.py
│ │ ├── trainer.py
│ │ └── optimizer.py
│ ├── evaluation/
│ │ ├── __init__.py
│ │ ├── metrics.py
│ │ └── visualization.py
│ └── api/
│ ├── __init__.py
│ ├── app.py
│ └── schemas.py
├── tests/
│ ├── __init__.py
│ ├── test_data.py
│ ├── test_model.py
│ ├── test_training.py
│ └── test_api.py
├── notebooks/
│ ├── 01_exploratory_analysis.ipynb
│ └── 02_model_experiments.ipynb
├── configs/
│ ├── base_config.yaml
│ └── train_config.yaml
├── scripts/
│ ├── train.py
│ ├── evaluate.py
│ └── deploy.py
├── .pre-commit-config.yaml
├── Dockerfile
├── docker-compose.yml
├── pyproject.toml
├── README.md
└── requirements.txt
BERT(Bidirectional Encoder Representations from Transformers)是基于 Transformer 编码器的预训练语言模型。其核心创新在于双向上下文理解,通过 Masked Language Model(MLM)和 Next Sentence Prediction(NSP)任务进行预训练。
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
"""多头注意力机制实现"""
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert embed_dim % num_heads == 0
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.scaling = self.head_dim ** -0.5
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = query.size(0)
# 线性变换并 reshape 为多头
q = self.q_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 计算注意力分数
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scaling
if attention_mask is not None:
attn_scores = attn_scores.masked_fill(attention_mask == 0, -1e9)
attn_probs = F.softmax(attn_scores, dim=-1)
attn_probs = self.dropout(attn_probs)
# 应用注意力权重
attn_output = torch.matmul(attn_probs, v)
# 合并多头
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.embed_dim
)
attn_output = self.out_proj(attn_output)
return attn_output, attn_probs
class TransformerEncoderLayer(nn.Module):
"""Transformer 编码器层"""
def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(embed_dim, num_heads, dropout)
self.attn_norm = nn.LayerNorm(embed_dim)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(ff_dim, embed_dim),
nn.Dropout(dropout)
)
self.ffn_norm = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
# 自注意力
residual = x
attn_output, attn_weights = self.self_attn(x, x, x, attention_mask)
x = self.attn_norm(residual + self.dropout(attn_output))
# 前馈网络
residual = x
x = self.ffn_norm(residual + self.ffn(x))
return x, attn_weights
Hugging Face Transformers 库提供了统一的 API 接口,支持多种预训练模型。其核心设计模式基于 PreTrainedModel 和 PreTrainedTokenizer 基类。
from transformers import (
BertConfig,
BertModel,
BertTokenizer,
PreTrainedModel,
PretrainedConfig
)
from transformers.modeling_outputs import SequenceClassifierOutput
import torch.nn as nn
class BertForSequenceClassification(PreTrainedModel):
"""基于 BERT 的序列分类模型"""
def __init__(self, config: BertConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
# 加载预训练 BERT 模型
self.bert = BertModel(config)
# 分类头
self.classifier = nn.Sequential(
nn.Dropout(config.hidden_dropout_prob),
nn.Linear(config.hidden_size, config.hidden_size),
nn.GELU(),
nn.Dropout(config.hidden_dropout_prob),
nn.Linear(config.hidden_size, config.num_labels)
)
# 初始化权重
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs
) -> SequenceClassifierOutput:
# BERT 前向传播
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
**kwargs
)
# 使用 [CLS] token 的表示
pooled_output = outputs.pooler_output
# 分类
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions
)
我们使用 IMDB 电影评论数据集进行情感分类任务。
from datasets import load_dataset, DatasetDict
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import torch
class DataProcessor:
"""数据处理器"""
def __init__(self, model_name: str = "bert-base-uncased", max_length: int = 512):
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.max_length = max_length
def load_imdb_dataset(self, cache_dir: str = "./data"):
"""加载 IMDB 数据集"""
# 从 Hugging Face 加载数据集
dataset = load_dataset("imdb", cache_dir=cache_dir)
# 划分验证集
train_test_split = dataset["train"].train_test_split(test_size=0.1, seed=42)
dataset_dict = DatasetDict({
"train": train_test_split["train"],
"validation": train_test_split["test"],
"test": dataset["test"]
})
return dataset_dict
def preprocess_function(self, examples):
"""预处理函数"""
# 对文本进行分词
tokenized_inputs = self.tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt"
)
# 转换为列表格式
return {
"input_ids": tokenized_inputs["input_ids"].tolist(),
"attention_mask": tokenized_inputs["attention_mask"].tolist(),
"labels": examples["label"]
}
def prepare_dataset(self, dataset_dict, batch_size: int = 32):
"""准备数据集"""
# 应用预处理
tokenized_datasets = dataset_dict.map(
self.preprocess_function,
batched=True,
remove_columns=["text", "label"]
)
# 设置格式
tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
# 创建数据加载器
train_dataloader = DataLoader(
tokenized_datasets["train"],
shuffle=True,
batch_size=batch_size,
collate_fn=self.collate_fn
)
val_dataloader = DataLoader(
tokenized_datasets["validation"],
batch_size=batch_size,
collate_fn=self.collate_fn
)
test_dataloader = DataLoader(
tokenized_datasets["test"],
batch_size=batch_size,
collate_fn=self.collate_fn
)
return train_dataloader, val_dataloader, test_dataloader
def collate_fn(self, batch):
"""批处理函数"""
input_ids = torch.stack([item["input_ids"] for item in batch])
attention_mask = torch.stack([item["attention_mask"] for item in batch])
labels = torch.tensor([item["labels"] for item in batch])
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
def analyze_dataset(self, dataset_dict):
"""数据集分析"""
stats = {}
for split in ["train", "validation", "test"]:
dataset = dataset_dict[split]
stats[split] = {
"samples": len(dataset),
"positive": sum(dataset["label"]),
"negative": len(dataset) - sum(dataset["label"]),
"avg_length": sum(len(text.split()) for text in dataset["text"]) / len(dataset)
}
return pd.DataFrame(stats).T
import nlpaug.augmenter.word as naw
from typing import List
class DataAugmenter:
"""数据增强器"""
def __init__(self, aug_method: str = "synonym"):
self.aug_method = aug_method
if aug_method == "synonym":
self.augmenter = naw.SynonymAug(aug_src="wordnet")
elif aug_method == "contextual":
self.augmenter = naw.ContextualWordEmbsAug(
model_path='bert-base-uncased',
action="substitute"
)
elif aug_method == "back_translation":
self.augmenter = naw.BackTranslationAug(
from_model_name='facebook/wmt19-en-de',
to_model_name='facebook/wmt19-de-en'
)
else:
raise ValueError(f"Unsupported augmentation method: {aug_method}")
def augment_text(self, text: str, num_aug: int = 3) -> List[str]:
"""生成增强文本"""
augmented_texts = []
for _ in range(num_aug):
augmented_text = self.augmenter.augment(text)
augmented_texts.append(augmented_text)
return augmented_texts
def augment_dataset(self, dataset, num_aug_per_sample: int = 2):
"""增强整个数据集"""
augmented_texts = []
augmented_labels = []
for text, label in zip(dataset["text"], dataset["label"]):
# 原始样本
augmented_texts.append(text)
augmented_labels.append(label)
# 增强样本
for _ in range(num_aug_per_sample):
augmented_text = self.augmenter.augment(text)
augmented_texts.append(augmented_text)
augmented_labels.append(label)
return {"text": augmented_texts, "label": augmented_labels}
from dataclasses import dataclass
from typing import Optional, Dict, Any
import yaml
from transformers import TrainingArguments
@dataclass
class TrainingConfig:
"""训练配置"""
# 模型配置
model_name: str = "bert-base-uncased"
num_labels: int = 2
dropout_rate: float = 0.1
# 训练配置
batch_size: int = 32
gradient_accumulation_steps: int = 1
num_epochs: int = 3
learning_rate: float = 2e-5
weight_decay: float = 0.01
warmup_steps: int = 500
# 优化器配置
optimizer: str = "adamw"
scheduler: str = "linear"
# 实验跟踪
logging_steps: int = 100
eval_steps: int = 500
save_steps: int = 1000
# 硬件配置
fp16: bool = True
device: str = "cuda" if torch.cuda.is_available() else "cpu"
@classmethod
def from_yaml(cls, yaml_path: str):
"""从 YAML 文件加载配置"""
with open(yaml_path, 'r') as f:
config_dict = yaml.safe_load(f)
return cls(**config_dict)
def to_training_arguments(self, output_dir: str) -> TrainingArguments:
"""转换为 Hugging Face TrainingArguments"""
return TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
num_train_epochs=self.num_epochs,
per_device_train_batch_size=self.batch_size,
per_device_eval_batch_size=self.batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
learning_rate=self.learning_rate,
weight_decay=self.weight_decay,
warmup_steps=self.warmup_steps,
logging_dir=f"{output_dir}/logs",
logging_steps=self.logging_steps,
eval_steps=self.eval_steps,
save_steps=self.save_steps,
evaluation_strategy="steps",
save_strategy="steps",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
greater_is_better=True,
fp16=self.fp16,
report_to=["wandb"],
run_name=f"bert-imdb-{self.model_name}"
)
import torch
from torch.utils.data import DataLoader
from transformers import Trainer, AdamW, get_linear_schedule_with_warmup
from typing import Dict, List, Optional, Tuple
import numpy as np
from tqdm.auto import tqdm
class CustomTrainer:
"""自定义训练器"""
def __init__(
self,
model,
train_config: TrainingConfig,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
test_dataloader: Optional[DataLoader] = None
):
self.model = model
self.config = train_config
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.test_dataloader = test_dataloader
# 设备设置
self.device = torch.device(train_config.device)
self.model.to(self.device)
# 优化器和调度器
self.optimizer = self._create_optimizer()
self.scheduler = self._create_scheduler()
# 训练状态
self.global_step = 0
self.best_metric = 0.0
self.history = {"train_loss": [], "val_loss": [], "val_accuracy": [], "learning_rate": []}
def _create_optimizer(self) -> torch.optim.Optimizer:
"""创建优化器"""
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in self.model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": self.config.weight_decay
},
{
"params": [
p for n, p in self.model.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0
}
]
if self.config.optimizer == "adamw":
return AdamW(
optimizer_grouped_parameters,
lr=self.config.learning_rate,
eps=1e-8
)
else:
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
def _create_scheduler(self):
"""创建学习率调度器"""
total_steps = len(self.train_dataloader) * self.config.num_epochs
if self.config.scheduler == "linear":
return get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=self.config.warmup_steps,
num_training_steps=total_steps
)
else:
raise ValueError(f"Unsupported scheduler: {self.config.scheduler}")
def train_epoch(self, epoch: int) -> Dict[str, float]:
"""训练一个 epoch"""
self.model.train()
total_loss = 0
progress_bar = tqdm(
self.train_dataloader,
desc=f"Epoch {epoch}",
leave=False
)
for batch in progress_bar:
# 将数据移到设备
batch = {k: v.to(self.device) for k, v in batch.items()}
# 前向传播
outputs = self.model(**batch)
loss = outputs.loss
# 反向传播
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=1.0
)
# 优化器步进
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
# 更新状态
total_loss += loss.item()
self.global_step += 1
# 更新进度条
progress_bar.set_postfix({
"loss": loss.item(),
"lr": self.scheduler.get_last_lr()[0]
})
# 记录学习率
self.history["learning_rate"].append(
self.scheduler.get_last_lr()[0]
)
# 定期验证
if self.global_step % self.config.eval_steps == 0:
val_metrics = self.evaluate()
self.history["val_loss"].append(val_metrics["loss"])
self.history["val_accuracy"].append(val_metrics["accuracy"])
# 保存最佳模型
if val_metrics["accuracy"] > self.best_metric:
self.best_metric = val_metrics["accuracy"]
self.save_model(f"best_model_step_{self.global_step}")
# 切换回训练模式
self.model.train()
avg_loss = total_loss / len(self.train_dataloader)
self.history["train_loss"].append(avg_loss)
return {"train_loss": avg_loss}
def evaluate(self, dataloader: Optional[DataLoader] = None) -> Dict[str, float]:
"""评估模型"""
if dataloader is None:
dataloader = self.val_dataloader
self.model.eval()
total_loss = 0
all_preds = []
all_labels = []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating", leave=False):
batch = {k: v.to(self.device) for k, v in batch.items()}
outputs = self.model(**batch)
loss = outputs.loss
logits = outputs.logits
total_loss += loss.item()
preds = torch.argmax(logits, dim=-1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(batch["labels"].cpu().numpy())
# 计算指标
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
accuracy = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(
all_labels, all_preds,
average="binary"
)
avg_loss = total_loss / len(dataloader)
return {
"loss": avg_loss,
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1": f1
}
def train(self):
"""完整训练过程"""
print(f"Starting training with config: {self.config}")
print(f"Training samples: {len(self.train_dataloader.dataset)}")
print(f"Validation samples: {len(self.val_dataloader.dataset)}")
for epoch in range(self.config.num_epochs):
print(f"\n{'='*50}")
print(f"Epoch {epoch + 1}/{self.config.num_epochs}")
print(f"{'='*50}")
# 训练一个 epoch
train_metrics = self.train_epoch(epoch)
# 验证
val_metrics = self.evaluate()
# 打印结果
print(f"\nEpoch {epoch + 1} Results:")
print(f"Train Loss: {train_metrics['train_loss']:.4f}")
print(f"Val Loss: {val_metrics['loss']:.4f}")
print(f"Val Accuracy: {val_metrics['accuracy']:.4f}")
print(f"Val F1: {val_metrics['f1']:.4f}")
# 最终测试
if self.test_dataloader is not None:
test_metrics = self.evaluate(self.test_dataloader)
print(f"\n{'='*50}")
print("Final Test Results:")
print(f"Test Accuracy: {test_metrics['accuracy']:.4f}")
print(f"Test F1: {test_metrics['f1']:.4f}")
return self.history
def save_model(self, save_path: str):
"""保存模型"""
torch.save({
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"config": self.config,
"history": self.history,
"global_step": self.global_step,
"best_metric": self.best_metric
}, f"{save_path}.pt")
# 同时保存为 Hugging Face 格式
self.model.save_pretrained(f"{save_path}_hf")
def load_model(self, load_path: str):
"""加载模型"""
checkpoint = torch.load(load_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
self.history = checkpoint["history"]
self.global_step = checkpoint["global_step"]
self.best_metric = checkpoint["best_metric"]
#!/usr/bin/env python3
""" 训练脚本 """
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import wandb
from transformers import BertForSequenceClassification
from src.data.processor import DataProcessor
from src.training.trainer import CustomTrainer, TrainingConfig
from src.models.model_utils import set_seed
def main():
# 设置随机种子
set_seed(42)
# 初始化 WandB
wandb.init(
project="bert-imdb-classification",
config={
"model": "bert-base-uncased",
"dataset": "imdb",
"epochs": 3,
"batch_size": 32,
"learning_rate": 2e-5
}
)
# 加载配置
config = TrainingConfig()
# 数据准备
print("Loading and preprocessing data...")
processor = DataProcessor(config.model_name)
dataset_dict = processor.load_imdb_dataset()
# 数据分析
stats_df = processor.analyze_dataset(dataset_dict)
print("\nDataset Statistics:")
print(stats_df)
# 准备数据加载器
train_dataloader, val_dataloader, test_dataloader = processor.prepare_dataset(
dataset_dict,
batch_size=config.batch_size
)
# 加载模型
print(f"\nLoading model: {config.model_name}")
model = BertForSequenceClassification.from_pretrained(
config.model_name,
num_labels=config.num_labels,
hidden_dropout_prob=config.dropout_rate,
attention_probs_dropout_prob=config.dropout_rate
)
# 创建训练器
trainer = CustomTrainer(
model=model,
train_config=config,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
test_dataloader=test_dataloader
)
# 开始训练
print("\nStarting training...")
history = trainer.train()
# 保存最终模型
trainer.save_model("final_model")
# 记录到 WandB
wandb.log({"final_accuracy": trainer.best_metric})
wandb.finish()
print("\nTraining completed successfully!")
return history
if __name__ == "__main__":
history = main()
import numpy as np
from sklearn.metrics import (
accuracy_score,
precision_score,
recall_score,
f1_score,
roc_auc_score,
confusion_matrix,
classification_report
)
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Any
class ModelEvaluator:
"""模型评估器"""
def __init__(self, model, tokenizer, device: str = "cuda"):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.model.to(device)
self.model.eval()
def predict(self, texts: List[str], batch_size: int = 32) -> Tuple[np.ndarray, np.ndarray]:
"""批量预测"""
all_logits = []
all_probs = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
# 编码
inputs = self.tokenizer(
batch_texts,
truncation=True,
padding=True,
max_length=512,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# 预测
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
all_logits.append(logits.cpu().numpy())
all_probs.append(probs.cpu().numpy())
logits_array = np.vstack(all_logits)
probs_array = np.vstack(all_probs)
return logits_array, probs_array
def evaluate_classification(
self,
texts: List[str],
labels: List[int],
threshold: float = 0.5
) -> Dict[str, Any]:
"""分类评估"""
# 预测
logits, probs = self.predict(texts)
preds = np.argmax(probs, axis=1)
# 计算指标
metrics = {
"accuracy": accuracy_score(labels, preds),
"precision": precision_score(labels, preds, average="binary"),
"recall": recall_score(labels, preds, average="binary"),
"f1": f1_score(labels, preds, average="binary"),
"roc_auc": roc_auc_score(labels, probs[:, 1])
}
# 混淆矩阵
cm = confusion_matrix(labels, preds)
# 分类报告
report = classification_report(
labels, preds,
target_names=["Negative", "Positive"],
output_dict=True
)
# 置信度分析
confidence_scores = np.max(probs, axis=1)
return {
"metrics": metrics,
"confusion_matrix": cm,
"classification_report": report,
"predictions": preds,
"probabilities": probs,
"confidence_scores": confidence_scores
}
def analyze_errors(self, texts: List[str], labels: List[int], preds: List[int]):
"""错误分析"""
errors = []
for i, (text, label, pred) in enumerate(zip(texts, labels, preds)):
if label != pred:
errors.append({
"text": text[:200] + "..." if len(text) > 200 else text,
"true_label": "Positive" if label == 1 else "Negative",
"predicted_label": "Positive" if pred == 1 else "Negative",
"text_length": len(text.split())
})
return errors
def plot_confusion_matrix(self, cm, save_path: str = None):
"""绘制混淆矩阵"""
plt.figure(figsize=(8, 6))
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=["Negative", "Positive"],
yticklabels=["Negative", "Positive"]
)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.show()
def plot_roc_curve(self, labels: List[int], probs: np.ndarray, save_path: str = None):
"""绘制 ROC 曲线"""
from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(labels, probs[:, 1])
plt.figure(figsize=(8, 6))
plt.plot(
fpr, tpr,
label=f"ROC Curve (AUC = {roc_auc_score(labels, probs[:, 1]):.3f})"
)
plt.plot([0, 1], [0, 1], "k--", label="Random")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend()
plt.grid(True)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.show()
def calibration_analysis(self, labels: List[int], probs: np.ndarray, n_bins: int = 10):
"""校准分析"""
from sklearn.calibration import calibration_curve
prob_true, prob_pred = calibration_curve(
labels, probs[:, 1], n_bins=n_bins
)
plt.figure(figsize=(8, 6))
plt.plot(prob_pred, prob_true, "s-", label="Model")
plt.plot([0, 1], [0, 1], "k--", label="Perfectly Calibrated")
plt.xlabel("Mean Predicted Probability")
plt.ylabel("Fraction of Positives")
plt.title("Calibration Curve")
plt.legend()
plt.grid(True)
plt.show()
# 计算 ECE (Expected Calibration Error)
bin_edges = np.linspace(0, 1, n_bins + 1)
bin_indices = np.digitize(probs[:, 1], bin_edges) - 1
ece = 0
for i in range(n_bins):
mask = bin_indices == i
if np.sum(mask) > 0:
bin_prob_mean = np.mean(probs[mask, 1])
bin_accuracy = np.mean(labels[mask] == 1)
ece += np.abs(bin_prob_mean - bin_accuracy) * np.sum(mask)
ece /= len(labels)
return {"ece": ece, "calibration_curve": (prob_true, prob_pred)}
import time
from typing import Dict, List
import psutil
import GPUtil
from memory_profiler import memory_usage
import numpy as np
class PerformanceBenchmark:
"""性能基准测试"""
def __init__(self, model, tokenizer, device: str = "cuda"):
self.model = model
self.tokenizer = tokenizer
self.device = device
def measure_inference_time(
self,
texts: List[str],
batch_sizes: List[int] = [1, 4, 8, 16, 32, 64]
) -> Dict[int, Dict[str, float]]:
"""测量推理时间"""
results = {}
for batch_size in batch_sizes:
print(f"\nTesting batch size: {batch_size}")
# 预热
warmup_texts = ["This is a warmup sentence."] * batch_size
self.predict_batch(warmup_texts)
# 实际测试
times = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
start_time = time.perf_counter()
self.predict_batch(batch_texts)
end_time = time.perf_counter()
times.append(end_time - start_time)
avg_time = np.mean(times)
throughput = len(texts) / np.sum(times)
results[batch_size] = {
"avg_inference_time": avg_time,
"throughput": throughput,
"samples_per_second": throughput,
"total_time": np.sum(times)
}
print(f" Average inference time: {avg_time:.4f}s")
print(f" Throughput: {throughput:.2f} samples/s")
return results
def predict_batch(self, texts: List[str]):
"""批量预测"""
inputs = self.tokenizer(
texts,
truncation=True,
padding=True,
max_length=512,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
return outputs
def measure_memory_usage(self, text_lengths: List[int] = [50, 100, 200, 400]):
"""测量内存使用"""
memory_results = {}
for length in text_lengths:
# 生成测试文本
test_text = " ".join(["word"] * length)
texts = [test_text] * 32
# 固定批量大小
# 测量内存
mem_usage = memory_usage((self.predict_batch, (texts,)), interval=0.1)
memory_results[length] = {
"max_memory_mb": max(mem_usage),
"avg_memory_mb": np.mean(mem_usage),
"text_length": length
}
return memory_results
def measure_gpu_utilization(self, texts: List[str], duration: int = 30):
"""测量 GPU 利用率"""
import threading
gpu_stats = []
stop_monitor = False
def monitor_gpu():
while not stop_monitor:
gpus = GPUtil.getGPUs()
for gpu in gpus:
gpu_stats.append({
"time": time.time(),
"memory_used": gpu.memoryUsed,
"memory_total": gpu.memoryTotal,
"load": gpu.load * 100,
"temperature": gpu.temperature
})
time.sleep(0.5)
# 启动监控线程
monitor_thread = threading.Thread(target=monitor_gpu)
monitor_thread.start()
# 运行推理
start_time = time.time()
batch_size = 32
while time.time() - start_time < duration:
for i in range(0, min(len(texts), 1000), batch_size):
batch_texts = texts[i:i+batch_size]
self.predict_batch(batch_texts)
# 停止监控
stop_monitor = True
monitor_thread.join()
return gpu_stats
def generate_report(self, benchmark_results: Dict) -> str:
"""生成性能报告"""
report = []
report.append("=" * 60)
report.append("PERFORMANCE BENCHMARK REPORT")
report.append("=" * 60)
# 推理时间
report.append("\n1. INFERENCE TIMES")
report.append("-" * 40)
for batch_size, metrics in benchmark_results["inference_times"].items():
report.append(
f"Batch Size {batch_size:3d}: "
f"{metrics['avg_inference_time']:.4f}s avg, "
f"{metrics['throughput']:.2f} samples/s"
)
# 内存使用
report.append("\n2. MEMORY USAGE")
report.append("-" * 40)
for length, metrics in benchmark_results["memory_usage"].items():
report.append(
f"Text Length {length:3d}: "
f"{metrics['max_memory_mb']:.1f}MB max, "
f"{metrics['avg_memory_mb']:.1f}MB avg"
)
# 系统资源
report.append("\n3. SYSTEM RESOURCES")
report.append("-" * 40)
cpu_percent = psutil.cpu_percent(interval=1)
memory = psutil.virtual_memory()
report.append(f"CPU Usage: {cpu_percent:.1f}%")
report.append(f"Memory Usage: {memory.percent:.1f}%")
report.append(f"Available Memory: {memory.available / 1024**3:.1f} GB")
return "\n".join(report)
import pytest
import tempfile
from hypothesis import given, strategies as st
from hypothesis.extra.numpy import arrays
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
class TestDataProcessor:
"""数据处理器测试"""
def setup_method(self):
self.processor = DataProcessor("bert-base-uncased")
def test_tokenization(self):
"""测试分词功能"""
text = "This is a test sentence."
tokenized = self.processor.tokenizer(
text,
truncation=True,
padding="max_length",
max_length=128
)
assert "input_ids" in tokenized
assert "attention_mask" in tokenized
assert len(tokenized["input_ids"]) == 128
@given(
st.text(min_size=1, max_size=1000),
st.integers(min_value=0, max_value=1)
)
def test_preprocess_function(self, text, label):
"""假设测试预处理函数"""
examples = {"text": [text], "label": [label]}
result = self.processor.preprocess_function(examples)
assert "input_ids" in result
assert "attention_mask" in result
assert "labels" in result
assert len(result["labels"]) == 1
assert result["labels"][0] == label
def test_dataset_loading(self):
"""测试数据集加载"""
# 使用小样本测试
with tempfile.TemporaryDirectory() as tmpdir:
# 这里使用模拟数据或小型测试数据集
pass
class TestModel:
"""模型测试"""
def setup_method(self):
self.model = BertForSequenceClassification.from_pretrained(
"bert-base-uncased",
num_labels=2
)
self.model.eval()
def test_model_forward(self):
"""测试模型前向传播"""
batch_size = 2
seq_length = 128
input_ids = torch.randint(0, 1000, (batch_size, seq_length))
attention_mask = torch.ones((batch_size, seq_length))
with torch.no_grad():
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask
)
assert outputs.logits.shape == (batch_size, 2)
assert outputs.logits.requires_grad == False
def test_model_save_load(self, tmp_path):
"""测试模型保存和加载"""
# 保存模型
save_path = tmp_path / "test_model"
self.model.save_pretrained(save_path)
# 加载模型
loaded_model = BertForSequenceClassification.from_pretrained(save_path)
# 比较参数
for (name1, param1), (name2, param2) in zip(
self.model.named_parameters(),
loaded_model.named_parameters()
):
assert name1 == name2
assert torch.allclose(param1, param2)
@given(
arrays(
dtype=np.int64,
shape=(2, 128),
elements=st.integers(min_value=0, max_value=1000)
),
arrays(
dtype=np.int64,
shape=(2, 128),
elements=st.integers(min_value=0, max_value=1)
)
)
def test_batch_processing(self, input_ids, attention_mask):
"""批量处理测试"""
input_ids = torch.from_numpy(input_ids)
attention_mask = torch.from_numpy(attention_mask)
with torch.no_grad():
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask
)
assert outputs.logits.shape[0] == input_ids.shape[0]
class TestTraining:
"""训练测试"""
def test_training_step(self):
"""测试训练步骤"""
config = TrainingConfig(
batch_size=2,
num_epochs=1,
learning_rate=1e-5
)
# 创建模拟数据
train_dataset = TensorDataset(
torch.randint(0, 1000, (10, 128)), # input_ids
torch.ones((10, 128)), # attention_mask
torch.randint(0, 2, (10,)) # labels
)
train_dataloader = DataLoader(
train_dataset,
batch_size=config.batch_size
)
val_dataloader = DataLoader(
train_dataset,
batch_size=config.batch_size
)
model = BertForSequenceClassification.from_pretrained(
"bert-base-uncased",
num_labels=2
)
trainer = CustomTrainer(
model=model,
train_config=config,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader
)
# 测试一个训练步骤
initial_loss = trainer.evaluate()["loss"]
# 训练一个 epoch
trainer.train_epoch(0)
final_loss = trainer.evaluate()["loss"]
# 检查损失是否下降
assert final_loss < initial_loss or torch.isclose(
torch.tensor(final_loss),
torch.tensor(initial_loss),
rtol=1e-3
)
def test_gradient_accumulation(self):
"""测试梯度累积"""
config = TrainingConfig(
batch_size=2,
gradient_accumulation_steps=2,
learning_rate=1e-5
)
# 验证梯度累积逻辑
assert config.gradient_accumulation_steps == 2
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
import asyncio
from fastapi.testclient import TestClient
import json
class TestAPI:
"""API 测试"""
def setup_method(self):
from src.api.app import app
self.client = TestClient(app)
def test_health_endpoint(self):
"""测试健康检查端点"""
response = self.client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "healthy"}
def test_predict_endpoint(self):
"""测试预测端点"""
test_data = {
"text": "This movie was absolutely fantastic!",
"model_version": "latest"
}
response = self.client.post("/predict", json=test_data)
assert response.status_code == 200
result = response.json()
assert "prediction" in result
assert "confidence" in result
assert "label" in result
assert result["confidence"] >= 0 and result["confidence"] <= 1
def test_batch_predict_endpoint(self):
"""测试批量预测端点"""
test_data = {
"texts": [
"I loved this movie!",
"It was terrible.",
"The acting was amazing."
]
}
response = self.client.post("/predict/batch", json=test_data)
assert response.status_code == 200
results = response.json()
assert len(results) == 3
for result in results:
assert "prediction" in result
assert "confidence" in result
def test_invalid_input(self):
"""测试无效输入"""
test_data = {
"text": "", # 空文本
}
response = self.client.post("/predict", json=test_data)
# 应该返回 400 错误
assert response.status_code == 400
async def test_concurrent_requests(self):
"""测试并发请求"""
async def make_request():
test_data = {
"text": "Test concurrent request",
"model_version": "latest"
}
response = self.client.post("/predict", json=test_data)
return response.status_code
# 创建多个并发请求
tasks = [make_request() for _ in range(10)]
results = await asyncio.gather(*tasks)
# 所有请求都应该成功
assert all(status == 200 for status in results)
def test_model_management(self):
"""测试模型管理端点"""
response = self.client.get("/models")
assert response.status_code == 200
models = response.json()
assert "available_models" in models
assert "active_model" in models
def test_metrics_endpoint(self):
"""测试指标端点"""
response = self.client.get("/metrics")
assert response.status_code == 200
metrics = response.json()
assert "total_predictions" in metrics
assert "average_response_time" in metrics
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, validator
from typing import List, Optional, Dict, Any
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
from datetime import datetime
import pickle
import hashlib
import torch
from transformers import BertForSequenceClassification, BertTokenizer
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 定义数据模型
class PredictionRequest(BaseModel):
"""预测请求模型"""
text: str = Field(..., min_length=1, max_length=5000)
model_version: Optional[str] = "latest"
@validator('text')
def text_not_empty(cls, v):
if not v.strip():
raise ValueError('Text cannot be empty')
return v.strip()
class BatchPredictionRequest(BaseModel):
"""批量预测请求模型"""
texts: List[str] = Field(..., min_items=1, max_items=100)
model_version: Optional[str] = "latest"
@validator('texts')
def validate_texts(cls, v):
if not all(text.strip() for text in v):
raise ValueError('All texts must be non-empty')
return [text.strip() for text in v]
class PredictionResponse(BaseModel):
"""预测响应模型"""
prediction: int
label: str
confidence: float
model_version: str
request_id: str
processing_time: float
class BatchPredictionResponse(BaseModel):
"""批量预测响应模型"""
predictions: List[Dict[str, Any]]
batch_id: str
total_processed: int
processing_time: float
class ModelManager:
"""模型管理器"""
def __init__(self, model_path: str = "./models"):
self.model_path = model_path
self.models = {}
self.active_model = None
self.model_versions = []
self.executor = ThreadPoolExecutor(max_workers=4)
# 加载模型
self.load_models()
def load_models(self):
"""加载所有可用模型"""
import os
import glob
model_dirs = glob.glob(os.path.join(self.model_path, "*"))
for model_dir in model_dirs:
if os.path.isdir(model_dir):
try:
model_name = os.path.basename(model_dir)
model = BertForSequenceClassification.from_pretrained(model_dir)
tokenizer = BertTokenizer.from_pretrained(model_dir)
self.models[model_name] = {
"model": model,
"tokenizer": tokenizer,
"loaded_at": datetime.now(),
"stats": {
"total_predictions": 0,
"avg_response_time": 0
}
}
self.model_versions.append(model_name)
if self.active_model is None:
self.active_model = model_name
logger.info(f"Loaded model: {model_name}")
except Exception as e:
logger.error(f"Failed to load model {model_dir}: {e}")
if not self.models:
logger.warning("No models found. Loading default model...")
self.load_default_model()
def load_default_model(self):
"""加载默认模型"""
try:
model_name = "bert-base-uncased"
model = BertForSequenceClassification.from_pretrained(
model_name,
num_labels=2
)
tokenizer = BertTokenizer.from_pretrained(model_name)
self.models[model_name] = {
"model": model,
"tokenizer": tokenizer,
"loaded_at": datetime.now(),
"stats": {
"total_predictions": 0,
"avg_response_time": 0
}
}
self.active_model = model_name
self.model_versions.append(model_name)
logger.info(f"Loaded default model: {model_name}")
except Exception as e:
logger.error(f"Failed to load default model: {e}")
raise
async def predict_async(
self,
text: str,
model_version: str = "latest"
) -> Dict[str, Any]:
"""异步预测"""
loop = asyncio.get_event_loop()
# 在线程池中运行预测
result = await loop.run_in_executor(
self.executor,
self.predict_sync,
text,
model_version
)
return result
def predict_sync(
self,
text: str,
model_version: str = "latest"
) -> Dict[str, Any]:
"""同步预测"""
start_time = datetime.now()
# 获取模型
model_key = model_version if model_version != "latest" else self.active_model
if model_key not in self.models:
raise ValueError(f"Model {model_key} not found")
model_info = self.models[model_key]
model = model_info["model"]
tokenizer = model_info["tokenizer"]
# 预处理
inputs = tokenizer(
text,
truncation=True,
padding=True,
max_length=512,
return_tensors="pt"
)
# 预测
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
prediction = torch.argmax(probs, dim=-1).item()
confidence = probs[0][prediction].item()
# 更新统计
end_time = datetime.now()
processing_time = (end_time - start_time).total_seconds()
model_info["stats"]["total_predictions"] += 1
# 更新平均响应时间
current_avg = model_info["stats"]["avg_response_time"]
total_preds = model_info["stats"]["total_predictions"]
model_info["stats"]["avg_response_time"] = (
(current_avg * (total_preds - 1) + processing_time) / total_preds
)
return {
"prediction": prediction,
"label": "Positive" if prediction == 1 else "Negative",
"confidence": confidence,
"model_version": model_key,
"processing_time": processing_time
}
async def predict_batch_async(
self,
texts: List[str],
model_version: str = "latest"
) -> List[Dict[str, Any]]:
"""异步批量预测"""
loop = asyncio.get_event_loop()
# 并行处理
tasks = [
loop.run_in_executor(
self.executor,
self.predict_sync,
text,
model_version
)
for text in texts
]
results = await asyncio.gather(*tasks)
return results
def get_model_stats(self) -> Dict[str, Any]:
"""获取模型统计信息"""
stats = {}
for model_name, model_info in self.models.items():
stats[model_name] = {
"loaded_at": model_info["loaded_at"].isoformat(),
"total_predictions": model_info["stats"]["total_predictions"],
"avg_response_time": model_info["stats"]["avg_response_time"],
"is_active": model_name == self.active_model
}
return stats
def switch_active_model(self, model_version: str) -> bool:
"""切换活动模型"""
if model_version in self.models:
self.active_model = model_version
logger.info(f"Switched active model to: {model_version}")
return True
return False
# 创建 FastAPI 应用
app = FastAPI(
title="BERT Text Classification API",
description="API for sentiment analysis using BERT model",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# 添加 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局变量
model_manager = None
request_counter = 0
cache = {}
startup_time = datetime.now()
# 请求 ID 生成器
def generate_request_id(text: str) -> str:
"""生成请求 ID"""
global request_counter
request_counter += 1
text_hash = hashlib.md5(text.encode()).hexdigest()[:8]
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
return f"{timestamp}_{request_counter}_{text_hash}"
# 缓存装饰器
def cache_predictions(func):
"""预测缓存装饰器"""
async def wrapper(text: str, model_version: str = "latest", *args, **kwargs):
cache_key = f"{model_version}:{hashlib.md5(text.encode()).hexdigest()}"
if cache_key in cache:
# 检查缓存是否过期(5 分钟)
cached_time, result = cache[cache_key]
if (datetime.now() - cached_time).total_seconds() < 300:
logger.info(f"Cache hit for key: {cache_key[:20]}...")
return result
# 执行预测
result = await func(text, model_version, *args, **kwargs)
# 更新缓存
cache[cache_key] = (datetime.now(), result)
# 限制缓存大小
if len(cache) > 1000:
# 移除最旧的条目
oldest_key = min(cache.keys(), key=lambda k: cache[k][0])
del cache[oldest_key]
return result
return wrapper
@app.on_event("startup")
async def startup_event():
"""启动事件"""
global model_manager
logger.info("Starting up BERT Classification API...")
# 初始化模型管理器
model_manager = ModelManager("./models")
logger.info(f"Loaded {len(model_manager.models)} models")
logger.info(f"Active model: {model_manager.active_model}")
@app.on_event("shutdown")
async def shutdown_event():
"""关闭事件"""
logger.info("Shutting down BERT Classification API...")
if model_manager:
model_manager.executor.shutdown()
@app.get("/")
async def root():
"""根端点"""
return {
"message": "BERT Text Classification API",
"version": "1.0.0",
"docs": "/docs",
"health": "/health",
"models": "/models"
}
@app.get("/health")
async def health_check():
"""健康检查"""
return {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"models_loaded": len(model_manager.models) if model_manager else 0,
"active_model": model_manager.active_model if model_manager else None
}
@app.post("/predict", response_model=PredictionResponse)
@cache_predictions
async def predict(
request: PredictionRequest,
background_tasks: BackgroundTasks
):
"""单文本预测"""
try:
start_time = datetime.now()
# 生成请求 ID
request_id = generate_request_id(request.text)
# 执行预测
result = await model_manager.predict_async(
request.text,
request.model_version
)
# 计算处理时间
processing_time = (datetime.now() - start_time).total_seconds()
# 记录日志(后台任务)
background_tasks.add_task(
log_prediction,
request_id=request_id,
text_length=len(request.text),
prediction=result["prediction"],
confidence=result["confidence"],
processing_time=processing_time
)
return PredictionResponse(
prediction=result["prediction"],
label=result["label"],
confidence=result["confidence"],
model_version=result["model_version"],
request_id=request_id,
processing_time=processing_time
)
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(
status_code=500,
detail=f"Prediction failed: {str(e)}"
)
@app.post("/predict/batch", response_model=BatchPredictionResponse)
async def predict_batch(request: BatchPredictionRequest):
"""批量预测"""
try:
start_time = datetime.now()
# 生成批次 ID
batch_id = hashlib.md5("".join(request.texts).encode()).hexdigest()[:12]
# 批量预测
results = await model_manager.predict_batch_async(
request.texts,
request.model_version
)
# 计算处理时间
processing_time = (datetime.now() - start_time).total_seconds()
# 准备响应
predictions = []
for i, (text, result) in enumerate(zip(request.texts, results)):
request_id = generate_request_id(text)
predictions.append({
"text_preview": text[:100] + "..." if len(text) > 100 else text,
"prediction": result["prediction"],
"label": result["label"],
"confidence": result["confidence"],
"request_id": request_id
})
return BatchPredictionResponse(
predictions=predictions,
batch_id=batch_id,
total_processed=len(results),
processing_time=processing_time
)
except Exception as e:
logger.error(f"Batch prediction error: {e}")
raise HTTPException(
status_code=500,
detail=f"Batch prediction failed: {str(e)}"
)
@app.get("/models")
async def get_models():
"""获取可用模型"""
if not model_manager:
raise HTTPException(status_code=500, detail="Model manager not initialized")
stats = model_manager.get_model_stats()
return {
"available_models": list(model_manager.models.keys()),
"active_model": model_manager.active_model,
"model_stats": stats
}
@app.post("/models/switch")
async def switch_model(model_version: str):
"""切换活动模型"""
if not model_manager:
raise HTTPException(status_code=500, detail="Model manager not initialized")
success = model_manager.switch_active_model(model_version)
if success:
return {
"message": f"Switched active model to {model_version}",
"active_model": model_manager.active_model
}
else:
raise HTTPException(
status_code=400,
detail=f"Model {model_version} not found"
)
@app.get("/metrics")
async def get_metrics():
"""获取 API 指标"""
global cache, request_counter
total_predictions = sum(
model_info["stats"]["total_predictions"]
for model_info in model_manager.models.values()
)
avg_response_times = [
model_info["stats"]["avg_response_time"]
for model_info in model_manager.models.values()
]
avg_response_time = (
sum(avg_response_times) / len(avg_response_times)
if avg_response_times else 0
)
return {
"total_predictions": total_predictions,
"total_requests": request_counter,
"cache_size": len(cache),
"cache_hit_rate": 0.0, # 需要实现命中率跟踪
"average_response_time": avg_response_time,
"models_loaded": len(model_manager.models),
"uptime": (datetime.now() - startup_time).total_seconds()
}
@app.get("/predictions/history")
async def get_prediction_history(
limit: int = 100,
model_version: Optional[str] = None
):
"""获取预测历史"""
# 这里可以实现从数据库或日志中获取历史记录
# 暂时返回模拟数据
return {
"message": "Prediction history endpoint",
"limit": limit,
"model_version": model_version
}
async def log_prediction(
request_id: str,
text_length: int,
prediction: int,
confidence: float,
processing_time: float
):
"""记录预测日志(后台任务)"""
log_entry = {
"timestamp": datetime.now().isoformat(),
"request_id": request_id,
"text_length": text_length,
"prediction": prediction,
"confidence": confidence,
"processing_time": processing_time
}
# 这里可以将日志写入文件或数据库
logger.info(f"Prediction logged: {log_entry}")
# 运行应用
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=8000,
log_level="info",
reload=True # 开发模式下自动重载
)
# Dockerfile
FROM python:3.9-slim
# 设置工作目录
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \
build-essential \
curl \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY requirements.txt .
# 安装 Python 依赖
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY . .
# 创建模型目录
RUN mkdir -p models
# 下载默认模型(可选)
# RUN python -c "from transformers import BertForSequenceClassification, BertTokenizer; \
# model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2); \
# model.save_pretrained('./models/bert-base-uncased'); \
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased'); \
# tokenizer.save_pretrained('./models/bert-base-uncased')"
# 暴露端口
EXPOSE 8000
# 健康检查
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# 启动命令
CMD ["uvicorn", "src.api.app:app", "--host", "0.0.0.0", "--port", "8000"]
# docker-compose.yml
version: '3.8'
services:
api:
build: .
ports:
- "8000:8000"
volumes:
- ./models:/app/models
- ./logs:/app/logs
environment:
- CUDA_VISIBLE_DEVICES=0 # 如果使用 GPU
- MODEL_PATH=/app/models
- LOG_LEVEL=INFO
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 40s
# 可选:添加监控服务
prometheus:
image: prom/prometheus:latest
ports:
- "9090:9090"
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
- prometheus_data:/prometheus
command:
- '--config.file=/etc/prometheus/prometheus.yml'
- '--storage.tsdb.path=/prometheus'
- '--web.console.libraries=/etc/prometheus/console_libraries'
- '--web.console.templates=/etc/prometheus/console_templates'
- '--storage.tsdb.retention.time=200h'
- '--web.enable-lifecycle'
grafana:
image: grafana/grafana:latest
ports:
- "3000:3000"
volumes:
- grafana_data:/var/lib/grafana
- ./grafana/provisioning:/etc/grafana/provisioning
environment:
- GF_SECURITY_ADMIN_PASSWORD=admin
restart: unless-stopped
volumes:
prometheus_data:
grafana_data:
# logging_config.py
import logging
import logging.config
import json
import sys
from datetime import datetime
from pythonjsonlogger import jsonlogger
class CustomJsonFormatter(jsonlogger.JsonFormatter):
"""自定义 JSON 日志格式化器"""
def add_fields(self, log_record, record, message_dict):
super().add_fields(log_record, record, message_dict)
if not log_record.get('timestamp'):
log_record['timestamp'] = datetime.utcnow().isoformat()
if log_record.get('level'):
log_record['level'] = log_record['level'].upper()
else:
log_record['level'] = record.levelname
# 添加额外字段
log_record['service'] = 'bert-classification-api'
log_record['module'] = record.module
log_record['function'] = record.funcName
log_record['line'] = record.lineno
LOGGING_CONFIG = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'json': {
'()': CustomJsonFormatter,
'format': '%(timestamp)s %(level)s %(name)s %(message)s'
},
'simple': {
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
}
},
'handlers': {
'console': {
'class': 'logging.StreamHandler',
'formatter': 'json',
'level': 'INFO'
},
'file': {
'class': 'logging.handlers.RotatingFileHandler',
'filename': 'logs/application.log',
'formatter': 'json',
'maxBytes': 10485760, # 10MB
'backupCount': 5,
'level': 'INFO'
},
'error_file': {
'class': 'logging.handlers.RotatingFileHandler',
'filename': 'logs/error.log',
'formatter': 'json',
'maxBytes': 10485760,
'backupCount': 5,
'level': 'ERROR'
}
},
'loggers': {
'': {
'handlers': ['console', 'file', 'error_file'],
'level': 'INFO',
'propagate': True
},
'uvicorn': {
'handlers': ['console', 'file'],
'level': 'INFO',
'propagate': False
},
'uvicorn.error': {
'handlers': ['error_file'],
'level': 'ERROR',
'propagate': False
}
}
}
def setup_logging():
"""设置日志配置"""
logging.config.dictConfig(LOGGING_CONFIG)
# 捕获未处理的异常
def handle_exception(exc_type, exc_value, exc_traceback):
if issubclass(exc_type, KeyboardInterrupt):
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return
logger = logging.getLogger(__name__)
logger.critical("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))
sys.excepthook = handle_exception
# monitoring.py
from prometheus_client import Counter, Histogram, Gauge, generate_latest
from prometheus_client.core import CollectorRegistry
import time
from functools import wraps
from fastapi.responses import Response
# 创建指标注册表
registry = CollectorRegistry()
# 定义指标
PREDICTION_REQUESTS = Counter(
'prediction_requests_total',
'Total number of prediction requests',
['model_version', 'endpoint'],
registry=registry
)
PREDICTION_LATENCY = Histogram(
'prediction_latency_seconds',
'Prediction latency in seconds',
['model_version', 'endpoint'],
buckets=(0.01, 0.05, 0.1, 0.5, 1.0, 5.0),
registry=registry
)
ACTIVE_MODELS = Gauge(
'active_models_total',
'Number of active models',
registry=registry
)
MODEL_LOAD_TIME = Histogram(
'model_load_time_seconds',
'Model loading time in seconds',
['model_name'],
registry=registry
)
CACHE_HITS = Counter(
'cache_hits_total',
'Total number of cache hits',
registry=registry
)
CACHE_MISSES = Counter(
'cache_misses_total',
'Total number of cache misses',
registry=registry
)
ERROR_COUNT = Counter(
'prediction_errors_total',
'Total number of prediction errors',
['error_type', 'model_version'],
registry=registry
)
def monitor_predictions(func):
"""监控预测函数的装饰器"""
@wraps(func)
async def wrapper(*args, **kwargs):
start_time = time.time()
# 获取模型版本
model_version = kwargs.get('model_version', 'latest')
endpoint = func.__name__
# 递增请求计数器
PREDICTION_REQUESTS.labels(
model_version=model_version,
endpoint=endpoint
).inc()
try:
result = await func(*args, **kwargs)
# 记录延迟
latency = time.time() - start_time
PREDICTION_LATENCY.labels(
model_version=model_version,
endpoint=endpoint
).observe(latency)
return result
except Exception as e:
# 记录错误
ERROR_COUNT.labels(
error_type=type(e).__name__,
model_version=model_version
).inc()
raise
return wrapper
def update_model_metrics(model_manager):
"""更新模型指标"""
ACTIVE_MODELS.set(len(model_manager.models))
for model_name, model_info in model_manager.models.items():
# 可以添加更多模型特定指标
pass
@app.get("/metrics")
async def metrics_endpoint():
"""Prometheus 指标端点"""
# 更新动态指标
if model_manager:
update_model_metrics(model_manager)
return Response(
generate_latest(registry),
media_type="text/plain"
)
#!/usr/bin/env python3
""" 端到端测试脚本 """
import sys
import os
import time
import requests
import json
from typing import Dict, List, Any
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.data.processor import DataProcessor
from src.training.trainer import CustomTrainer
from src.evaluation.metrics import ModelEvaluator
from src.api.app import app
from fastapi.testclient import TestClient
class EndToEndTest:
"""端到端测试"""
def __init__(self, api_url: str = "http://localhost:8000"):
self.api_url = api_url
self.client = TestClient(app)
# 测试客户端
self.results = {}
def run_all_tests(self):
"""运行所有测试"""
print("=" * 60)
print("端到端测试开始")
print("=" * 60)
tests = [
self.test_environment,
self.test_data_pipeline,
self.test_model_training,
self.test_model_evaluation,
self.test_api_endpoints,
self.test_performance,
self.test_error_handling,
self.test_concurrent_requests
]
for test in tests:
try:
test_name = test.__name__
print(f"\n执行测试:{test_name}")
print("-" * 40)
result = test()
self.results[test_name] = {"status": "PASSED", "result": result}
print(f"✓ {test_name}: PASSED")
except Exception as e:
self.results[test_name] = {"status": "FAILED", "error": str(e)}
print(f"✗ {test_name}: FAILED - {e}")
self.generate_report()
def test_environment(self) -> Dict[str, Any]:
"""测试环境配置"""
# 检查 Python 版本
python_version = sys.version_info
assert python_version.major == 3 and python_version.minor >= 8
# 检查依赖
import torch
import transformers
return {
"python_version": f"{python_version.major}.{python_version.minor}.{python_version.micro}",
"torch_version": torch.__version__,
"transformers_version": transformers.__version__,
"cuda_available": torch.cuda.is_available()
}
def test_data_pipeline(self) -> Dict[str, Any]:
"""测试数据管道"""
processor = DataProcessor("bert-base-uncased")
# 测试分词
test_text = "This is a test sentence for tokenization."
tokenized = processor.tokenizer(
test_text,
truncation=True,
padding="max_length",
max_length=128
)
assert "input_ids" in tokenized
assert len(tokenized["input_ids"]) == 128
# 测试批处理
batch_texts = ["Text 1", "Text 2", "Text 3"]
batch_labels = [0, 1, 0]
batch = processor.collate_fn([
{"input_ids": torch.tensor([101] * 128),
"attention_mask": torch.tensor([1] * 128),
"labels": label}
for label in batch_labels
])
assert batch["input_ids"].shape[0] == 3
assert batch["labels"].shape[0] == 3
return {
"tokenization_test": "PASSED",
"batch_processing_test": "PASSED"
}
def test_model_training(self) -> Dict[str, Any]:
"""测试模型训练(使用小数据集)"""
# 使用小样本进行快速测试
from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained(
"bert-base-uncased",
num_labels=2
)
# 创建模拟数据
train_data = torch.utils.data.TensorDataset(
torch.randint(0, 1000, (100, 128)),
torch.ones((100, 128)),
torch.randint(0, 2, (100,))
)
train_dataloader = torch.utils.data.DataLoader(
train_data,
batch_size=16
)
val_dataloader = torch.utils.data.DataLoader(
train_data,
batch_size=16
)
# 快速训练测试
from src.training.train_config import TrainingConfig
config = TrainingConfig(
batch_size=16,
num_epochs=1,
learning_rate=1e-5
)
trainer = CustomTrainer(
model=model,
train_config=config,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader
)
# 测试一个训练步骤
initial_metrics = trainer.evaluate()
trainer.train_epoch(0)
final_metrics = trainer.evaluate()
# 检查模型是否学习
assert final_metrics["loss"] <= initial_metrics["loss"] * 1.5
return {
"initial_loss": initial_metrics["loss"],
"final_loss": final_metrics["loss"],
"training_completed": True
}
def test_model_evaluation(self) -> Dict[str, Any]:
"""测试模型评估"""
from transformers import BertForSequenceClassification, BertTokenizer
model = BertForSequenceClassification.from_pretrained(
"bert-base-uncased",
num_labels=2
)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
evaluator = ModelEvaluator(model, tokenizer, device="cpu")
# 测试预测
test_texts = [
"This is excellent!",
"I don't like this at all.",
"It's okay, nothing special."
]
test_labels = [1, 0, 0]
results = evaluator.evaluate_classification(test_texts, test_labels)
assert "accuracy" in results["metrics"]
assert "confusion_matrix" in results
return {
"evaluation_metrics": results["metrics"],
"predictions_made": len(results["predictions"])
}
def test_api_endpoints(self) -> Dict[str, Any]:
"""测试 API 端点"""
# 测试健康检查
response = self.client.get("/health")
assert response.status_code == 200
health_data = response.json()
assert health_data["status"] == "healthy"
# 测试预测端点
test_data = {
"text": "This movie was absolutely fantastic! I loved every minute of it.",
"model_version": "latest"
}
response = self.client.post("/predict", json=test_data)
assert response.status_code == 200
prediction_data = response.json()
assert "prediction" in prediction_data
assert "confidence" in prediction_data
assert prediction_data["confidence"] >= 0
# 测试批量预测
batch_data = {
"texts": [
"Amazing film, highly recommended!",
"Not my cup of tea, unfortunately.",
"The acting was superb."
]
}
response = self.client.post("/predict/batch", json=batch_data)
assert response.status_code == 200
batch_result = response.json()
assert len(batch_result["predictions"]) == 3
# 测试模型管理端点
response = self.client.get("/models")
assert response.status_code == 200
models_data = response.json()
assert "available_models" in models_data
return {
"health_check": "PASSED",
"single_prediction": "PASSED",
"batch_prediction": "PASSED",
"model_management": "PASSED"
}
def test_performance(self) -> Dict[str, Any]:
"""测试性能"""
import time
# 测试单个请求延迟
test_cases = [
"Short text",
"Medium length text " * 10,
"Very long text " * 100
]
latencies = []
for text in test_cases:
start_time = time.perf_counter()
response = self.client.post("/predict", json={"text": text})
end_time = time.perf_counter()
latency = end_time - start_time
latencies.append(latency)
assert response.status_code == 200
assert latency < 5.0 # 单个请求应小于 5 秒
# 测试吞吐量
batch_texts = [f"Test text {i}" for i in range(10)]
start_time = time.perf_counter()
response = self.client.post("/predict/batch", json={"texts": batch_texts})
end_time = time.perf_counter()
batch_latency = end_time - start_time
assert response.status_code == 200
assert batch_latency < 10.0 # 批量请求应小于 10 秒
avg_latency_per_request = batch_latency / len(batch_texts)
return {
"single_request_latencies": latencies,
"batch_latency": batch_latency,
"avg_latency_per_request": avg_latency_per_request,
"throughput": len(batch_texts) / batch_latency
}
def test_error_handling(self) -> Dict[str, Any]:
"""测试错误处理"""
# 测试空文本
response = self.client.post("/predict", json={"text": ""})
assert response.status_code == 422 # 验证错误
# 测试不存在的模型
response = self.client.post("/predict", json={
"text": "Test text",
"model_version": "non-existent-model"
})
# 根据实现,可能返回 400 或 500
assert response.status_code in [400, 500]
# 测试无效的批量请求
response = self.client.post("/predict/batch", json={"texts": []})
assert response.status_code == 422
return {
"empty_text_handling": "PASSED",
"invalid_model_handling": "PASSED",
"empty_batch_handling": "PASSED"
}
def test_concurrent_requests(self) -> Dict[str, Any]:
"""测试并发请求"""
import concurrent.futures
import time
test_texts = [f"Concurrent test {i}" for i in range(20)]
def make_request(text):
start_time = time.perf_counter()
response = self.client.post("/predict", json={"text": text})
end_time = time.perf_counter()
return {
"status_code": response.status_code,
"latency": end_time - start_time
}
# 使用线程池模拟并发请求
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [
executor.submit(make_request, text)
for text in test_texts
]
results = [
future.result()
for future in concurrent.futures.as_completed(futures)
]
status_codes = [r["status_code"] for r in results]
latencies = [r["latency"] for r in results]
# 所有请求都应成功
assert all(code == 200 for code in status_codes)
# 平均延迟应在合理范围内
avg_latency = np.mean(latencies)
assert avg_latency < 2.0
return {
"total_requests": len(results),
"success_rate": sum(1 for code in status_codes if code == 200) / len(results),
"avg_latency": avg_latency,
"max_latency": max(latencies)
}
def generate_report(self):
"""生成测试报告"""
print("\n" + "=" * 60)
print("测试报告")
print("=" * 60)
total_tests = len(self.results)
passed_tests = sum(1 for result in self.results.values() if result["status"] == "PASSED")
print(f"总测试数:{total_tests}")
print(f"通过测试:{passed_tests}")
print(f"失败测试:{total_tests - passed_tests}")
if total_tests - passed_tests > 0:
print("\n失败详情:")
for test_name, result in self.results.items():
if result["status"] == "FAILED":
print(f" {test_name}: {result['error']}")
# 保存详细报告
report_data = {
"summary": {
"total_tests": total_tests,
"passed_tests": passed_tests,
"failed_tests": total_tests - passed_tests,
"success_rate": passed_tests / total_tests if total_tests > 0 else 0
},
"details": self.results,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
with open("test_report.json", "w") as f:
json.dump(report_data, f, indent=2, default=str)
print(f"\n详细报告已保存到:test_report.json")
# 决定是否通过
if passed_tests == total_tests:
print("\n✓ 所有测试通过!")
return True
else:
print("\n✗ 部分测试失败!")
return False
def main():
"""主函数"""
# 启动测试
tester = EndToEndTest()
try:
success = tester.run_all_tests()
if success:
print("\n" + "=" * 60)
print("端到端测试完成 - 所有测试通过!")
print("=" * 60)
sys.exit(0)
else:
print("\n" + "=" * 60)
print("端到端测试完成 - 存在失败的测试!")
print("=" * 60)
sys.exit(1)
except Exception as e:
print(f"\n测试执行出错:{e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()
# .github/workflows/test.yml
name: CI/CD Pipeline
on:
push:
branches: [main, develop]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9, 3.10]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov flake8 mypy
- name: Lint with flake8
run: |
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Type check with mypy
run: |
mypy src --ignore-missing-imports
- name: Test with pytest
run: |
pytest tests/ -v --cov=src --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
name: codecov-umbrella
docker:
needs: test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Build Docker image
run: |
docker build -t bert-classification-api:latest .
- name: Run Docker container
run: |
docker run -d -p 8000:8000 --name test-api bert-classification-api:latest
sleep 10 # 等待应用启动
- name: Test Docker container
run: |
curl -f http://localhost:8000/health || exit 1
- name: Cleanup
run: |
docker stop test-api
docker rm test-api
from optimum.onnxruntime import ORTModelForSequenceClassification
from optimum.onnxruntime import ORTQuantizer
from optimum.onnxruntime.configuration import AutoQuantizationConfig
import onnxruntime as ort
import time
from typing import Dict, Any
class ModelOptimizer:
"""模型优化器"""
def __init__(self, model_path: str):
self.model_path = model_path
def convert_to_onnx(self, output_path: str = "./models/onnx"):
"""转换为 ONNX 格式"""
from optimum.onnxruntime import ORTModelForSequenceClassification
# 加载模型并转换为 ONNX
model = ORTModelForSequenceClassification.from_pretrained(
self.model_path,
from_transformers=True,
export=True
)
# 保存 ONNX 模型
model.save_pretrained(output_path)
return output_path
def quantize_model(
self,
model_path: str,
quantization_config: str = "avx512_vnni"
):
"""量化模型"""
# 加载模型
model = ORTModelForSequenceClassification.from_pretrained(model_path)
# 创建量化器
quantizer = ORTQuantizer.from_pretrained(model_path)
# 配置量化
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
# 应用量化
quantizer.quantize(
save_dir=f"{model_path}_quantized",
quantization_config=qconfig
)
return f"{model_path}_quantized"
def optimize_with_onnxruntime(self, model_path: str):
"""使用 ONNX Runtime 优化"""
# 会话选项
sess_options = ort.SessionOptions()
# 启用并行执行
sess_options.intra_op_num_threads = 4
sess_options.inter_op_num_threads = 2
# 优化级别
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# 内存优化
sess_options.enable_cpu_mem_arena = True
sess_options.enable_mem_pattern = True
# 创建优化会话
session = ort.InferenceSession(
f"{model_path}/model.onnx",
sess_options=sess_options,
providers=['CPUExecutionProvider']
)
return session
def benchmark_optimizations(
self,
model_variants: Dict[str, Any],
test_data: List[str]
):
"""基准测试不同优化"""
results = {}
for variant_name, model in model_variants.items():
start_time = time.time()
# 运行推理
for text in test_data:
# 执行预测
pass
end_time = time.time()
results[variant_name] = {
"total_time": end_time - start_time,
"avg_time_per_sample": (end_time - start_time) / len(test_data),
"throughput": len(test_data) / (end_time - start_time)
}
return results
from functools import lru_cache
import hashlib
import pickle
from datetime import datetime, timedelta
from typing import Dict, Any
class PredictionCache:
"""预测缓存"""
def __init__(self, max_size: int = 10000, ttl: int = 3600):
self.max_size = max_size
self.ttl = timedelta(seconds=ttl)
self.cache = {}
self.access_times = {}
self.hits = 0
self.misses = 0
def _generate_key(self, text: str, model_version: str) -> str:
"""生成缓存键"""
content = f"{model_version}:{text}"
return hashlib.md5(content.encode()).hexdigest()
def get(self, text: str, model_version: str):
"""获取缓存值"""
key = self._generate_key(text, model_version)
if key in self.cache:
cached_time, value = self.cache[key]
# 检查是否过期
if datetime.now() - cached_time < self.ttl:
self.access_times[key] = datetime.now()
self.hits += 1
return value
self.misses += 1
return None
def set(self, text: str, model_version: str, value):
"""设置缓存值"""
key = self._generate_key(text, model_version)
# 如果缓存已满,移除最旧的条目
if len(self.cache) >= self.max_size:
oldest_key = min(
self.access_times.keys(),
key=lambda k: self.access_times[k]
)
del self.cache[oldest_key]
del self.access_times[oldest_key]
# 存储新值
self.cache[key] = (datetime.now(), value)
self.access_times[key] = datetime.now()
def clear_expired(self):
"""清除过期条目"""
now = datetime.now()
expired_keys = [
key
for key, (cached_time, _) in self.cache.items()
if now - cached_time > self.ttl
]
for key in expired_keys:
del self.cache[key]
del self.access_times[key]
def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计"""
total = self.hits + self.misses
hit_rate = self.hits / total if total > 0 else 0
return {
"size": len(self.cache),
"hits": self.hits,
"misses": self.misses,
"hit_rate": hit_rate,
"max_size": self.max_size,
"ttl_seconds": self.ttl.total_seconds()
}
通过本项目的完整实施,我们获得了以下关键收获:
本项目可以进一步扩展为:
在将系统部署到生产环境前,请确保:
本文详细展示了开源 AI 模型从引入到测试的完整技术流程。通过这个实战项目,我们不仅学习了如何集成和使用先进的 AI 模型,更重要的是掌握了构建生产级 AI 应用的系统工程方法。这套方法论和代码框架可以应用于各种 AI 项目,为 AI 应用开发提供坚实基础。
成功的 AI 项目不仅仅是模型准确率,更是系统工程、可维护性、可扩展性和可靠性的综合体现。希望本文能成为我们 AI 工程化道路上的有力参考。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online