当前位置: 首页 > article >正文

果蔬识别系统性能优化之路(一)

目录

    • 前情提要
      • 当前问题
    • 优化方案
      • 1. 内存化
      • 2. 原生化
      • 3. 接口化
    • 行动
    • 实现
    • 结语

前情提要

超详细前端AI蔬菜水果生鲜识别应用优化之路

当前问题

  1. indexddb在webview中确实性能有限,存储量上来后每次读取数据会有明显卡顿
  2. 目前的余弦相邻算法是基于所有特征向量数据进行的计算,一旦数据量大了后计算量也是一个十分消耗性能的点
  3. 由于机器性能问题,本地化加载模型在每次进入页面后需要消耗很长一段时间进行模型的加载,体验相当不好

优化方案

1. 内存化

方式:为了解决indexddb读取速度的问题,最直接的方式就是把数据放内存对象中,提前将indexddb的数据取出来,然后在数据变化时进行同步
结果:确实省去了读取这一步会快很多,但是仍有隐患,webview分配的内存是否这种占用内存的方式

2. 原生化

本地化模型加载可以在原生层面进行模型的加载,识别,学习,相当于把整套方案在原生端实现一次,通过bridge调用原生相关方法完成识别,但目前暂无学习原生的意向,所以搁置

3. 接口化

方式:将数据存储在mysql,识别单独起一个python服务,通过接口调用,利用服务器的性能优势使识别速度提升,保证网络消耗在合理范围即可

行动

最终选择了接口化,在保证网络的情况下,识别速度在300ms内可被接受

  1. nestjs搭建服务端,分feature和img两个模块
  2. mysql搭建数据库,建立feature和img两个表,通过imgId和feature表关联,同时feature表包含storeCode字段,用来管理门店
  3. 搭建redis,代替内存化方案,实现快速读取
  4. 搭建python服务端,与nestjs服务端进行通信,进行识别结果的传输
  5. 通过IVF方式提升计算速度,保证大量特征值情况下仍然可以快速算出相似结果

实现

  1. python端,flask搭建http服务,当然后续可以改成其他和服务端通信方式,提升速度
  2. 实现识别接口恶化同步接口,用于识别图片特征值和同步数据库存储的特征向量
    app.py
from flask import Flask, request, jsonify
from flask_cors import CORS
from detect import MainDetect
from tensorflow.keras import layers, models

app = Flask(__name__)
CORS(app)  # 允许所有路由上的跨域请求
detector = MainDetect()


@app.route('/')
def home():
    return "Welcome to the Vegetable Recognize App!"


@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files:
        return 'No file part', 400

    file = request.files['file']
    store_code = request.form["storeCode"]
    if file.filename == '':
        return 'No selected file', 400

    try:
        image_data = file.read()
        data = detector.classify_image(image_data, store_code)
        outputs = data["outputs"]
        time = data["time"]
        index = data["index"]
        return jsonify({"predictTime": time, "features": outputs, "index": index})

    except Exception as e:
        return jsonify({'error': str(e)}), 500


@app.route('/sync', methods=['POST'])
def sync():
    data = request.get_json()
    arr = data.get('data')
    store_code = data.get('storeCode')
    detector.sync(store_code, arr)
    return jsonify({"message": 'ok'})

detect.py

import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
import numpy as np
import time
import gc
from ivf import IVFPQ

# 加载预训练的 MobileNetV2 模型,不包含顶部的分类层
model = MobileNetV2(input_shape=(224, 224, 3), weights='imagenet', include_top=False, pooling='avg')


class MainDetect:
    # 初始化
    def __init__(self):
        super().__init__()
        # 模型初始化
        self.image_id = None
        self.image_features = None
        self.model = tf.keras.models.load_model("models/custom/my-model.h5")
        self.ivfObj = {}

    def classify_image(self, image_data, store_code, just_predict):
        # Load and preprocess image
        img = tf.image.decode_image(image_data, channels=3)
        img = tf.image.resize(img, [224, 224])
        img = tf.expand_dims(img, axis=0)  # Add batch dimension

        # Run model prediction
        start_time = time.time()
        outputs = model.predict(img)
        # outputs = self.model.predict(outputs)
        # prediction = tf.divide(outputs, tf.norm(outputs))
        i = []
        if just_predict == "false":
            if store_code + '-featureDatabase' in self.ivfObj:
                i = self.ivfObj[store_code + '-featureDatabase'].search(outputs)
                i = i.flatten().tolist()

        end_time = time.time()

        # Calculate elapsed time
        elapsed_time = end_time - start_time

        # Flatten the outputs and return them
        # output_data = prediction.numpy().flatten().tolist()
        output_data = outputs.flatten().tolist()

        # Force garbage collection to free up memory
        del img, outputs, end_time, start_time  # Ensure variables are deleted
        gc.collect()

        return {"outputs": output_data, "time": f"{elapsed_time * 1000:.2f}ms", "index": i}

    def sync(self, store_code, data):
        if store_code + '-featureDatabase' in self.ivfObj:
            del self.ivfObj[store_code + '-featureDatabase']
        # 提取所有特征并转换为 NumPy 数组
        features = np.array([item['features'] for item in data], dtype=np.float32)
        self.ivfObj[store_code + '-featureDatabase'] = IVFPQ(features)
        return 'ok'

ivf.py

import faiss
import numpy as np


class IVFPQ:
    def __init__(self, features, nlist=100, m=16, n_bits=8):
        d = features.shape[1]
        # 创建量化器
        quantizer = faiss.IndexFlatL2(d)  # 使用L2距离进行量化
        self.index = faiss.IndexIVFPQ(quantizer, d, nlist, m, n_bits)
        # 训练索引
        self.index.train(features)
        self.index.add(features)  # 将特征向量添加到索引中

    def search(self, xq, k=100):
        d, i = self.index.search(xq, k)
        return i

    def add(self, xb):
        self.index.add(xb)

    def train(self, xb):
        self.index.train(xb)

    def sync(self, features):
        for i in range(len(features)):
            self.add(features[i])


  1. nestjs进行接口的转发和数据处理
    识别的service,调用python服务,并通过返回的索引找到正确的目标label
  /**
   * 预测
   * @param file
   * @param num
   * @param storeCode
   * @param justPredict
   */
  async predict(file: Express.Multer.File, num: string = '5', storeCode: string, justPredict: Boolean = false) {
    const url = 'http://localhost:5000/predict'; // Python 服务的 URL
    const startTime = Date.now();
    try {
      // 返回 Python 服务的响应数据
      const formData = new FormData();
      formData.append('file', file.buffer, file.originalname);
      formData.append('storeCode', storeCode);
      const response = await firstValueFrom(this.httpService.post(url, formData));
      const endTime = Date.now();
      const features = response.data.features;
      const index = response.data.index;
      if (justPredict) {
        return features;
      }
      // const top5 = await this.findTopNSimilar(features, parseInt(num), storeCode);

      const featureDatabaseStr = await this.redisService.get(`${storeCode}-featureDatabase`);
      if (!featureDatabaseStr) {
        return response.data = {
          ...response.data,
          [`top${num}`]: [],
          features,
          totalTime: `${endTime - startTime}ms`,
        };
      }
      const featureDatabase = JSON.parse(featureDatabaseStr);
      const list = [];
      index.forEach((i: number) => {
        const ide = i - 256;
        if (ide >= 0) {
          const item = featureDatabase[ide];
          if (!list.some(l => l.label === item.label)) {
            list.push({ label: item.label });
          }
        }
      });
      return response.data = {
        ...response.data,
        [`top${num}`]: list,
        features,
        totalTime: `${endTime - startTime}ms`,
      };
    } catch (error) {
      // 错误处理
      console.error('Error calling Python service:', error);
      throw error;
    }
  }

同步python端特征向量,在数据库的增删改查时进行调用

  /**
   * 同步redis
   * @param storeCode
   */
  async syncRedis(storeCode: string) {
    const featureDatabase = await this.findAll(storeCode);
    await this.redisService.set(`${storeCode}-featureDatabase`, JSON.stringify(featureDatabase));
    const url = 'http://localhost:5000/sync'; // Python 服务的 URL
    await firstValueFrom(this.httpService.post(url, { data: featureDatabase, storeCode }));
  }

结语

没规划好,其实全部用python实现应该更自然,后续有时间再更新,先上线跑跑


http://www.kler.cn/a/299857.html

相关文章:

  • 【LeetCode:3174】清除数字(Java)
  • 《JavaEE进阶》----15.<Spring Boot 日志>
  • Day 31: 贪心算法基础 V
  • 【linux-Day2】linux下的基本指令
  • Ubuntu基本命令的熟悉和使用
  • 插件maven-search:Maven导入依赖时,使用插件maven-search拷贝需要的依赖的GAV
  • Rickdiculously Easy靶场渗透测试
  • 【Python 学习】Numpy的基础和应用
  • node.js实现阿里云短信发送
  • Android之LiveTemplate注释模板
  • 基于云原生向量数据库 PieCloudVector 的 RAG 实践
  • 面试官:说说你对keep-alive的理解是什么?
  • vue element时间选择不能超过今天 时间选中长度不能超过7天
  • 动手学深度学习(pytorch土堆)-02TensorBoard的使用
  • 防患于未然,智能监控新视角:EasyCVR视频平台在高校安全防控中的关键角色
  • Azure OpenAI models being unable to correctly identify model
  • [001-03-007].第26节:分布式锁迭代3->优化基于setnx命令实现的分布式锁-防锁的误删
  • openharmony 应用支持常驻和自启动
  • Web安全之XSS跨站脚本攻击:如何预防及解决
  • 2024年最新版Ajax+Axios 学习【包含原理、Promise、报文、接口等...】