当前位置: 首页 > article >正文

Clang插件演示-直接调用AI模型定义的变量完成模型推理

Clang插件演示-直接调用AI模型定义的变量完成模型推理

  • 一.需求描述
  • 二.实现方案
  • 三.复现过程
    • 1.搭建docker环境
    • 2.准备推理SDK DEMO
    • 3.准备Clang插件
    • 4.测试DEMO

一.需求描述

  • 1.用户在c++代码里定义一个AI模型的描述(文件路径,数据类型,输入、输出等)
  • 2.编译器识别该描述,编译模型为二进制,并且嵌入到elf文件中,对外暴露一个API
  • 3.用户直接调用该API,传入模型输入输出参数,完成模型推理

二.实现方案

  • 1.封装一个AI模型编译和推理的SDK,供后续使用
  • 2.利用Clang PluginASTAction机制,准备一个Clang插件,它会识别某个类型的变量,变量的内容是json格式的模型描述信息
  • 3.加载以上插件编译用户代码,该过程会调用AI模型编译API编译模型,生成二进制
    并且将内容赋给一个数组,之后会修改源码,生成一个与变量同名的函数,在该API中调用AI推理API,同时删除模型定义的变量
  • 4.再次使用clang,去掉插件选项编译,生成目标文件
  • 5.用法
    #include <iostream>
    typedef const char* ai_model_define;
    int main() 
    {   
    	//定义模型,编译器会将这条语言编译成模型文件及void vamodel(int argc,void *argv[]) 
    	ai_model_define vamodel = R"(
    	{
    	 "path":"./model.pt",
    	 "input":[
    		"1":"float16"
    	 ];
    	})";
    	void *args[]={(void*)"Hello",(void*)"World"};
    	vamodel(2,args); // 调用自动生成的 vamodel 函数
    	return 0;
    }
    
  • 6.也可以在AST中插入函数定义,这样就不用二次编译(本文没有演示)

三.复现过程

1.搭建docker环境

mkdir ClangPlugin
cd ClangPlugin

docker run -it --net=host --privileged=true -v $PWD:/home --name ClangPlugin ubuntu:22.04 /bin/bash
cd ClangPlugin

apt-get update
apt-get install clang llvm llvm-dev -y
apt-get install libclang-dev -y

clang --version
llvm-config --version

2.准备推理SDK DEMO

tee infer_sdk.c<<-'EOF'
#include <stdio.h>
#include <string.h>
int AiInferenceBuild(const char *model_name,const char *model_cfg_json,char *output_model_bin,int max_model_size)
{
    printf("AiInferenceBuild:%s %s\n",model_name,model_cfg_json);
    sprintf(output_model_bin,"const char *%s_modelbin=\"This is Model:%s\\n\";",model_name,model_name);	
}
int AiInferenceForward(const char *model_bin,int argc,void *argv[])
{
    printf("AiInferenceForward:%s\n",model_bin);
    for(int i=0;i<argc;i++)
    {
        printf("AiInferenceForward arg:%d value:%s\n",i,(char*)argv[i]);
    }
}
EOF
clang -fPIC -shared infer_sdk.c -o libinfer_sdk.so -fno-rtti

3.准备Clang插件

tee MyPlugin.cpp<<-'EOF'
#include <iostream>
#include "clang/AST/ASTConsumer.h"
#include "clang/AST/ASTContext.h"
#include "clang/Frontend/ASTUnit.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Frontend/CompilerInvocation.h"
#include "clang/Frontend/PreprocessorOutputOptions.h"
#include "clang/Frontend/TextDiagnosticPrinter.h"
#include "clang/Tooling/Tooling.h"
#include "llvm/Support/raw_ostream.h"
#include "clang/Lex/PreprocessorOptions.h"
#include "clang/AST/AST.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Frontend/FrontendPluginRegistry.h"
#include "clang/Rewrite/Core/Rewriter.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Basic/Diagnostic.h"

using namespace clang;

extern "C" int AiInferenceBuild(const char *model_name,const char *model_cfg_json,char *output_model_bin,int max_model_size);

namespace {

class MyASTVisitor : public RecursiveASTVisitor<MyASTVisitor> {
public:
    explicit MyASTVisitor(Rewriter &R) : TheRewriter(R) {}

    bool VisitVarDecl(VarDecl *VD) {
        // 检查变量类型是否为 ai_model_define
        QualType qt = VD->getType();
        std::string typeStr = qt.getAsString();

        if (typeStr == "ai_model_define") {
            HandleAIModeDefineVar(VD);
        }
        return true;
    }

private:
    Rewriter &TheRewriter;

    void HandleAIModeDefineVar(VarDecl *VD) {
        if (Expr *Init = VD->getInit()) {
            if (StringLiteral *StrLit = dyn_cast<StringLiteral>(Init->IgnoreImpCasts())) {
                std::string Content = StrLit->getString().str();
                GenerateAndInsertVAModel(Content, VD);
            }
        }
    }
    void GenerateAndInsertVAModel(const std::string &Content, VarDecl *VD) {
        std::string VarName = VD->getNameAsString();
        std::string FuncName = VarName; // 函数名与变量名相同
        
        //编译模型
        const int max_model_size=2<<20;
        char output_model_bin[max_model_size];
        AiInferenceBuild(FuncName.c_str(),Content.c_str(),output_model_bin,max_model_size);
        
        //std::cout<<"[[["<<Content<<"]]]"<<std::endl;
        
        // 生成函数代码
        std::string FunctionCode =output_model_bin;
        FunctionCode +="\nextern \"C\" int AiInferenceForward\(const char *model_name,int argc,void *argv[]\);\n";
        FunctionCode =FunctionCode+"void "+FuncName+"(int argc,void *argv[]) {\n";
        FunctionCode += "    AiInferenceForward("+FuncName+"_modelbin,argc,argv);\n";
        FunctionCode += "}\n";
        //std::cout<<FunctionCode<<std::endl;
        InsertVAModelImplementation(FunctionCode, VD);
    }

    void InsertVAModelImplementation(const std::string &FunctionCode, VarDecl *VD) {

        DiagnosticsEngine &DE = VD->getASTContext().getDiagnostics();
        unsigned DiagID = DE.getCustomDiagID(DiagnosticsEngine::Warning, "Found ai_model_define variable.");
        DE.Report(VD->getLocation(), DiagID);
    
        SourceManager &SM = TheRewriter.getSourceMgr();
        
        SourceLocation StartOfFileLoc = SM.getLocForStartOfFile(SM.getMainFileID());
        //插入生成的函数代码到源文件末尾
        TheRewriter.InsertText(StartOfFileLoc, FunctionCode, false, true);
        
        LangOptions LangOpts = TheRewriter.getLangOpts();
        SourceRange VarRange = VD->getSourceRange();

        // 获取要替换的文本范围的起始位置和结束位置
        SourceLocation StartLoc = SM.getExpansionLoc(VarRange.getBegin());
        SourceLocation EndLoc = SM.getExpansionLoc(VarRange.getEnd());

        // 计算替换范围的长度
        unsigned Length = SM.getCharacterData(EndLoc) - SM.getCharacterData(StartLoc) + Lexer::MeasureTokenLength(EndLoc, SM, TheRewriter.getLangOpts());
        // 删除变量
        TheRewriter.ReplaceText(StartLoc, Length+1, "");

        // 将修改后的文件打印到终端
        // TheRewriter.getEditBuffer(SM.getMainFileID()).write(llvm::outs());
        
        // 将修改后的文保存到文件
        std::string FileName = SM.getFileEntryForID(SM.getMainFileID())->getName().str();
        std::string OutputFileName = FileName + ".rewritten.cpp";        
        std::error_code error_code;
        llvm::raw_fd_ostream outFile(OutputFileName, error_code);
        TheRewriter.getEditBuffer(SM.getMainFileID()).write(outFile);
        outFile.close();
    }
};

class MyASTConsumer : public ASTConsumer {
public:
    explicit MyASTConsumer(Rewriter &R) : Visitor(R) {}

    void HandleTranslationUnit(ASTContext &Context) override {
        Visitor.TraverseDecl(Context.getTranslationUnitDecl());
    }

private:
    MyASTVisitor Visitor;
};

class MyPluginAction : public PluginASTAction {
public:
    bool ParseArgs(const CompilerInstance &CI, const std::vector<std::string> &args) override {
        return true;
    }

    std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) override {
        TheRewriter.setSourceMgr(CI.getSourceManager(), CI.getLangOpts());
        return std::make_unique<MyASTConsumer>(TheRewriter);
    }

private:
    Rewriter TheRewriter;
};
} // namespace

// 注册插件
static FrontendPluginRegistry::Add<MyPluginAction>
X("MyPluginAction", "Insert vamodel implementation");
EOF

clang++ -std=c++14 -g -fPIC -shared MyPlugin.cpp -o plugin.so \
    `llvm-config --cxxflags --ldflags --system-libs --libs all` -fno-rtti ./libinfer_sdk.so

4.测试DEMO

tee main.cpp<<-'EOF'
#include <iostream>
typedef const char* ai_model_define;

int main() 
{   
    //定义模型,编译器会将这条语言编译成模型文件及void vamodel(int argc,void *argv[]) 
    ai_model_define vamodel = R"(
    {
     "path":"./model.pt",
     "input":[
        "1":"float16"
     ];
    })";
    void *args[]={(void*)"Hello",(void*)"World"};
    vamodel(2,args); // 调用自动生成的 vamodel 函数
    return 0;
}
EOF

clang++ -Xclang -load -Xclang ./plugin.so -Xclang -plugin -Xclang MyPluginAction main.cpp
clang++ -o demo main.cpp.rewritten.cpp ./libinfer_sdk.so
./demo

输出

AiInferenceBuild:vamodel
    {
     "path":"./model.pt",
     "input":[
        "1":"float16"
     ];
    }
main.cpp:8:21: warning: Found ai_model_define variable.
    ai_model_define vamodel = R"(
                    ^
1 warning and 1 error generated.
AiInferenceForward:This is Model:vamodel

AiInferenceForward arg:0 value:Hello
AiInferenceForward arg:1 value:World

说明

  • 1.使用Rewriter修改的代码,编译器默认不会编译修改后的代码
  • 2.因此,需要在插件中使用Rewriter进行代码修改,并将修改后的代码输出到文件,然后重新编译修改后的源文件
  • 3.或者在AST中插入函数定义,使编译器在编译时能够看到新的函数,如下:
    void GenerateAndInsertVAModelAST(const std::string &Content, VarDecl *VD) {
        ASTContext &Context = VD->getASTContext();
    
        // 1. 准备函数的返回类型和参数类型
        QualType ReturnType = Context.VoidTy;
        QualType ParamType = Context.VoidPtrTy;
    
        // 2. 创建函数名标识符
        IdentifierInfo *FuncId = &Context.Idents.get("vamodel");
    
        // 3. 创建函数参数列表
        ParmVarDecl *Param = ParmVarDecl::Create(
            Context,
            nullptr, // 没有父上下文
            SourceLocation(),
            SourceLocation(),
            &Context.Idents.get("input"),
            ParamType,
            nullptr,
            SC_None,
            nullptr
        );
        SmallVector<ParmVarDecl *, 1> Params;
        Params.push_back(Param);
    
        // 4. 创建函数声明
        FunctionDecl *FuncDecl = FunctionDecl::Create(
            Context,
            Context.getTranslationUnitDecl(), // 所属的上下文
            SourceLocation(),
            SourceLocation(),
            FuncId,
            ReturnType,
            nullptr,
            SC_None,
            false,
            false,
            false
        );
        FuncDecl->setParams(Params);
    
        // 5. 创建函数体(此处生成示例代码,可根据需要完善)
        // 构建函数体内容
        // 这里只是一个空的复合语句,可以添加实际的语句
        CompoundStmt *FuncBody = new (Context) CompoundStmt(SourceLocation(), SourceLocation());
    
        FuncDecl->setBody(FuncBody);
    
        // 6. 将函数声明添加到翻译单元
        Context.getTranslationUnitDecl()->addDecl(FuncDecl);
    }
    

http://www.kler.cn/a/312731.html

相关文章:

  • 使用vite构建一个react网站,并部署到Netlify上
  • 【Linux-进程间通信】了解信号量 + 共享内存 + 消息队列的应用
  • WPF自定义翻页控件
  • 【react】Redux基础用法
  • 前端页面性能优化的常见问题与解决方案
  • 负载均衡式在线oj项目开发文档(个人项目)
  • IP Source Guard技术原理与应用
  • 如何在GitHub上克隆仓库:HTTPS、SSH和GitHub CLI的区别
  • 【机器学习(五)】分类和回归任务-AdaBoost算法-Sentosa_DSML社区版
  • 【算法题】300. 最长递增子序列-力扣(LeetCode)
  • 【资料分析】刷题日记3
  • node前端开发基本设置
  • 计算机毕业设计 公寓出租系统的设计与实现 Java实战项目 附源码+文档+视频讲解
  • 冷热电气多能互补的微能源网优化调度(含matlab代码)
  • MinIO自动化下载及部署脚本(Windows)
  • macOS Sequoia 15 发布,iPhone 镜像、密码应用程序、窗口平铺更新等带来全新体验
  • 数据中心一体化智能运维方案
  • tomcat中间件漏洞CVE-2017-12615,后台弱口令部署war包,CVE-2020-1938
  • 如何查看WSL默认安装位置以及挪动其到指定安装路径
  • A. Closest Point
  • LabVIEW提高开发效率技巧----使用事件结构优化用户界面响应
  • 【计算机网络 - 基础问题】每日 3 题(二)
  • JUC学习笔记(一)
  • 【Kubernetes】常见面试题汇总(十四)
  • 攻防演练篇:攻防演练场景中面临的常见加密威胁-HTTP隐蔽隧道
  • Lombok -----此java库 常用的注解及其功能总结