当前位置: 首页 > article >正文

Canmv k230 C++案例1.2——image classify项目 C++代码分析(待完成)

这部分为初学,所以手头最好有本工具书便于查阅

01 代码初步注释

// 这里是一些定义配置
// 时间的标准库
#include <chrono>
// 写入或读取文件的标准库
#include <fstream>
// 文件输入输出的标准库,流模型
#include <iostream>
// k230的头文件,用于AI模型推断
#include <nncase/runtime/interpreter.h>
#include <nncase/runtime/runtime_op_utility.h>
// opencv开关  及  预处理过程的开关
#define USE_OPENCV 1
#define preprocess 1

// opencv 文件
#if USE_OPENCV
// 显示、编码、处理?待查手册验证
#include <opencv2/highgui.hpp>
#include <opencv2/imgcodecs.hpp>
#include <opencv2/imgproc.hpp>
#endif

// nncase的命名空间
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::detail;

// 数据的输入尺寸定义 224*224*3
#define INTPUT_HEIGHT 224
#define INTPUT_WIDTH 224
#define INTPUT_CHANNELS 3

02

// 定理的类模版
template <class T>
// 读取二进制文件  输入参数为文件名
std::vector<T> read_binary_file(const std::string &file_name)
{
    // 打开数据流  构筑一个ifstream并打开给定文件 以二进制方式进行IO
    std::ifstream ifs(file_name, std::ios::binary);
    // 从文件末尾开始搜寻
    ifs.seekg(0, ifs.end);
    // 获取矢量数据的长度
    size_t len = ifs.tellg();
    // 定义矢量,依赖与机器
    std::vector<T> vec(len / sizeof(T), 0);
    // 从文件开始的地方进行搜寻
    ifs.seekg(0, ifs.beg);
    // 强制类型转换,非常危险,但只有这个vector是正确的
    ifs.read(reinterpret_cast<char *>(vec.data()), len);
    // 关闭文件
    ifs.close();
    return vec;
}
// 读入文件,打开文件,判断文件长度,写入数据
void read_binary_file(const char *file_name, char *buffer)
{
   
    std::ifstream ifs(file_name, std::ios::binary);
    ifs.seekg(0, ifs.end);
    // sizeof 返回类型
    size_t len = ifs.tellg();
    
    ifs.seekg(0, ifs.beg);
    ifs.read(buffer, len);
    ifs.close();
}

static std::vector<std::string> read_txt_file(const char *file_name)
{
    // 定义字符串vector变量vec,不含任何元素
    std::vector<std::string> vec;
    // 分类至少能容纳n个元素的内容空间
    vec.reserve(1024);

    // 打开file_name的文件名变量fp
    std::ifstream fp(file_name);
    // 定义字符串变量label
    std::string label;

    // 从fp中读取一行赋给label,返回fp
    // 每次读如一整行,直至到达文件末尾
    while (getline(fp, label))
    {
        // 矢量后面增加数据
        vec.push_back(label);
    }

    return vec;
}

softmax函数,函数公式

template<typename T>
static int softmax(const T* src, T* dst, int length)
{
    // 算法函数 寻找数组的最大值 const赋值
    const T alpha = *std::max_element(src, src + length);
    // 分母
    T denominator{ 0 };

    for (int i = 0; i < length; ++i) {
        dst[i] = std::exp(src[i] - alpha);
        denominator += dst[i];
    }
    // 指数输出
    for (int i = 0; i < length; ++i) {
        dst[i] /= denominator;
    }

    return 0;
}


数据格式转换

// 调用OPENCV的函数转换
#if USE_OPENCV
// hwc转chw 转换以适应tensorflow?
std::vector<uint8_t> hwc2chw(cv::Mat &img)
{
    std::vector<uint8_t> vec;
    std::vector<cv::Mat> rgbChannels(3);
    cv::split(img, rgbChannels);
    for (auto i = 0; i < rgbChannels.size(); i++)
    {
        std::vector<uint8_t> data = std::vector<uint8_t>(rgbChannels[i].reshape(1, 1));
        vec.insert(vec.end(), data.begin(), data.end());
    }

    return vec;
}
#endif

模型推断代码

// 可以看到推断需要执行文件和三个参数模型、图片、标签
static int inference(const char *kmodel_file, const char *image_file, const char *label_file)
{
    // load kmodel
    interpreter interp;
    // 模型也被保存为二进制文件?但格式未知
    std::ifstream ifs(kmodel_file, std::ios::binary);
    // 判断是否载入正常
    interp.load_model(ifs).expect("load_model failed");

    // create input tensor 创建输入?
    auto input_desc = interp.input_desc(0);
    // create input tensor 创建输入尺寸?
    auto input_shape = interp.input_shape(0);
    
    auto input_tensor = host_runtime_tensor::create(input_desc.datatype, input_shape, hrt::pool_shared).expect("cannot create input tensor");
    interp.input_tensor(0, input_tensor).expect("cannot set input tensor");

    // create output tensor
    // auto output_desc = interp.output_desc(0);
    // auto output_shape = interp.output_shape(0);
    // auto output_tensor = host_runtime_tensor::create(output_desc.datatype, output_shape, hrt::pool_shared).expect("cannot create output tensor");
    // interp.output_tensor(0, output_tensor).expect("cannot set output tensor");

    // set input data
    auto dst = input_tensor.impl()->to_host().unwrap()->buffer().as_host().unwrap().map(map_access_::map_write).unwrap().buffer();
#if USE_OPENCV
    cv::Mat img = cv::imread(image_file);
    cv::resize(img, img, cv::Size(INTPUT_WIDTH, INTPUT_HEIGHT), cv::INTER_NEAREST);
    auto input_vec = hwc2chw(img);
    memcpy(reinterpret_cast<char *>(dst.data()), input_vec.data(), input_vec.size());
#else
    read_binary_file(image_file, reinterpret_cast<char *>(dst.data()));
#endif
    hrt::sync(input_tensor, sync_op_t::sync_write_back, true).expect("sync write_back failed");

    // run
    size_t counter = 1;
    auto start = std::chrono::steady_clock::now();
    for (size_t c = 0; c < counter; c++)
    {
        interp.run().expect("error occurred in running model");
    }
    auto stop = std::chrono::steady_clock::now();
    double duration = std::chrono::duration<double, std::milli>(stop - start).count();
    std::cout << "interp.run() took: " << duration / counter << " ms" << std::endl;

    // get output data
    auto output_tensor = interp.output_tensor(0).expect("cannot set output tensor");
    dst = output_tensor.impl()->to_host().unwrap()->buffer().as_host().unwrap().map(map_access_::map_read).unwrap().buffer();
    float *output_data = reinterpret_cast<float *>(dst.data());
    auto out_shape = interp.output_shape(0);
    auto size = compute_size(out_shape);

    // postprogress softmax by cpu
    std::vector<float> softmax_vec(size, 0);
    auto buf = softmax_vec.data();
    softmax(output_data, buf, size);
    auto it = std::max_element(buf, buf + size);
    size_t idx = it - buf;

    // load label
    auto labels = read_txt_file(label_file);
    std::cout << "image classify result: " << labels[idx] << "(" << *it << ")" << std::endl;

    return 0;
}

主函数

// 主函数 要判断各个环节是否正确输出
int main(int argc, char *argv[])
{
    // 输出argv[0]一般是文件名称
    std::cout << "case " << argv[0] << " built at " << __DATE__ << " " << __TIME__ << std::endl;
    if (argc != 4)
    {
        // 判断输入argc个数
        std::cerr << "Usage: " << argv[0] << " <kmodel> <image> <label>" << std::endl;
        return -1;
    }

    int ret = inference(argv[1], argv[2], argv[3]);
    if (ret)
    {
        std::cerr << "inference failed: ret = " << ret << std::endl;
        return -2;
    }

    return 0;
}

02 代码的整体结构

03 部分代码说明

04 附录 一些代码资料说明


http://www.kler.cn/news/350707.html

相关文章:

  • STM32-Modbus协议(一文通)
  • ESlint代码规范
  • 设计模式: Pimpl(Pointer to Implementation)
  • Python实现非线性数据结构-字典、集合、树、图
  • 微信小程序/uniapp动态修改tabBar信息及常见报错
  • NewStarCTF 2023 公开赛道 Web week1-week2
  • 网络安全公司及其主要产品介绍
  • Django CORS跨域支持
  • 基于Python的博客系统
  • Android10 recent键相关总结
  • Spring Boot框架下大创项目流程自动化
  • 【论文阅读】03-Diffusion Models and Representation Learning: A Survey
  • C++ | Leetcode C++题解之第486题预测赢家
  • Android activity 启动流程
  • 指针——函数指针数组
  • 计算机网络 2024 11 10
  • windows上的git bash中会将~设为哪个目录?
  • vector的深入剖析与底层逻辑
  • css-背景图片全屏显示适配不同尺寸覆盖
  • 股票分析软件设计
  • 003_django基于Django高校岗位招聘平台与数据可视化分析设计和实现2024_414pr4jc
  • 大数据-174 Elasticsearch Query DSL - 全文检索 full-text query 匹配、短语、多字段 详细操作
  • 法规标准-懂车帝智能化实测标准(2024版)
  • 嵌入式:Keil的Code、RW、RO、ZI段的解析
  • 解决 Qt 中提升控件后样式表无法正确应用的问题
  • 导致动态代理无法使用的原因有哪些?