PyTorch 通过带有 Flask 的 REST API 在 Python 中部署 PyTorch

2025-06-18 17:13 更新

在机器学习项目的实际应用中,将训练好的模型部署为服务,使其能够接收外部请求并返回预测结果,是实现模型价值的关键一步。Flask 作为 Python 的轻量级 Web 框架,凭借其简洁易用的特性,成为部署 PyTorch 模型的理想选择之一。本文将详细指导您如何使用 Flask 将 PyTorch 模型部署为 REST API 服务,以预训练的 DenseNet 121 模型为例,实现图像分类功能。

一、环境搭建与依赖安装

在开始部署之前,确保已安装所需的依赖库。运行以下命令以安装 Flask 和 torchvision:

pip install Flask torchvision

二、创建简单的 Web 服务器

首先,我们创建一个基本的 Flask Web 服务器,后续将在此基础上添加模型推理功能。

from flask import Flask


app = Flask(__name__)


@app.route('/')
def hello():
    return 'Hello World!'


if __name__ == '__main__':
    app.run()

保存上述代码为 app.py,运行 Flask 开发服务器:

FLASK_ENV=development FLASK_APP=app.py flask run

访问 http://localhost:5000/,您将看到 "Hello World!" 文字,这表明服务器已成功启动。

三、定义 API 端点与推理逻辑

我们将定义一个 /predict 端点,用于接收包含图像文件的 HTTP POST 请求,并返回预测结果。

(一)图像预处理

DenseNet 121 模型要求输入图像为 224 x 224 的 3 通道 RGB 图像,且需进行归一化处理。我们使用 torchvision.transforms 构建图像预处理管道:

import io
import torchvision.transforms as transforms
from PIL import Image


def transform_image(image_bytes):
    my_transforms = transforms.Compose([
        transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

(二)加载预训练模型

加载预训练的 DenseNet 121 模型,并设置为评估模式:

from torchvision import models


model = models.densenet121(pretrained=True)
model.eval()

(三)图像分类预测

编写函数以获取图像的预测类别:

import json


imagenet_class_index = json.load(open('imagenet_class_index.json'))  # 请替换为实际文件路径


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

四、整合 Flask API 服务器

将模型推理功能整合到 Flask 服务器中,完成 API 的定义:

from flask import jsonify, request


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()

完整的 app.py 文件如下:

import io
import json
from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)


imagenet_class_index = json.load(open('imagenet_class_index.json'))  # 请替换为实际文件路径
model = models.densenet121(pretrained=True)
model.eval()


def transform_image(image_bytes):
    my_transforms = transforms.Compose([
        transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()

五、测试与验证

运行 Flask 服务器:

FLASK_ENV=development FLASK_APP=app.py flask run

使用 requests 库发送 POST 请求进行测试:

import requests


resp = requests.post("http://localhost:5000/predict", files={"file": open('cat.jpg', 'rb')})  # 请替换为实际图像文件路径
print(resp.json())

成功返回结果示例:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

六、总结与优化建议

通过本文,您已成功使用 Flask 部署了一个 PyTorch 模型,并通过 REST API 提供图像分类服务。然而,当前的实现较为基础,对于生产环境,您可以考虑以下优化措施:

  • 增强错误处理机制 :对请求中未包含图像文件、文件类型错误等情况进行处理,返回友好的错误提示信息。
  • 提升性能与扩展性 :采用更高效的服务器架构(如 Gunicorn)替代开发服务器,结合模型量化、剪枝等技术优化模型性能,以适应高并发场景。
  • 添加认证与授权 :为 API 添加身份验证,确保服务的安全性。
  • 构建前端界面 :开发一个简单的 Web 界面,允许用户上传图像并显示预测结果,提升用户体验。

模型部署是连接模型开发与实际应用的桥梁,掌握这一技能,能够使您的模型真正发挥价值,解决实际问题。编程狮将持续为您带来更多模型部署与应用开发的实用教程。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号