基于 Flask 与 PyTorch 的图像分类 API 服务搭建
本文介绍如何基于 Flask 和 PyTorch 构建图像分类 API 服务。服务端负责加载预训练模型、接收图片上传、预处理及返回预测结果;客户端负责读取本地图片、发送 POST 请求并解析返回的 JSON 数据。通过该流程可实现训练好的模型通过 HTTP 接口对外提供预测服务。

本文介绍如何基于 Flask 和 PyTorch 构建图像分类 API 服务。服务端负责加载预训练模型、接收图片上传、预处理及返回预测结果;客户端负责读取本地图片、发送 POST 请求并解析返回的 JSON 数据。通过该流程可实现训练好的模型通过 HTTP 接口对外提供预测服务。

在深度学习项目落地时,将训练好的模型封装成可远程调用的 API 接口是核心环节。本文将完整讲解如何基于 Flask(轻量级 Web 框架)和 PyTorch(深度学习框架),实现图像分类服务端 API 开发 + 客户端调用的全流程,让训练好的模型通过 HTTP 接口对外提供预测服务。
服务端的核心职责是:加载预训练模型、接收客户端上传的图片、预处理图片、执行预测、返回 JSON 格式的预测结果。
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models
# 初始化 Flask 应用
app = flask.Flask(__name__)
# 全局变量:模型实例、是否使用 GPU
model = None
use_gpu = False
# 测试阶段建议关闭 GPU,避免环境问题
def load_model():
"""加载预训练的 ResNet18 模型(可替换为自定义模型)"""
global model
# 1. 加载 ResNet18 主干网络
model = models.resnet18(pretrained=False)
# 不加载默认预训练权重
# 2. 修改全连接层,适配自定义分类任务(示例为 102 分类)
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
# 3. 加载训练好的权重文件
checkpoint = torch.load('best.pth', map_location='cpu')
# 强制使用 CPU,避免 GPU/CPU 不匹配
model.load_state_dict(checkpoint['state_dict'])
# 4. 设置模型为评估模式(禁用 Dropout/BatchNorm 更新)
model.eval()
# 5. 可选:使用 GPU(需确保环境有 CUDA)
if use_gpu and torch.cuda.is_available():
model.cuda()
def prepare_image(image, target_size=(224, 224)):
"""预处理图片:转为 RGB、resize、归一化、增加 batch 维度"""
# 1. 统一转为 RGB 格式(避免灰度图等格式异常)
if image.mode != 'RGB':
image = image.convert('RGB')
# 2. 图像变换:resize -> 转 Tensor -> 归一化
transform = transforms.Compose([
transforms.Resize(target_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet 均值/方差
])
image = transform(image)
# 3. 增加 batch 维度(模型要求输入为 [batch, channel, h, w])
image = image.unsqueeze(0)
# 4. 可选:转到 GPU
if use_gpu and torch.cuda.is_available():
image = image.cuda()
return image
@app.route("/predict", methods=["POST"])
def predict():
"""API 核心接口:接收图片,返回预测结果"""
# 初始化返回数据
data = {"success": False}
# 仅处理 POST 请求
if flask.request.method == "POST":
# 检查是否有图片上传
if flask.request.files.get("image"):
try:
# 1. 读取并解析图片
image_bytes = flask.request.files["image"].read()
image = Image.open(io.BytesIO(image_bytes))
# 2. 预处理图片
image = prepare_image(image)
# 3. 模型预测(禁用梯度计算,提升速度)
with torch.no_grad():
preds = F.softmax(model(image), dim=1)
# 计算类别概率
# 4. 获取概率最高的前 3 个结果
top3_probs, top3_labels = torch.topk(preds.cpu(), k=3, dim=1)
# 5. 格式化结果
data['predictions'] = []
for prob, label in zip(top3_probs.numpy()[0], top3_labels.numpy()[0]):
data['predictions'].append({
"label": str(label),
"probability": float(prob)
})
# 6. 标记请求成功
data["success"] = True
except Exception as e:
print(f"预测出错:{str(e)}")
# 返回 JSON 格式结果(HTTP 响应)
return flask.jsonify(data)
if __name__ == '__main__':
print("加载 PyTorch 模型中...")
load_model()
# 启动前先加载模型
print("模型加载完成,启动 Flask 服务...")
# 启动服务:host=0.0.0.0 允许局域网访问,port 为自定义端口
app.run(host='0.0.0.0', port=5012, debug=False)
load_model()best.pth,设置 map_location='cpu' 避免 GPU/CPU 环境不匹配。model.eval() 将模型设为评估模式(禁用 Dropout、BatchNorm 等训练层)。prepare_image()predict()flask.request.files 读取客户端上传的二进制图片。torch.no_grad() 禁用梯度计算,大幅提升预测速度。torch.topk 获取概率最高的前 3 个类别,格式化后以 JSON 返回。客户端的核心职责是:读取本地图片、以二进制形式上传到服务端 API、解析返回的 JSON 结果并展示。
import requests
# 客户端的程序代码:功能;负责按照指定格式上传图片和接受结果
flask_url = 'http://127.0.0.1:5012/predict' # URL 和端口写成自己的本地 IP
def predict_result(image_path):
image = open(image_path, 'rb').read()
payload = {'image': image}
r = requests.post(flask_url, files=payload).json() # 网络传输逻辑
if r['success']:
# 成功的话再返回
# 输出结果
for i, result in enumerate(r['predictions']):
print('{}.预测类别为{}:的概率:{}'.format(i + 1, result['label'], result['probability']))
else:
print('Request failed')
if __name__ == "__main__":
predict_result('./flower_data/2.jpg')
requests 库:这是 Python 中处理 HTTP 请求的主流第三方库,不仅是爬虫工具,更是接口调用的核心工具,能轻松实现 POST/GET 等请求方式。127.0.0.1 是本地回环地址,对应运行 Flask 服务端的本机;5012 是服务端设置的端口号;/predict 是服务端开放的预测接口路径,必须与服务端的路由地址完全一致。predict_resultopen(image_path, 'rb') 以二进制只读模式打开图片,read() 将文件内容读取为二进制字节流 —— 网络传输图片必须用二进制格式,不能直接传文件路径或文本格式。payload,键名 image 是服务端约定的接收图片的参数名(需与服务端 flask.request.files.get("image") 中的 image 一致),值为读取的二进制图片数据。requests.post():向指定的 flask_url 发送 POST 请求,files 参数专门用于上传文件 / 二进制数据。.json():将服务端返回的 JSON 格式字符串自动解析为 Python 字典,方便后续取值。success 为 True 表示预测成功,predictions 是包含前 3 个高概率类别结果的列表。enumerate 为结果添加序号,通过 result['label'] 和 result['probability'] 分别获取类别标签和对应概率,格式化输出。success 为 False,则打印请求失败提示。服务端正常响应时,客户端控制台会输出:
1.预测类别为 5:的概率:0.9875643849372864
2.预测类别为 8:的概率:0.0089123456789012
3.预测类别为 12:的概率:0.0021098765432109

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online