果蔬识别系统性能优化之路
目录
- 一级目录
- 二级目录
- 三级目录
- 前情提要
- 当前问题
- 优化方案
- 1. 内存化
- 2. 原生化
- 3. 接口化
- 行动
- 实现
- 结语
一级目录
二级目录
三级目录
前情提要
超详细前端AI蔬菜水果生鲜识别应用优化之路
当前问题
- indexddb在webview中确实性能有限,存储量上来后每次读取数据会有明显卡顿
- 目前的余弦相邻算法是基于所有特征向量数据进行的计算,一旦数据量大了后计算量也是一个十分消耗性能的点
- 由于机器性能问题,本地化加载模型在每次进入页面后需要消耗很长一段时间进行模型的加载,体验相当不好
优化方案
1. 内存化
方式:为了解决indexddb读取速度的问题,最直接的方式就是把数据放内存对象中,提前将indexddb的数据取出来,然后在数据变化时进行同步
结果:确实省去了读取这一步会快很多,但是仍有隐患,webview分配的内存是否这种占用内存的方式
2. 原生化
本地化模型加载可以在原生层面进行模型的加载,识别,学习,相当于把整套方案在原生端实现一次,通过bridge调用原生相关方法完成识别,但目前暂无学习原生的意向,所以搁置
3. 接口化
方式:将数据存储在mysql,识别单独起一个python服务,通过接口调用,利用服务器的性能优势使识别速度提升,保证网络消耗在合理范围即可
行动
最终选择了接口化,在保证网络的情况下,识别速度在300ms内可被接受
- nestjs搭建服务端,分feature和img两个模块
- mysql搭建数据库,建立feature和img两个表,通过imgId和feature表关联,同时feature表包含storeCode字段,用来管理门店
- 搭建redis,代替内存化方案,实现快速读取
- 搭建python服务端,与nestjs服务端进行通信,进行识别结果的传输
- 通过IVF方式提升计算速度,保证大量特征值情况下仍然可以快速算出相似结果
实现
- python端,flask搭建http服务,当然后续可以改成其他和服务端通信方式,提升速度
- 实现识别接口恶化同步接口,用于识别图片特征值和同步数据库存储的特征向量
app.py
from flask import Flask, request, jsonify
from flask_cors import CORS
from detect import MainDetect
from tensorflow.keras import layers, models
app = Flask(__name__)
CORS(app) # 允许所有路由上的跨域请求
detector = MainDetect()
@app.route('/')
def home():
return "Welcome to the Vegetable Recognize App!"
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return 'No file part', 400
file = request.files['file']
store_code = request.form["storeCode"]
if file.filename == '':
return 'No selected file', 400
try:
image_data = file.read()
data = detector.classify_image(image_data, store_code)
outputs = data["outputs"]
time = data["time"]
index = data["index"]
return jsonify({"predictTime": time, "features": outputs, "index": index})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/sync', methods=['POST'])
def sync():
data = request.get_json()
arr = data.get('data')
store_code = data.get('storeCode')
detector.sync(store_code, arr)
return jsonify({"message": 'ok'})
detect.py
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.applications.mobilenet import preprocess_input, decode_predictions
from PIL import Image
import numpy as np
import cv2
import time
import io
import gc
from ivf import IVFPQ
model = MobileNet(input_shape=(224, 224, 3), weights='imagenet', include_top=False, pooling='avg')
class MainDetect:
# 初始化
def __init__(self):
super().__init__()
# 模型初始化
self.image_id = None
self.image_features = None
self.model = tf.keras.models.load_model("models/custom/my-model.h5")
self.ivfObj = {}
def classify_image(self, image_data, store_code):
# Load and preprocess image
img = tf.image.decode_image(image_data, channels=3)
img = tf.image.resize(img, [224, 224])
img = tf.expand_dims(img, axis=0) # Add batch dimension
# Run model prediction
start_time = time.time()
outputs = model.predict(img)
# outputs = self.model.predict(outputs)
# prediction = tf.divide(outputs, tf.norm(outputs))
i = []
if store_code + '-featureDatabase' in self.ivfObj:
i = self.ivfObj[store_code + '-featureDatabase'].search(outputs)
i = i.flatten().tolist()
end_time = time.time()
# Calculate elapsed time
elapsed_time = end_time - start_time
# Flatten the outputs and return them
# output_data = prediction.numpy().flatten().tolist()
output_data = outputs.flatten().tolist()
# Force garbage collection to free up memory
del img, outputs, end_time, start_time # Ensure variables are deleted
gc.collect()
return {"outputs": output_data, "time": f"{elapsed_time * 1000:.2f}ms", "index": i}
def sync(self, store_code, data):
if store_code + '-featureDatabase' in self.ivfObj:
del self.ivfObj[store_code + '-featureDatabase']
self.ivfObj[store_code + '-featureDatabase'] = IVFPQ()
for item in data:
feature = item['features']
self.ivfObj[store_code + '-featureDatabase'].add(np.array([feature], dtype=np.float32))
return 'ok'
ivf.py
import faiss
import numpy as np
class IVFPQ:
def __init__(self, d=1024, nlist=1, m=16, n_bits=8):
# 创建量化器
quantizer = faiss.IndexFlatL2(d) # 使用L2距离进行量化
self.index = faiss.IndexIVFPQ(quantizer, d, nlist, m, n_bits)
np.random.seed(1234)
xb = np.random.random((256, d)).astype('float32') # 模拟数据库中的特征向量
# 训练索引
self.index.train(xb)
self.index.add(xb) # 将特征向量添加到索引中
def search(self, xq, k=50):
d, i = self.index.search(xq, k)
return i
def add(self, xb):
self.index.add(xb)
def sync(self, features):
for i in range(len(features)):
self.add(features[i])
- nestjs进行接口的转发和数据处理
识别的service,调用python服务,并通过返回的索引找到正确的目标label
/**
* 预测
* @param file
* @param num
* @param storeCode
* @param justPredict
*/
async predict(file: Express.Multer.File, num: string = '5', storeCode: string, justPredict: Boolean = false) {
const url = 'http://localhost:5000/predict'; // Python 服务的 URL
const startTime = Date.now();
try {
// 返回 Python 服务的响应数据
const formData = new FormData();
formData.append('file', file.buffer, file.originalname);
formData.append('storeCode', storeCode);
const response = await firstValueFrom(this.httpService.post(url, formData));
const endTime = Date.now();
const features = response.data.features;
const index = response.data.index;
if (justPredict) {
return features;
}
// const top5 = await this.findTopNSimilar(features, parseInt(num), storeCode);
const featureDatabaseStr = await this.redisService.get(`${storeCode}-featureDatabase`);
if (!featureDatabaseStr) {
return response.data = {
...response.data,
[`top${num}`]: [],
features,
totalTime: `${endTime - startTime}ms`,
};
}
const featureDatabase = JSON.parse(featureDatabaseStr);
const list = [];
index.forEach((i: number) => {
const ide = i - 256;
if (ide >= 0) {
const item = featureDatabase[ide];
if (!list.some(l => l.label === item.label)) {
list.push({ label: item.label });
}
}
});
return response.data = {
...response.data,
[`top${num}`]: list,
features,
totalTime: `${endTime - startTime}ms`,
};
} catch (error) {
// 错误处理
console.error('Error calling Python service:', error);
throw error;
}
}
同步python端特征向量,在数据库的增删改查时进行调用
/**
* 同步redis
* @param storeCode
*/
async syncRedis(storeCode: string) {
const featureDatabase = await this.findAll(storeCode);
await this.redisService.set(`${storeCode}-featureDatabase`, JSON.stringify(featureDatabase));
const url = 'http://localhost:5000/sync'; // Python 服务的 URL
await firstValueFrom(this.httpService.post(url, { data: featureDatabase, storeCode }));
}
结语
没规划好,其实全部用python实现应该更自然,后续有时间再更新,先上线跑跑