当前位置: 首页 > article >正文

果蔬识别系统性能优化之路

目录

  • 一级目录
    • 二级目录
      • 三级目录
    • 前情提要
      • 当前问题
    • 优化方案
      • 1. 内存化
      • 2. 原生化
      • 3. 接口化
    • 行动
    • 实现
    • 结语

一级目录

二级目录

三级目录

前情提要

超详细前端AI蔬菜水果生鲜识别应用优化之路

当前问题

  1. indexddb在webview中确实性能有限,存储量上来后每次读取数据会有明显卡顿
  2. 目前的余弦相邻算法是基于所有特征向量数据进行的计算,一旦数据量大了后计算量也是一个十分消耗性能的点
  3. 由于机器性能问题,本地化加载模型在每次进入页面后需要消耗很长一段时间进行模型的加载,体验相当不好

优化方案

1. 内存化

方式:为了解决indexddb读取速度的问题,最直接的方式就是把数据放内存对象中,提前将indexddb的数据取出来,然后在数据变化时进行同步
结果:确实省去了读取这一步会快很多,但是仍有隐患,webview分配的内存是否这种占用内存的方式

2. 原生化

本地化模型加载可以在原生层面进行模型的加载,识别,学习,相当于把整套方案在原生端实现一次,通过bridge调用原生相关方法完成识别,但目前暂无学习原生的意向,所以搁置

3. 接口化

方式:将数据存储在mysql,识别单独起一个python服务,通过接口调用,利用服务器的性能优势使识别速度提升,保证网络消耗在合理范围即可

行动

最终选择了接口化,在保证网络的情况下,识别速度在300ms内可被接受

  1. nestjs搭建服务端,分feature和img两个模块
  2. mysql搭建数据库,建立feature和img两个表,通过imgId和feature表关联,同时feature表包含storeCode字段,用来管理门店
  3. 搭建redis,代替内存化方案,实现快速读取
  4. 搭建python服务端,与nestjs服务端进行通信,进行识别结果的传输
  5. 通过IVF方式提升计算速度,保证大量特征值情况下仍然可以快速算出相似结果

实现

  1. python端,flask搭建http服务,当然后续可以改成其他和服务端通信方式,提升速度
  2. 实现识别接口恶化同步接口,用于识别图片特征值和同步数据库存储的特征向量
    app.py
from flask import Flask, request, jsonify
from flask_cors import CORS
from detect import MainDetect
from tensorflow.keras import layers, models

app = Flask(__name__)
CORS(app)  # 允许所有路由上的跨域请求
detector = MainDetect()


@app.route('/')
def home():
    return "Welcome to the Vegetable Recognize App!"


@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files:
        return 'No file part', 400

    file = request.files['file']
    store_code = request.form["storeCode"]
    if file.filename == '':
        return 'No selected file', 400

    try:
        image_data = file.read()
        data = detector.classify_image(image_data, store_code)
        outputs = data["outputs"]
        time = data["time"]
        index = data["index"]
        return jsonify({"predictTime": time, "features": outputs, "index": index})

    except Exception as e:
        return jsonify({'error': str(e)}), 500


@app.route('/sync', methods=['POST'])
def sync():
    data = request.get_json()
    arr = data.get('data')
    store_code = data.get('storeCode')
    detector.sync(store_code, arr)
    return jsonify({"message": 'ok'})

detect.py

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.applications.mobilenet import preprocess_input, decode_predictions
from PIL import Image
import numpy as np
import cv2
import time
import io
import gc
from ivf import IVFPQ

model = MobileNet(input_shape=(224, 224, 3), weights='imagenet', include_top=False, pooling='avg')


class MainDetect:
    # 初始化
    def __init__(self):
        super().__init__()
        # 模型初始化
        self.image_id = None
        self.image_features = None
        self.model = tf.keras.models.load_model("models/custom/my-model.h5")
        self.ivfObj = {}

    def classify_image(self, image_data, store_code):
        # Load and preprocess image
        img = tf.image.decode_image(image_data, channels=3)
        img = tf.image.resize(img, [224, 224])
        img = tf.expand_dims(img, axis=0)  # Add batch dimension

        # Run model prediction
        start_time = time.time()
        outputs = model.predict(img)
        # outputs = self.model.predict(outputs)
        # prediction = tf.divide(outputs, tf.norm(outputs))
        i = []
        if store_code + '-featureDatabase' in self.ivfObj:
            i = self.ivfObj[store_code + '-featureDatabase'].search(outputs)
            i = i.flatten().tolist()

        end_time = time.time()

        # Calculate elapsed time
        elapsed_time = end_time - start_time

        # Flatten the outputs and return them
        # output_data = prediction.numpy().flatten().tolist()
        output_data = outputs.flatten().tolist()

        # Force garbage collection to free up memory
        del img, outputs, end_time, start_time  # Ensure variables are deleted
        gc.collect()

        return {"outputs": output_data, "time": f"{elapsed_time * 1000:.2f}ms", "index": i}

    def sync(self, store_code, data):
        if store_code + '-featureDatabase' in self.ivfObj:
            del self.ivfObj[store_code + '-featureDatabase']
        self.ivfObj[store_code + '-featureDatabase'] = IVFPQ()
        for item in data:
            feature = item['features']
            self.ivfObj[store_code + '-featureDatabase'].add(np.array([feature], dtype=np.float32))
        return 'ok'

ivf.py

import faiss
import numpy as np


class IVFPQ:
    def __init__(self, d=1024, nlist=1, m=16, n_bits=8):
        # 创建量化器
        quantizer = faiss.IndexFlatL2(d)  # 使用L2距离进行量化
        self.index = faiss.IndexIVFPQ(quantizer, d, nlist, m, n_bits)

        np.random.seed(1234)
        xb = np.random.random((256, d)).astype('float32')  # 模拟数据库中的特征向量
        # 训练索引
        self.index.train(xb)
        self.index.add(xb)  # 将特征向量添加到索引中

    def search(self, xq, k=50):
        d, i = self.index.search(xq, k)
        return i

    def add(self, xb):
        self.index.add(xb)

    def sync(self, features):
        for i in range(len(features)):
            self.add(features[i])

  1. nestjs进行接口的转发和数据处理
    识别的service,调用python服务,并通过返回的索引找到正确的目标label
  /**
   * 预测
   * @param file
   * @param num
   * @param storeCode
   * @param justPredict
   */
  async predict(file: Express.Multer.File, num: string = '5', storeCode: string, justPredict: Boolean = false) {
    const url = 'http://localhost:5000/predict'; // Python 服务的 URL
    const startTime = Date.now();
    try {
      // 返回 Python 服务的响应数据
      const formData = new FormData();
      formData.append('file', file.buffer, file.originalname);
      formData.append('storeCode', storeCode);
      const response = await firstValueFrom(this.httpService.post(url, formData));
      const endTime = Date.now();
      const features = response.data.features;
      const index = response.data.index;
      if (justPredict) {
        return features;
      }
      // const top5 = await this.findTopNSimilar(features, parseInt(num), storeCode);

      const featureDatabaseStr = await this.redisService.get(`${storeCode}-featureDatabase`);
      if (!featureDatabaseStr) {
        return response.data = {
          ...response.data,
          [`top${num}`]: [],
          features,
          totalTime: `${endTime - startTime}ms`,
        };
      }
      const featureDatabase = JSON.parse(featureDatabaseStr);
      const list = [];
      index.forEach((i: number) => {
        const ide = i - 256;
        if (ide >= 0) {
          const item = featureDatabase[ide];
          if (!list.some(l => l.label === item.label)) {
            list.push({ label: item.label });
          }
        }
      });
      return response.data = {
        ...response.data,
        [`top${num}`]: list,
        features,
        totalTime: `${endTime - startTime}ms`,
      };
    } catch (error) {
      // 错误处理
      console.error('Error calling Python service:', error);
      throw error;
    }
  }

同步python端特征向量,在数据库的增删改查时进行调用

  /**
   * 同步redis
   * @param storeCode
   */
  async syncRedis(storeCode: string) {
    const featureDatabase = await this.findAll(storeCode);
    await this.redisService.set(`${storeCode}-featureDatabase`, JSON.stringify(featureDatabase));
    const url = 'http://localhost:5000/sync'; // Python 服务的 URL
    await firstValueFrom(this.httpService.post(url, { data: featureDatabase, storeCode }));
  }

结语

没规划好,其实全部用python实现应该更自然,后续有时间再更新,先上线跑跑


http://www.kler.cn/a/290998.html

相关文章:

  • 26考研资料分享 百度网盘
  • 深入探索 Nginx 的高级用法:解锁 Web 服务器的强大潜能
  • (三)线性代数之二阶和三阶行列式详解
  • 【重庆市乡镇界】面图层shp格式arcgis数据乡镇名称和编码wgs84坐标无偏移内容测评
  • 【0x04】HCI_Connection_Request事件详解
  • 微信小程序:实现单选,多选,通过变量控制单选/多选
  • 集成电路学习:什么是MCU微控制器
  • 软件测试中错误推断法(错误猜测法或错误推测法)
  • 排序
  • MATLAB基础应用精讲-【数模应用】极差分析(附MATLAB、python和R语言代码实现)
  • 深度学习(八)-图像色彩操作
  • 基于FCM模糊聚类算法的图像分割matlab仿真
  • 【小设计】基于宏实现的C++ 可复用setter 和getter设计
  • 嵌入式面经 嵌入式软件开发 嵌入式硬件开发 经纬恒润嵌入式面试汇总总结
  • RK3588平台开发系列讲解(显示篇)图像的宽高和跨距
  • scss中的mix函数
  • 基于深度学习的人机交互中的认知模型
  • Google 的 9 年职业生涯回顾
  • ubuntu通过smba访问华为设备
  • 线性表之栈
  • ThreadLocal 在线程池中的内存泄漏问题
  • JavaAgent技术原理
  • MybatisPlus入门
  • Android Radio2.0——交通公告状态设置(二)
  • 【20.1 python中的Web基础】
  • 云计算之数据库