跳到主要内容Python 开源 AI 模型引入与测试全流程实战 | 极客日志PythonAI算法
Python 开源 AI 模型引入与测试全流程实战
基于 Python 和 Hugging Face Transformers,演示 BERT 模型从环境搭建、数据预处理、微调训练到 FastAPI 部署的全流程。涵盖单元测试、性能基准测试及 Docker 容器化方案,提供生产级 AI 应用工程实践参考。
落日余晖15 浏览 Python 开源 AI 模型引入与测试全流程实战

引言:开源 AI 生态系统概览
开源 AI 模型已成为现代应用的核心。从 BERT 到 Llama,Hugging Face 提供了超过 10 万个预训练模型。本文以 BERT 为例,展示从环境配置、模型加载、数据处理、微调训练到部署测试的完整工程化流程。
技术栈选择:
- 模型框架: Hugging Face Transformers
- 深度学习框架: PyTorch
- 数据处理: Pandas, NumPy, Datasets
- 实验跟踪: Weights & Biases (WandB)
- 测试框架: Pytest, Hypothesis
- 部署工具: FastAPI, Docker
环境配置与项目初始化
系统要求
确保 Python 版本在 3.8 以上,若有 GPU 支持更佳。
python --version
nvidia-smi
创建虚拟环境
建议为每个项目隔离依赖。
mkdir openai-introduction && cd openai-introduction
python -m venv venv
source venv/bin/activate
venv\Scripts\activate
依赖管理
创建 requirements.txt 统一管理依赖。
torch>=2.0.0
transformers>=4.30.0
datasets>=2.12.0
accelerate>=0.20.0
numpy>=1.24.0
pandas>=2.0.0
scikit-learn>=1.3.0
wandb>=0.15.0
tensorboard>=2.13.0
fastapi>=0.100.0
uvicorn[standard]>=0.23.0
pydantic>=2.0.0
pytest>=7.4.0
hypothesis>=6.82.0
black>=23.0.0
flake8>=6.0.0
mypy>=1.5.0
pre-commit>=3.3.0
optimum>=1.12.0
onnxruntime>=1.15.0
jupyter>=1.0.0
ipython>=8.14.0
matplotlib>=3.7.0
seaborn>=0.12.0
安装依赖时,若使用 GPU 需指定 CUDA 版本。
pip install -r requirements.txt
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
项目结构
openai-introduction/
├── src/
│ ├── data/
│ ├── models/
│ ├── training/
│ ├── evaluation/
│ └── api/
├── tests/
├── notebooks/
├── configs/
├── scripts/
├── .pre-commit-config.yaml
├── Dockerfile
├── docker-compose.yml
├── pyproject.toml
├── README.md
└── requirements.txt
理解 BERT 模型架构
BERT(Bidirectional Encoder Representations from Transformers)通过双向上下文理解提升了语言建模能力。其核心是 Transformer 编码器。
Transformer 多头注意力机制
这里展示一个简化的多头注意力实现,帮助理解底层原理。
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
"""多头注意力机制实现"""
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert embed_dim % num_heads == 0
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.scaling = self.head_dim ** -0.5
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = query.size(0)
q = self.q_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scaling
if attention_mask is not None:
attn_scores = attn_scores.masked_fill(attention_mask == 0, -1e9)
attn_probs = F.softmax(attn_scores, dim=-1)
attn_probs = self.dropout(attn_probs)
attn_output = torch.matmul(attn_probs, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_probs
Hugging Face Transformers 封装
实际开发中,我们直接使用库提供的类,它们已经处理了复杂的初始化细节。
from transformers import BertConfig, BertModel, BertTokenizer, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
import torch.nn as nn
class BertForSequenceClassification(PreTrainedModel):
"""基于 BERT 的序列分类模型"""
def __init__(self, config: BertConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.bert = BertModel(config)
self.classifier = nn.Sequential(
nn.Dropout(config.hidden_dropout_prob),
nn.Linear(config.hidden_size, config.hidden_size),
nn.GELU(),
nn.Dropout(config.hidden_dropout_prob),
nn.Linear(config.hidden_size, config.num_labels)
)
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs
) -> SequenceClassifierOutput:
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
**kwargs
)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
数据准备与预处理
我们以 IMDB 电影评论数据集进行情感分类任务。
数据集加载与划分
from datasets import load_dataset, DatasetDict
import pandas as pd
from sklearn.model_selection import train_test_split
class DataProcessor:
def __init__(self, model_name: str = "bert-base-uncased", max_length: int = 512):
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.max_length = max_length
def load_imdb_dataset(self, cache_dir: str = "./data"):
dataset = load_dataset("imdb", cache_dir=cache_dir)
train_test_split = dataset["train"].train_test_split(test_size=0.1, seed=42)
dataset_dict = DatasetDict({
"train": train_test_split["train"],
"validation": train_test_split["test"],
"test": dataset["test"]
})
return dataset_dict
def preprocess_function(self, examples):
tokenized_inputs = self.tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt"
)
return {
"input_ids": tokenized_inputs["input_ids"].tolist(),
"attention_mask": tokenized_inputs["attention_mask"].tolist(),
"labels": examples["label"]
}
def prepare_dataset(self, dataset_dict, batch_size: int = 32):
tokenized_datasets = dataset_dict.map(
self.preprocess_function, batched=True, remove_columns=["text", "label"]
)
tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
train_dataloader = DataLoader(
tokenized_datasets["train"], shuffle=True, batch_size=batch_size,
collate_fn=self.collate_fn
)
val_dataloader = DataLoader(
tokenized_datasets["validation"], batch_size=batch_size, collate_fn=self.collate_fn
)
test_dataloader = DataLoader(
tokenized_datasets["test"], batch_size=batch_size, collate_fn=self.collate_fn
)
return train_dataloader, val_dataloader, test_dataloader
def collate_fn(self, batch):
input_ids = torch.stack([item["input_ids"] for item in batch])
attention_mask = torch.stack([item["attention_mask"] for item in batch])
labels = torch.tensor([item["labels"] for item in batch])
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
数据增强策略
import nlpaug.augmenter.word as naw
from typing import List
class DataAugmenter:
def __init__(self, aug_method: str = "synonym"):
self.aug_method = aug_method
if aug_method == "synonym":
self.augmenter = naw.SynonymAug(aug_src="wordnet")
elif aug_method == "contextual":
self.augmenter = naw.ContextualWordEmbsAug(model_path='bert-base-uncased', action="substitute")
elif aug_method == "back_translation":
self.augmenter = naw.BackTranslationAug(
from_model_name='facebook/wmt19-en-de', to_model_name='facebook/wmt19-de-en'
)
else:
raise ValueError(f"Unsupported augmentation method: {aug_method}")
def augment_text(self, text: str, num_aug: int = 3) -> List[str]:
augmented_texts = []
for _ in range(num_aug):
augmented_text = self.augmenter.augment(text)
augmented_texts.append(augmented_text)
return augmented_texts
模型训练与微调
训练配置
使用 TrainingArguments 管理超参数。
from dataclasses import dataclass
from typing import Optional, Dict, Any
import yaml
from transformers import TrainingArguments
@dataclass
class TrainingConfig:
model_name: str = "bert-base-uncased"
num_labels: int = 2
dropout_rate: float = 0.1
batch_size: int = 32
gradient_accumulation_steps: int = 1
num_epochs: int = 3
learning_rate: float = 2e-5
weight_decay: float = 0.01
warmup_steps: int = 500
optimizer: str = "adamw"
scheduler: str = "linear"
logging_steps: int = 100
eval_steps: int = 500
save_steps: int = 1000
fp16: bool = True
device: str = "cuda" if torch.cuda.is_available() else "cpu"
@classmethod
def from_yaml(cls, yaml_path: str):
with open(yaml_path, 'r') as f:
config_dict = yaml.safe_load(f)
return cls(**config_dict)
def to_training_arguments(self, output_dir: str) -> TrainingArguments:
return TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
num_train_epochs=self.num_epochs,
per_device_train_batch_size=self.batch_size,
per_device_eval_batch_size=self.batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
learning_rate=self.learning_rate,
weight_decay=self.weight_decay,
warmup_steps=self.warmup_steps,
logging_dir=f"{output_dir}/logs",
logging_steps=self.logging_steps,
eval_steps=self.eval_steps,
save_steps=self.save_steps,
evaluation_strategy="steps",
save_strategy="steps",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
greater_is_better=True,
fp16=self.fp16,
report_to=["wandb"],
run_name=f"bert-imdb-{self.model_name}"
)
自定义训练器
为了更灵活地控制训练过程,我们可以封装一个自定义 Trainer。
import torch
from torch.utils.data import DataLoader
from transformers import Trainer, AdamW, get_linear_schedule_with_warmup
from typing import Dict, List, Optional, Tuple
import numpy as np
from tqdm.auto import tqdm
class CustomTrainer:
def __init__(
self, model, train_config: TrainingConfig,
train_dataloader: DataLoader, val_dataloader: DataLoader,
test_dataloader: Optional[DataLoader] = None
):
self.model = model
self.config = train_config
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.test_dataloader = test_dataloader
self.device = torch.device(train_config.device)
self.model.to(self.device)
self.optimizer = self._create_optimizer()
self.scheduler = self._create_scheduler()
self.global_step = 0
self.best_metric = 0.0
self.history = {"train_loss": [], "val_loss": [], "val_accuracy": [], "learning_rate": []}
def _create_optimizer(self) -> torch.optim.Optimizer:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.config.weight_decay
},
{
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0
}
]
if self.config.optimizer == "adamw":
return AdamW(optimizer_grouped_parameters, lr=self.config.learning_rate, eps=1e-8)
else:
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
def _create_scheduler(self):
total_steps = len(self.train_dataloader) * self.config.num_epochs
if self.config.scheduler == "linear":
return get_linear_schedule_with_warmup(
self.optimizer, num_warmup_steps=self.config.warmup_steps, num_training_steps=total_steps
)
else:
raise ValueError(f"Unsupported scheduler: {self.config.scheduler}")
def train_epoch(self, epoch: int) -> Dict[str, float]:
self.model.train()
total_loss = 0
progress_bar = tqdm(self.train_dataloader, desc=f"Epoch {epoch}", leave=False)
for batch in progress_bar:
batch = {k: v.to(self.device) for k, v in batch.items()}
outputs = self.model(**batch)
loss = outputs.loss
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
total_loss += loss.item()
self.global_step += 1
progress_bar.set_postfix({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]})
self.history["learning_rate"].append(self.scheduler.get_last_lr()[0])
if self.global_step % self.config.eval_steps == 0:
val_metrics = self.evaluate()
self.history["val_loss"].append(val_metrics["loss"])
self.history["val_accuracy"].append(val_metrics["accuracy"])
if val_metrics["accuracy"] > self.best_metric:
self.best_metric = val_metrics["accuracy"]
self.save_model(f"best_model_step_{self.global_step}")
self.model.train()
avg_loss = total_loss / len(self.train_dataloader)
self.history["train_loss"].append(avg_loss)
return {"train_loss": avg_loss}
def evaluate(self, dataloader: Optional[DataLoader] = None) -> Dict[str, float]:
if dataloader is None:
dataloader = self.val_dataloader
self.model.eval()
total_loss = 0
all_preds = []
all_labels = []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating", leave=False):
batch = {k: v.to(self.device) for k, v in batch.items()}
outputs = self.model(**batch)
loss = outputs.loss
logits = outputs.logits
total_loss += loss.item()
preds = torch.argmax(logits, dim=-1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(batch["labels"].cpu().numpy())
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
accuracy = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="binary")
avg_loss = total_loss / len(dataloader)
return {"loss": avg_loss, "accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
def train(self):
print(f"Starting training with config: {self.config}")
print(f"Training samples: {len(self.train_dataloader.dataset)}")
print(f"Validation samples: {len(self.val_dataloader.dataset)}")
for epoch in range(self.config.num_epochs):
print(f"\n{'='*50}")
print(f"Epoch {epoch + 1}/{self.config.num_epochs}")
print(f"{'='*50}")
train_metrics = self.train_epoch(epoch)
val_metrics = self.evaluate()
print(f"\nEpoch {epoch + 1} Results:")
print(f"Train Loss: {train_metrics['train_loss']:.4f}")
print(f"Val Loss: {val_metrics['loss']:.4f}")
print(f"Val Accuracy: {val_metrics['accuracy']:.4f}")
print(f"Val F1: {val_metrics['f1']:.4f}")
if self.test_dataloader is not None:
test_metrics = self.evaluate(self.test_dataloader)
print(f"\n{'='*50}")
print("Final Test Results:")
print(f"Test Accuracy: {test_metrics['accuracy']:.4f}")
print(f"Test F1: {test_metrics['f1']:.4f}")
return self.history
def save_model(self, save_path: str):
torch.save({
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"config": self.config,
"history": self.history,
"global_step": self.global_step,
"best_metric": self.best_metric
}, f"{save_path}.pt")
self.model.save_pretrained(f"{save_path}_hf")
训练脚本入口
"""训练脚本"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import wandb
from transformers import BertForSequenceClassification
from src.data.processor import DataProcessor
from src.training.trainer import CustomTrainer, TrainingConfig
from src.models.model_utils import set_seed
def main():
set_seed(42)
wandb.init(project="bert-imdb-classification", config={"model": "bert-base-uncased", "dataset": "imdb", "epochs": 3, "batch_size": 32, "learning_rate": 2e-5})
config = TrainingConfig()
processor = DataProcessor(config.model_name)
dataset_dict = processor.load_imdb_dataset()
stats_df = processor.analyze_dataset(dataset_dict)
print("\nDataset Statistics:")
print(stats_df)
train_dataloader, val_dataloader, test_dataloader = processor.prepare_dataset(dataset_dict, batch_size=config.batch_size)
print(f"\nLoading model: {config.model_name}")
model = BertForSequenceClassification.from_pretrained(
config.model_name, num_labels=config.num_labels, hidden_dropout_prob=config.dropout_rate, attention_probs_dropout_prob=config.dropout_rate
)
trainer = CustomTrainer(
model=model, train_config=config, train_dataloader=train_dataloader, val_dataloader=val_dataloader, test_dataloader=test_dataloader
)
print("\nStarting training...")
history = trainer.train()
trainer.save_model("final_model")
wandb.log({"final_accuracy": trainer.best_metric})
wandb.finish()
print("\nTraining completed successfully!")
return history
if __name__ == "__main__":
history = main()
模型评估与测试
综合评估指标
import numpy as np
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
confusion_matrix, classification_report
)
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Any
class ModelEvaluator:
def __init__(self, model, tokenizer, device: str = "cuda"):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.model.to(device)
self.model.eval()
def predict(self, texts: List[str], batch_size: int = 32) -> Tuple[np.ndarray, np.ndarray]:
all_logits = []
all_probs = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
inputs = self.tokenizer(batch_texts, truncation=True, padding=True, max_length=512, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
all_logits.append(logits.cpu().numpy())
all_probs.append(probs.cpu().numpy())
logits_array = np.vstack(all_logits)
probs_array = np.vstack(all_probs)
return logits_array, probs_array
def evaluate_classification(self, texts: List[str], labels: List[int], threshold: float = 0.5) -> Dict[str, Any]:
logits, probs = self.predict(texts)
preds = np.argmax(probs, axis=1)
metrics = {
"accuracy": accuracy_score(labels, preds),
"precision": precision_score(labels, preds, average="binary"),
"recall": recall_score(labels, preds, average="binary"),
"f1": f1_score(labels, preds, average="binary"),
"roc_auc": roc_auc_score(labels, probs[:, 1])
}
cm = confusion_matrix(labels, preds)
report = classification_report(labels, preds, target_names=["Negative", "Positive"], output_dict=True)
confidence_scores = np.max(probs, axis=1)
return {
"metrics": metrics,
"confusion_matrix": cm,
"classification_report": report,
"predictions": preds,
"probabilities": probs,
"confidence_scores": confidence_scores
}
压力测试与性能基准
import time
from typing import Dict, List
import psutil
import GPUtil
from memory_profiler import memory_usage
class PerformanceBenchmark:
def __init__(self, model, tokenizer, device: str = "cuda"):
self.model = model
self.tokenizer = tokenizer
self.device = device
def measure_inference_time(self, texts: List[str], batch_sizes: List[int] = [1, 4, 8, 16, 32, 64]) -> Dict[int, Dict[str, float]]:
results = {}
for batch_size in batch_sizes:
print(f"\nTesting batch size: {batch_size}")
warmup_texts = ["This is a warmup sentence."] * batch_size
self.predict_batch(warmup_texts)
times = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
start_time = time.perf_counter()
self.predict_batch(batch_texts)
end_time = time.perf_counter()
times.append(end_time - start_time)
avg_time = np.mean(times)
throughput = len(texts) / np.sum(times)
results[batch_size] = {
"avg_inference_time": avg_time,
"throughput": throughput,
"samples_per_second": throughput,
"total_time": np.sum(times)
}
print(f" Average inference time: {avg_time:.4f}s")
print(f" Throughput: {throughput:.2f} samples/s")
return results
def predict_batch(self, texts: List[str]):
inputs = self.tokenizer(texts, truncation=True, padding=True, max_length=512, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
return outputs
测试框架与质量保证
单元测试
使用 pytest 和 hypothesis 保证代码健壮性。
import pytest
import tempfile
from hypothesis import given, strategies as st
from hypothesis.extra.numpy import arrays
import numpy as np
class TestDataProcessor:
def setup_method(self):
self.processor = DataProcessor("bert-base-uncased")
def test_tokenization(self):
text = "This is a test sentence."
tokenized = self.processor.tokenizer(text, truncation=True, padding="max_length", max_length=128)
assert "input_ids" in tokenized
assert "attention_mask" in tokenized
assert len(tokenized["input_ids"]) == 128
@given(st.text(min_size=1, max_size=1000), st.integers(min_value=0, max_value=1))
def test_preprocess_function(self, text, label):
examples = {"text": [text], "label": [label]}
result = self.processor.preprocess_function(examples)
assert "input_ids" in result
assert "attention_mask" in result
assert "labels" in result
assert len(result["labels"]) == 1
assert result["labels"][0] == label
class TestModel:
def setup_method(self):
self.model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
self.model.eval()
def test_model_forward(self):
batch_size = 2
seq_length = 128
input_ids = torch.randint(0, 1000, (batch_size, seq_length))
attention_mask = torch.ones((batch_size, seq_length))
with torch.no_grad():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
assert outputs.logits.shape == (batch_size, 2)
assert outputs.logits.requires_grad == False
集成测试
import asyncio
from fastapi.testclient import TestClient
import json
class TestAPI:
def setup_method(self):
from src.api.app import app
self.client = TestClient(app)
def test_health_endpoint(self):
response = self.client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "healthy"}
def test_predict_endpoint(self):
test_data = {"text": "This movie was absolutely fantastic!", "model_version": "latest"}
response = self.client.post("/predict", json=test_data)
assert response.status_code == 200
result = response.json()
assert "prediction" in result
assert "confidence" in result
assert "label" in result
assert result["confidence"] >= 0 and result["confidence"] <= 1
async def test_concurrent_requests(self):
async def make_request():
test_data = {"text": "Test concurrent request", "model_version": "latest"}
response = self.client.post("/predict", json=test_data)
return response.status_code
tasks = [make_request() for _ in range(10)]
results = await asyncio.gather(*tasks)
assert all(status == 200 for status in results)
模型部署与 API 服务
FastAPI 服务实现
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, validator
from typing import List, Optional, Dict, Any
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
from datetime import datetime
import pickle
import hashlib
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class PredictionRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=5000)
model_version: Optional[str] = "latest"
@validator('text')
def text_not_empty(cls, v):
if not v.strip():
raise ValueError('Text cannot be empty')
return v.strip()
class BatchPredictionRequest(BaseModel):
texts: List[str] = Field(..., min_items=1, max_items=100)
model_version: Optional[str] = "latest"
@validator('texts')
def validate_texts(cls, v):
if not all(text.strip() for text in v):
raise ValueError('All texts must be non-empty')
return [text.strip() for text in v]
class PredictionResponse(BaseModel):
prediction: int
label: str
confidence: float
model_version: str
request_id: str
processing_time: float
class BatchPredictionResponse(BaseModel):
predictions: List[Dict[str, Any]]
batch_id: str
total_processed: int
processing_time: float
class ModelManager:
def __init__(self, model_path: str = "./models"):
self.model_path = model_path
self.models = {}
self.active_model = None
self.model_versions = []
self.executor = ThreadPoolExecutor(max_workers=4)
self.load_models()
def load_models(self):
import os
import glob
model_dirs = glob.glob(os.path.join(self.model_path, "*"))
for model_dir in model_dirs:
if os.path.isdir(model_dir):
try:
model_name = os.path.basename(model_dir)
model = BertForSequenceClassification.from_pretrained(model_dir)
tokenizer = BertTokenizer.from_pretrained(model_dir)
self.models[model_name] = {
"model": model, "tokenizer": tokenizer,
"loaded_at": datetime.now(),
"stats": {"total_predictions": 0, "avg_response_time": 0}
}
self.model_versions.append(model_name)
if self.active_model is None:
self.active_model = model_name
logger.info(f"Loaded model: {model_name}")
except Exception as e:
logger.error(f"Failed to load model {model_dir}: {e}")
if not self.models:
logger.warning("No models found. Loading default model...")
self.load_default_model()
def load_default_model(self):
try:
model_name = "bert-base-uncased"
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
tokenizer = BertTokenizer.from_pretrained(model_name)
self.models[model_name] = {
"model": model, "tokenizer": tokenizer,
"loaded_at": datetime.now(),
"stats": {"total_predictions": 0, "avg_response_time": 0}
}
self.active_model = model_name
self.model_versions.append(model_name)
logger.info(f"Loaded default model: {model_name}")
except Exception as e:
logger.error(f"Failed to load default model: {e}")
raise
async def predict_async(self, text: str, model_version: str = "latest") -> Dict[str, Any]:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(self.executor, self.predict_sync, text, model_version)
return result
def predict_sync(self, text: str, model_version: str = "latest") -> Dict[str, Any]:
start_time = datetime.now()
model_key = model_version if model_version != "latest" else self.active_model
if model_key not in self.models:
raise ValueError(f"Model {model_key} not found")
model_info = self.models[model_key]
model = model_info["model"]
tokenizer = model_info["tokenizer"]
inputs = tokenizer(text, truncation=True, padding=True, max_length=512, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
prediction = torch.argmax(probs, dim=-1).item()
confidence = probs[0][prediction].item()
end_time = datetime.now()
processing_time = (end_time - start_time).total_seconds()
model_info["stats"]["total_predictions"] += 1
current_avg = model_info["stats"]["avg_response_time"]
total_preds = model_info["stats"]["total_predictions"]
model_info["stats"]["avg_response_time"] = ((current_avg * (total_preds - 1) + processing_time) / total_preds)
return {
"prediction": prediction,
"label": "Positive" if prediction == 1 else "Negative",
"confidence": confidence,
"model_version": model_key,
"processing_time": processing_time
}
app = FastAPI(title="BERT Text Classification API", description="API for sentiment analysis using BERT model", version="1.0.0", docs_url="/docs", redoc_url="/redoc")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
model_manager = None
request_counter = 0
cache = {}
startup_time = datetime.now()
def generate_request_id(text: str) -> str:
global request_counter
request_counter += 1
text_hash = hashlib.md5(text.encode()).hexdigest()[:8]
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
return f"{timestamp}_{request_counter}_{text_hash}"
def cache_predictions(func):
async def wrapper(text: str, model_version: str = "latest", *args, **kwargs):
cache_key = f"{model_version}:{hashlib.md5(text.encode()).hexdigest()}"
if cache_key in cache:
cached_time, result = cache[cache_key]
if (datetime.now() - cached_time).total_seconds() < 300:
logger.info(f"Cache hit for key: {cache_key[:20]}...")
return result
result = await func(text, model_version, *args, **kwargs)
cache[cache_key] = (datetime.now(), result)
if len(cache) > 1000:
oldest_key = min(cache.keys(), key=lambda k: cache[k][0])
del cache[oldest_key]
return result
return wrapper
@app.on_event("startup")
async def startup_event():
global model_manager
logger.info("Starting up BERT Classification API...")
model_manager = ModelManager("./models")
logger.info(f"Loaded {len(model_manager.models)} models")
logger.info(f"Active model: {model_manager.active_model}")
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Shutting down BERT Classification API...")
if model_manager:
model_manager.executor.shutdown()
@app.get("/")
async def root():
return {"message": "BERT Text Classification API", "version": "1.0.0", "docs": "/docs", "health": "/health", "models": "/models"}
@app.get("/health")
async def health_check():
return {"status": "healthy", "timestamp": datetime.now().isoformat(), "models_loaded": len(model_manager.models) if model_manager else 0, "active_model": model_manager.active_model if model_manager else None}
@app.post("/predict", response_model=PredictionResponse)
@cache_predictions
async def predict(request: PredictionRequest, background_tasks: BackgroundTasks):
try:
start_time = datetime.now()
request_id = generate_request_id(request.text)
result = await model_manager.predict_async(request.text, request.model_version)
processing_time = (datetime.now() - start_time).total_seconds()
background_tasks.add_task(log_prediction, request_id=request_id, text_length=len(request.text), prediction=result["prediction"], confidence=result["confidence"], processing_time=processing_time)
return PredictionResponse(
prediction=result["prediction"], label=result["label"], confidence=result["confidence"], model_version=result["model_version"], request_id=request_id, processing_time=processing_time
)
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@app.post("/predict/batch", response_model=BatchPredictionResponse)
async def predict_batch(request: BatchPredictionRequest):
try:
start_time = datetime.now()
batch_id = hashlib.md5("".join(request.texts).encode()).hexdigest()[:12]
results = await model_manager.predict_batch_async(request.texts, request.model_version)
processing_time = (datetime.now() - start_time).total_seconds()
predictions = []
for i, (text, result) in enumerate(zip(request.texts, results)):
request_id = generate_request_id(text)
predictions.append({"text_preview": text[:100] + "..." if len(text) > 100 else text, "prediction": result["prediction"], "label": result["label"], "confidence": result["confidence"], "request_id": request_id})
return BatchPredictionResponse(predictions=predictions, batch_id=batch_id, total_processed=len(results), processing_time=processing_time)
except Exception as e:
logger.error(f"Batch prediction error: {e}")
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")
@app.get("/models")
async def get_models():
if not model_manager:
raise HTTPException(status_code=500, detail="Model manager not initialized")
stats = model_manager.get_model_stats()
return {"available_models": list(model_manager.models.keys()), "active_model": model_manager.active_model, "model_stats": stats}
@app.post("/models/switch")
async def switch_model(model_version: str):
if not model_manager:
raise HTTPException(status_code=500, detail="Model manager not initialized")
success = model_manager.switch_active_model(model_version)
if success:
return {"message": f"Switched active model to {model_version}", "active_model": model_manager.active_model}
else:
raise HTTPException(status_code=400, detail=f"Model {model_version} not found")
@app.get("/metrics")
async def get_metrics():
global cache, request_counter
total_predictions = sum(model_info["stats"]["total_predictions"] for model_info in model_manager.models.values())
avg_response_times = [model_info["stats"]["avg_response_time"] for model_info in model_manager.models.values()]
avg_response_time = (sum(avg_response_times) / len(avg_response_times) if avg_response_times else 0)
return {"total_predictions": total_predictions, "total_requests": request_counter, "cache_size": len(cache), "cache_hit_rate": 0.0, "average_response_time": avg_response_time, "models_loaded": len(model_manager.models), "uptime": (datetime.now() - startup_time).total_seconds()}
@app.get("/predictions/history")
async def get_prediction_history(limit: int = 100, model_version: Optional[str] = None):
return {"message": "Prediction history endpoint", "limit": limit, "model_version": model_version}
async def log_prediction(request_id: str, text_length: int, prediction: int, confidence: float, processing_time: float):
log_entry = {"timestamp": datetime.now().isoformat(), "request_id": request_id, "text_length": text_length, "prediction": prediction, "confidence": confidence, "processing_time": processing_time}
logger.info(f"Prediction logged: {log_entry}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info", reload=True)
Docker 部署配置
# Dockerfile
FROM python:3.9-slim
WORKDIR /app
RUN apt-get update && apt-get install -y build-essential curl && rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
RUN mkdir -p models
EXPOSE 8000
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 CMD curl -f http://localhost:8000/health || exit 1
CMD ["uvicorn", "src.api.app:app", "--host", "0.0.0.0", "--port", "8000"]
version: '3.8'
services:
api:
build: .
ports:
- "8000:8000"
volumes:
- ./models:/app/models
- ./logs:/app/logs
environment:
- CUDA_VISIBLE_DEVICES=0
- MODEL_PATH=/app/models
- LOG_LEVEL=INFO
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 40s
监控与日志
结构化日志配置
生产环境建议使用 JSON 格式日志以便聚合分析。
import logging
import logging.config
import json
from datetime import datetime
from pythonjsonlogger import jsonlogger
class CustomJsonFormatter(jsonlogger.JsonFormatter):
def add_fields(self, log_record, record, message_dict):
super().add_fields(log_record, record, message_dict)
if not log_record.get('timestamp'):
log_record['timestamp'] = datetime.utcnow().isoformat()
if log_record.get('level'):
log_record['level'] = log_record['level'].upper()
else:
log_record['level'] = record.levelname
log_record['service'] = 'bert-classification-api'
log_record['module'] = record.module
log_record['function'] = record.funcName
log_record['line'] = record.lineno
LOGGING_CONFIG = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'json': {'()': CustomJsonFormatter, 'format': '%(timestamp)s %(level)s %(name)s %(message)s'},
'simple': {'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s'}
},
'handlers': {
'console': {'class': 'logging.StreamHandler', 'formatter': 'json', 'level': 'INFO'},
'file': {'class': 'logging.handlers.RotatingFileHandler', 'filename': 'logs/application.log', 'formatter': 'json', 'maxBytes': 10485760, 'backupCount': 5, 'level': 'INFO'},
'error_file': {'class': 'logging.handlers.RotatingFileHandler', 'filename': 'logs/error.log', 'formatter': 'json', 'maxBytes': 10485760, 'backupCount': 5, 'level': 'ERROR'}
},
'loggers': {
'': {'handlers': ['console', 'file', 'error_file'], 'level': 'INFO', 'propagate': True},
'uvicorn': {'handlers': ['console', 'file'], 'level': 'INFO', 'propagate': False},
'uvicorn.error': {'handlers': ['error_file'], 'level': 'ERROR', 'propagate': False}
}
}
def setup_logging():
logging.config.dictConfig(LOGGING_CONFIG)
性能监控
from prometheus_client import Counter, Histogram, Gauge, generate_latest
from prometheus_client.core import CollectorRegistry
import time
from functools import wraps
registry = CollectorRegistry()
PREDICTION_REQUESTS = Counter('prediction_requests_total', 'Total number of prediction requests', ['model_version', 'endpoint'], registry=registry)
PREDICTION_LATENCY = Histogram('prediction_latency_seconds', 'Prediction latency in seconds', ['model_version', 'endpoint'], buckets=(0.01, 0.05, 0.1, 0.5, 1.0, 5.0), registry=registry)
ACTIVE_MODELS = Gauge('active_models_total', 'Number of active models', registry=registry)
MODEL_LOAD_TIME = Histogram('model_load_time_seconds', 'Model loading time in seconds', ['model_name'], registry=registry)
CACHE_HITS = Counter('cache_hits_total', 'Total number of cache hits', registry=registry)
CACHE_MISSES = Counter('cache_misses_total', 'Total number of cache misses', registry=registry)
ERROR_COUNT = Counter('prediction_errors_total', 'Total number of prediction errors', ['error_type', 'model_version'], registry=registry)
def monitor_predictions(func):
@wraps(func)
async def wrapper(*args, **kwargs):
start_time = time.time()
model_version = kwargs.get('model_version', 'latest')
endpoint = func.__name__
PREDICTION_REQUESTS.labels(model_version=model_version, endpoint=endpoint).inc()
try:
result = await func(*args, **kwargs)
latency = time.time() - start_time
PREDICTION_LATENCY.labels(model_version=model_version, endpoint=endpoint).observe(latency)
return result
except Exception as e:
ERROR_COUNT.labels(error_type=type(e).__name__, model_version=model_version).inc()
raise
return wrapper
def update_model_metrics(model_manager):
ACTIVE_MODELS.set(len(model_manager.models))
for model_name, model_info in model_manager.models.items():
pass
@app.get("/metrics")
async def metrics_endpoint():
if model_manager:
update_model_metrics(model_manager)
return Response(generate_latest(registry), media_type="text/plain")
完整测试执行流程
端到端测试脚本
"""端到端测试脚本"""
import sys
import os
import time
import requests
import json
from typing import Dict, List, Any
import pandas as pd
import numpy as np
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.data.processor import DataProcessor
from src.training.trainer import CustomTrainer
from src.evaluation.metrics import ModelEvaluator
from src.api.app import app
from fastapi.testclient import TestClient
class EndToEndTest:
def __init__(self, api_url: str = "http://localhost:8000"):
self.api_url = api_url
self.client = TestClient(app)
self.results = {}
def run_all_tests(self):
print("="*60)
print("端到端测试开始")
print("="*60)
tests = [
self.test_environment,
self.test_data_pipeline,
self.test_model_training,
self.test_model_evaluation,
self.test_api_endpoints,
self.test_performance,
self.test_error_handling,
self.test_concurrent_requests
]
for test in tests:
try:
test_name = test.__name__
print(f"\n执行测试:{test_name}")
print("-"*40)
result = test()
self.results[test_name] = {"status": "PASSED", "result": result}
print(f"✓ {test_name}: PASSED")
except Exception as e:
self.results[test_name] = {"status": "FAILED", "error": str(e)}
print(f"✗ {test_name}: FAILED - {e}")
self.generate_report()
def test_environment(self) -> Dict[str, Any]:
python_version = sys.version_info
assert python_version.major == 3 and python_version.minor >= 8
import torch
import transformers
return {
"python_version": f"{python_version.major}.{python_version.minor}.{python_version.micro}",
"torch_version": torch.__version__,
"transformers_version": transformers.__version__,
"cuda_available": torch.cuda.is_available()
}
def test_data_pipeline(self) -> Dict[str, Any]:
processor = DataProcessor("bert-base-uncased")
test_text = "This is a test sentence for tokenization."
tokenized = processor.tokenizer(test_text, truncation=True, padding="max_length", max_length=128)
assert "input_ids" in tokenized
assert len(tokenized["input_ids"]) == 128
batch_texts = ["Text 1", "Text 2", "Text 3"]
batch_labels = [0, 1, 0]
batch = processor.collate_fn([{"input_ids": torch.tensor([101]*128), "attention_mask": torch.tensor([1]*128), "labels": label} for label in batch_labels])
assert batch["input_ids"].shape[0] == 3
assert batch["labels"].shape[0] == 3
return {"tokenization_test": "PASSED", "batch_processing_test": "PASSED"}
def test_model_training(self) -> Dict[str, Any]:
from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
train_data = torch.utils.data.TensorDataset(torch.randint(0, 1000, (100, 128)), torch.ones((100, 128)), torch.randint(0, 2, (100,)))
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=16)
val_dataloader = torch.utils.data.DataLoader(train_data, batch_size=16)
from src.training.train_config import TrainingConfig
config = TrainingConfig(batch_size=16, num_epochs=1, learning_rate=1e-5)
trainer = CustomTrainer(model=model, train_config=config, train_dataloader=train_dataloader, val_dataloader=val_dataloader)
initial_metrics = trainer.evaluate()
trainer.train_epoch(0)
final_metrics = trainer.evaluate()
assert final_metrics["loss"] <= initial_metrics["loss"] * 1.5
return {"initial_loss": initial_metrics["loss"], "final_loss": final_metrics["loss"], "training_completed": True}
def test_model_evaluation(self) -> Dict[str, Any]:
from transformers import BertForSequenceClassification, BertTokenizer
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
evaluator = ModelEvaluator(model, tokenizer, device="cpu")
test_texts = ["This is excellent!", "I don't like this at all.", "It's okay, nothing special."]
test_labels = [1, 0, 0]
results = evaluator.evaluate_classification(test_texts, test_labels)
assert "accuracy" in results["metrics"]
assert "confusion_matrix" in results
return {"evaluation_metrics": results["metrics"], "predictions_made": len(results["predictions"])}
def test_api_endpoints(self) -> Dict[str, Any]:
response = self.client.get("/health")
assert response.status_code == 200
health_data = response.json()
assert health_data["status"] == "healthy"
test_data = {"text": "This movie was absolutely fantastic! I loved every minute of it.", "model_version": "latest"}
response = self.client.post("/predict", json=test_data)
assert response.status_code == 200
prediction_data = response.json()
assert "prediction" in prediction_data
assert "confidence" in prediction_data
assert prediction_data["confidence"] >= 0
batch_data = {"texts": ["Amazing film, highly recommended!", "Not my cup of tea, unfortunately.", "The acting was superb."]}
response = self.client.post("/predict/batch", json=batch_data)
assert response.status_code == 200
batch_result = response.json()
assert len(batch_result["predictions"]) == 3
response = self.client.get("/models")
assert response.status_code == 200
models_data = response.json()
assert "available_models" in models_data
return {"health_check": "PASSED", "single_prediction": "PASSED", "batch_prediction": "PASSED", "model_management": "PASSED"}
def test_performance(self) -> Dict[str, Any]:
import time
test_cases = ["Short text", "Medium length text "*10, "Very long text "*100]
latencies = []
for text in test_cases:
start_time = time.perf_counter()
response = self.client.post("/predict", json={"text": text})
end_time = time.perf_counter()
latency = end_time - start_time
latencies.append(latency)
assert response.status_code == 200
assert latency < 5.0
batch_texts = [f"Test text {i}" for i in range(10)]
start_time = time.perf_counter()
response = self.client.post("/predict/batch", json={"texts": batch_texts})
end_time = time.perf_counter()
batch_latency = end_time - start_time
assert response.status_code == 200
assert batch_latency < 10.0
avg_latency_per_request = batch_latency / len(batch_texts)
return {"single_request_latencies": latencies, "batch_latency": batch_latency, "avg_latency_per_request": avg_latency_per_request, "throughput": len(batch_texts) / batch_latency}
def test_error_handling(self) -> Dict[str, Any]:
response = self.client.post("/predict", json={"text": ""})
assert response.status_code == 422
response = self.client.post("/predict", json={"text": "Test text", "model_version": "non-existent-model"})
assert response.status_code in [400, 500]
response = self.client.post("/predict/batch", json={"texts": []})
assert response.status_code == 422
return {"empty_text_handling": "PASSED", "invalid_model_handling": "PASSED", "empty_batch_handling": "PASSED"}
def test_concurrent_requests(self) -> Dict[str, Any]:
import concurrent.futures
import time
test_texts = [f"Concurrent test {i}" for i in range(20)]
def make_request(text):
start_time = time.perf_counter()
response = self.client.post("/predict", json={"text": text})
end_time = time.perf_counter()
return {"status_code": response.status_code, "latency": end_time - start_time}
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(make_request, text) for text in test_texts]
results = [future.result() for future in concurrent.futures.as_completed(futures)]
status_codes = [r["status_code"] for r in results]
latencies = [r["latency"] for r in results]
assert all(code == 200 for code in status_codes)
avg_latency = np.mean(latencies)
assert avg_latency < 2.0
return {"total_requests": len(results), "success_rate": sum(1 for code in status_codes if code == 200) / len(results), "avg_latency": avg_latency, "max_latency": max(latencies)}
def generate_report(self):
print("\n" + "="*60)
print("测试报告")
print("="*60)
total_tests = len(self.results)
passed_tests = sum(1 for result in self.results.values() if result["status"] == "PASSED")
print(f"总测试数:{total_tests}")
print(f"通过测试:{passed_tests}")
print(f"失败测试:{total_tests - passed_tests}")
if total_tests - passed_tests > 0:
print("\n失败详情:")
for test_name, result in self.results.items():
if result["status"] == "FAILED":
print(f" {test_name}: {result['error']}")
report_data = {
"summary": {"total_tests": total_tests, "passed_tests": passed_tests, "failed_tests": total_tests - passed_tests, "success_rate": passed_tests / total_tests if total_tests > 0 else 0},
"details": self.results,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
with open("test_report.json", "w") as f:
json.dump(report_data, f, indent=2, default=str)
print(f"\n详细报告已保存到:test_report.json")
if passed_tests == total_tests:
print("\n✓ 所有测试通过!")
return True
else:
print("\n✗ 部分测试失败!")
return False
def main():
tester = EndToEndTest()
try:
success = tester.run_all_tests()
if success:
print("\n" + "="*60)
print("端到端测试完成 - 所有测试通过!")
print("="*60)
sys.exit(0)
else:
print("\n" + "="*60)
print("端到端测试完成 - 存在失败的测试!")
print("="*60)
sys.exit(1)
except Exception as e:
print(f"\n测试执行出错:{e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()
自动化测试流水线
name: CI/CD Pipeline
on:
push:
branches: [main, develop]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9, 3.10]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov flake8 mypy
- name: Lint with flake8
run: |
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Type check with mypy
run: |
mypy src --ignore-missing-imports
- name: Test with pytest
run: |
pytest tests/ -v --cov=src --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
name: codecov-umbrella
docker:
needs: test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Build Docker image
run: |
docker build -t bert-classification-api:latest .
- name: Run Docker container
run: |
docker run -d -p 8000:8000 --name test-api bert-classification-api:latest
sleep 10
- name: Test Docker container
run: |
curl -f http://localhost:8000/health || exit 1
- name: Cleanup
run: |
docker stop test-api
docker rm test-api
优化与最佳实践
模型优化技术
在生产环境中,模型量化和 ONNX 转换能显著提升推理速度。
from optimum.onnxruntime import ORTModelForSequenceClassification
from optimum.onnxruntime import ORTQuantizer
from optimum.onnxruntime.configuration import AutoQuantizationConfig
import onnxruntime as ort
class ModelOptimizer:
def __init__(self, model_path: str):
self.model_path = model_path
def convert_to_onnx(self, output_path: str = "./models/onnx"):
model = ORTModelForSequenceClassification.from_pretrained(self.model_path, from_transformers=True, export=True)
model.save_pretrained(output_path)
return output_path
def quantize_model(self, model_path: str, quantization_config: str = "avx512_vnni"):
model = ORTModelForSequenceClassification.from_pretrained(model_path)
quantizer = ORTQuantizer.from_pretrained(model_path)
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
quantizer.quantize(save_dir=f"{model_path}_quantized", quantization_config=qconfig)
return f"{model_path}_quantized"
def optimize_with_onnxruntime(self, model_path: str):
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = 4
sess_options.inter_op_num_threads = 2
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.enable_cpu_mem_arena = True
sess_options.enable_mem_pattern = True
session = ort.InferenceSession(f"{model_path}/model.onnx", sess_options=sess_options, providers=['CPUExecutionProvider'])
return session
缓存策略
from functools import lru_cache
import hashlib
import pickle
from datetime import datetime, timedelta
class PredictionCache:
def __init__(self, max_size: int = 10000, ttl: int = 3600):
self.max_size = max_size
self.ttl = timedelta(seconds=ttl)
self.cache = {}
self.access_times = {}
self.hits = 0
self.misses = 0
def _generate_key(self, text: str, model_version: str) -> str:
content = f"{model_version}:{text}"
return hashlib.md5(content.encode()).hexdigest()
def get(self, text: str, model_version: str):
key = self._generate_key(text, model_version)
if key in self.cache:
cached_time, value = self.cache[key]
if datetime.now() - cached_time < self.ttl:
self.access_times[key] = datetime.now()
self.hits += 1
return value
self.misses += 1
return None
def set(self, text: str, model_version: str, value):
key = self._generate_key(text, model_version)
if len(self.cache) >= self.max_size:
oldest_key = min(self.access_times.keys(), key=lambda k: self.access_times[k])
del self.cache[oldest_key]
del self.access_times[oldest_key]
self.cache[key] = (datetime.now(), value)
self.access_times[key] = datetime.now()
def clear_expired(self):
now = datetime.now()
expired_keys = [key for key, (cached_time, _) in self.cache.items() if now - cached_time > self.ttl]
for key in expired_keys:
del self.cache[key]
del self.access_times[key]
def get_stats(self) -> Dict[str, Any]:
total = self.hits + self.misses
hit_rate = self.hits / total if total > 0 else 0
return {
"size": len(self.cache),
"hits": self.hits,
"misses": self.misses,
"hit_rate": hit_rate,
"max_size": self.max_size,
"ttl_seconds": self.ttl.total_seconds()
}
总结与扩展
关键收获
- 完整的 AI 集成流程:从模型选择到生产部署的全流程实践经验。
- 最佳工程实践:包括测试驱动开发、持续集成、监控告警等。
- 性能优化技巧:模型量化、缓存策略、并行处理等优化方法。
- 可扩展架构:支持多模型、多版本、高并发的系统设计。
- 全面的质量保证:单元测试、集成测试、端到端测试的完整覆盖。
扩展方向
- 多模态 AI 集成:集成图像、音频等多模态模型。
- 模型版本管理:实现 A/B 测试、金丝雀发布等高级部署策略。
- 自动扩缩容:基于流量预测的自动资源调整。
- 联邦学习支持:在边缘设备上进行分布式训练。
- 解释性 AI:添加模型解释和可视化功能。
部署检查清单
- 所有测试通过(单元测试、集成测试、端到端测试)
- 性能基准测试完成并满足 SLA 要求
- 监控和告警配置完毕
- 日志收集和分析系统就绪
- 备份和恢复策略制定
- 安全审计和漏洞扫描完成
- 文档和运行手册编写完成
- 灾难恢复计划制定
成功的 AI 项目不仅仅是模型准确率,更是系统工程、可维护性、可扩展性和可靠性的综合体现。这套方法论和代码框架可以应用于各种 AI 项目,为 AI 应用开发提供坚实基础。
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- 随机西班牙地址生成器
随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
- Gemini 图片去水印
基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online