PyTorch 通过带有 Flask 的 REST API 在 Python 中部署 PyTorch
在机器学习项目的实际应用中,将训练好的模型部署为服务,使其能够接收外部请求并返回预测结果,是实现模型价值的关键一步。Flask 作为 Python 的轻量级 Web 框架,凭借其简洁易用的特性,成为部署 PyTorch 模型的理想选择之一。本文将详细指导您如何使用 Flask 将 PyTorch 模型部署为 REST API 服务,以预训练的 DenseNet 121 模型为例,实现图像分类功能。
一、环境搭建与依赖安装
在开始部署之前,确保已安装所需的依赖库。运行以下命令以安装 Flask 和 torchvision:
pip install Flask torchvision
二、创建简单的 Web 服务器
首先,我们创建一个基本的 Flask Web 服务器,后续将在此基础上添加模型推理功能。
from flask import Flask
app = Flask(__name__)
@app.route('/')
def hello():
return 'Hello World!'
if __name__ == '__main__':
app.run()
保存上述代码为 app.py
,运行 Flask 开发服务器:
FLASK_ENV=development FLASK_APP=app.py flask run
访问 http://localhost:5000/
,您将看到 "Hello World!" 文字,这表明服务器已成功启动。
三、定义 API 端点与推理逻辑
我们将定义一个 /predict
端点,用于接收包含图像文件的 HTTP POST 请求,并返回预测结果。
(一)图像预处理
DenseNet 121 模型要求输入图像为 224 x 224 的 3 通道 RGB 图像,且需进行归一化处理。我们使用 torchvision.transforms
构建图像预处理管道:
import io
import torchvision.transforms as transforms
from PIL import Image
def transform_image(image_bytes):
my_transforms = transforms.Compose([
transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
(二)加载预训练模型
加载预训练的 DenseNet 121 模型,并设置为评估模式:
from torchvision import models
model = models.densenet121(pretrained=True)
model.eval()
(三)图像分类预测
编写函数以获取图像的预测类别:
import json
imagenet_class_index = json.load(open('imagenet_class_index.json')) # 请替换为实际文件路径
def get_prediction(image_bytes):
tensor = transform_image(image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
四、整合 Flask API 服务器
将模型推理功能整合到 Flask 服务器中,完成 API 的定义:
from flask import jsonify, request
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
app.run()
完整的 app.py
文件如下:
import io
import json
from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
app = Flask(__name__)
imagenet_class_index = json.load(open('imagenet_class_index.json')) # 请替换为实际文件路径
model = models.densenet121(pretrained=True)
model.eval()
def transform_image(image_bytes):
my_transforms = transforms.Compose([
transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
def get_prediction(image_bytes):
tensor = transform_image(image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
app.run()
五、测试与验证
运行 Flask 服务器:
FLASK_ENV=development FLASK_APP=app.py flask run
使用 requests
库发送 POST 请求进行测试:
import requests
resp = requests.post("http://localhost:5000/predict", files={"file": open('cat.jpg', 'rb')}) # 请替换为实际图像文件路径
print(resp.json())
成功返回结果示例:
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
六、总结与优化建议
通过本文,您已成功使用 Flask 部署了一个 PyTorch 模型,并通过 REST API 提供图像分类服务。然而,当前的实现较为基础,对于生产环境,您可以考虑以下优化措施:
- 增强错误处理机制 :对请求中未包含图像文件、文件类型错误等情况进行处理,返回友好的错误提示信息。
- 提升性能与扩展性 :采用更高效的服务器架构(如 Gunicorn)替代开发服务器,结合模型量化、剪枝等技术优化模型性能,以适应高并发场景。
- 添加认证与授权 :为 API 添加身份验证,确保服务的安全性。
- 构建前端界面 :开发一个简单的 Web 界面,允许用户上传图像并显示预测结果,提升用户体验。
模型部署是连接模型开发与实际应用的桥梁,掌握这一技能,能够使您的模型真正发挥价值,解决实际问题。编程狮将持续为您带来更多模型部署与应用开发的实用教程。
更多建议: