Clang插件演示-直接调用AI模型定义的变量完成模型推理
Clang插件演示-直接调用AI模型定义的变量完成模型推理
- 一.需求描述
- 二.实现方案
- 三.复现过程
- 1.搭建docker环境
- 2.准备推理SDK DEMO
- 3.准备Clang插件
- 4.测试DEMO
一.需求描述
- 1.用户在c++代码里定义一个AI模型的描述(文件路径,数据类型,输入、输出等)
- 2.编译器识别该描述,编译模型为二进制,并且嵌入到elf文件中,对外暴露一个API
- 3.用户直接调用该API,传入模型输入输出参数,完成模型推理
二.实现方案
- 1.封装一个AI模型编译和推理的SDK,供后续使用
- 2.利用Clang PluginASTAction机制,准备一个Clang插件,它会识别某个类型的变量,变量的内容是json格式的模型描述信息
- 3.加载以上插件编译用户代码,该过程会调用AI模型编译API编译模型,生成二进制
并且将内容赋给一个数组,之后会修改源码,生成一个与变量同名的函数,在该API中调用AI推理API,同时删除模型定义的变量 - 4.再次使用clang,去掉插件选项编译,生成目标文件
- 5.用法
#include <iostream> typedef const char* ai_model_define; int main() { //定义模型,编译器会将这条语言编译成模型文件及void vamodel(int argc,void *argv[]) ai_model_define vamodel = R"( { "path":"./model.pt", "input":[ "1":"float16" ]; })"; void *args[]={(void*)"Hello",(void*)"World"}; vamodel(2,args); // 调用自动生成的 vamodel 函数 return 0; }
- 6.也可以在AST中插入函数定义,这样就不用二次编译(本文没有演示)
三.复现过程
1.搭建docker环境
mkdir ClangPlugin
cd ClangPlugin
docker run -it --net=host --privileged=true -v $PWD:/home --name ClangPlugin ubuntu:22.04 /bin/bash
cd ClangPlugin
apt-get update
apt-get install clang llvm llvm-dev -y
apt-get install libclang-dev -y
clang --version
llvm-config --version
2.准备推理SDK DEMO
tee infer_sdk.c<<-'EOF'
#include <stdio.h>
#include <string.h>
int AiInferenceBuild(const char *model_name,const char *model_cfg_json,char *output_model_bin,int max_model_size)
{
printf("AiInferenceBuild:%s %s\n",model_name,model_cfg_json);
sprintf(output_model_bin,"const char *%s_modelbin=\"This is Model:%s\\n\";",model_name,model_name);
}
int AiInferenceForward(const char *model_bin,int argc,void *argv[])
{
printf("AiInferenceForward:%s\n",model_bin);
for(int i=0;i<argc;i++)
{
printf("AiInferenceForward arg:%d value:%s\n",i,(char*)argv[i]);
}
}
EOF
clang -fPIC -shared infer_sdk.c -o libinfer_sdk.so -fno-rtti
3.准备Clang插件
tee MyPlugin.cpp<<-'EOF'
#include <iostream>
#include "clang/AST/ASTConsumer.h"
#include "clang/AST/ASTContext.h"
#include "clang/Frontend/ASTUnit.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Frontend/CompilerInvocation.h"
#include "clang/Frontend/PreprocessorOutputOptions.h"
#include "clang/Frontend/TextDiagnosticPrinter.h"
#include "clang/Tooling/Tooling.h"
#include "llvm/Support/raw_ostream.h"
#include "clang/Lex/PreprocessorOptions.h"
#include "clang/AST/AST.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Frontend/FrontendPluginRegistry.h"
#include "clang/Rewrite/Core/Rewriter.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Basic/Diagnostic.h"
using namespace clang;
extern "C" int AiInferenceBuild(const char *model_name,const char *model_cfg_json,char *output_model_bin,int max_model_size);
namespace {
class MyASTVisitor : public RecursiveASTVisitor<MyASTVisitor> {
public:
explicit MyASTVisitor(Rewriter &R) : TheRewriter(R) {}
bool VisitVarDecl(VarDecl *VD) {
// 检查变量类型是否为 ai_model_define
QualType qt = VD->getType();
std::string typeStr = qt.getAsString();
if (typeStr == "ai_model_define") {
HandleAIModeDefineVar(VD);
}
return true;
}
private:
Rewriter &TheRewriter;
void HandleAIModeDefineVar(VarDecl *VD) {
if (Expr *Init = VD->getInit()) {
if (StringLiteral *StrLit = dyn_cast<StringLiteral>(Init->IgnoreImpCasts())) {
std::string Content = StrLit->getString().str();
GenerateAndInsertVAModel(Content, VD);
}
}
}
void GenerateAndInsertVAModel(const std::string &Content, VarDecl *VD) {
std::string VarName = VD->getNameAsString();
std::string FuncName = VarName; // 函数名与变量名相同
//编译模型
const int max_model_size=2<<20;
char output_model_bin[max_model_size];
AiInferenceBuild(FuncName.c_str(),Content.c_str(),output_model_bin,max_model_size);
//std::cout<<"[[["<<Content<<"]]]"<<std::endl;
// 生成函数代码
std::string FunctionCode =output_model_bin;
FunctionCode +="\nextern \"C\" int AiInferenceForward\(const char *model_name,int argc,void *argv[]\);\n";
FunctionCode =FunctionCode+"void "+FuncName+"(int argc,void *argv[]) {\n";
FunctionCode += " AiInferenceForward("+FuncName+"_modelbin,argc,argv);\n";
FunctionCode += "}\n";
//std::cout<<FunctionCode<<std::endl;
InsertVAModelImplementation(FunctionCode, VD);
}
void InsertVAModelImplementation(const std::string &FunctionCode, VarDecl *VD) {
DiagnosticsEngine &DE = VD->getASTContext().getDiagnostics();
unsigned DiagID = DE.getCustomDiagID(DiagnosticsEngine::Warning, "Found ai_model_define variable.");
DE.Report(VD->getLocation(), DiagID);
SourceManager &SM = TheRewriter.getSourceMgr();
SourceLocation StartOfFileLoc = SM.getLocForStartOfFile(SM.getMainFileID());
//插入生成的函数代码到源文件末尾
TheRewriter.InsertText(StartOfFileLoc, FunctionCode, false, true);
LangOptions LangOpts = TheRewriter.getLangOpts();
SourceRange VarRange = VD->getSourceRange();
// 获取要替换的文本范围的起始位置和结束位置
SourceLocation StartLoc = SM.getExpansionLoc(VarRange.getBegin());
SourceLocation EndLoc = SM.getExpansionLoc(VarRange.getEnd());
// 计算替换范围的长度
unsigned Length = SM.getCharacterData(EndLoc) - SM.getCharacterData(StartLoc) + Lexer::MeasureTokenLength(EndLoc, SM, TheRewriter.getLangOpts());
// 删除变量
TheRewriter.ReplaceText(StartLoc, Length+1, "");
// 将修改后的文件打印到终端
// TheRewriter.getEditBuffer(SM.getMainFileID()).write(llvm::outs());
// 将修改后的文保存到文件
std::string FileName = SM.getFileEntryForID(SM.getMainFileID())->getName().str();
std::string OutputFileName = FileName + ".rewritten.cpp";
std::error_code error_code;
llvm::raw_fd_ostream outFile(OutputFileName, error_code);
TheRewriter.getEditBuffer(SM.getMainFileID()).write(outFile);
outFile.close();
}
};
class MyASTConsumer : public ASTConsumer {
public:
explicit MyASTConsumer(Rewriter &R) : Visitor(R) {}
void HandleTranslationUnit(ASTContext &Context) override {
Visitor.TraverseDecl(Context.getTranslationUnitDecl());
}
private:
MyASTVisitor Visitor;
};
class MyPluginAction : public PluginASTAction {
public:
bool ParseArgs(const CompilerInstance &CI, const std::vector<std::string> &args) override {
return true;
}
std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) override {
TheRewriter.setSourceMgr(CI.getSourceManager(), CI.getLangOpts());
return std::make_unique<MyASTConsumer>(TheRewriter);
}
private:
Rewriter TheRewriter;
};
} // namespace
// 注册插件
static FrontendPluginRegistry::Add<MyPluginAction>
X("MyPluginAction", "Insert vamodel implementation");
EOF
clang++ -std=c++14 -g -fPIC -shared MyPlugin.cpp -o plugin.so \
`llvm-config --cxxflags --ldflags --system-libs --libs all` -fno-rtti ./libinfer_sdk.so
4.测试DEMO
tee main.cpp<<-'EOF'
#include <iostream>
typedef const char* ai_model_define;
int main()
{
//定义模型,编译器会将这条语言编译成模型文件及void vamodel(int argc,void *argv[])
ai_model_define vamodel = R"(
{
"path":"./model.pt",
"input":[
"1":"float16"
];
})";
void *args[]={(void*)"Hello",(void*)"World"};
vamodel(2,args); // 调用自动生成的 vamodel 函数
return 0;
}
EOF
clang++ -Xclang -load -Xclang ./plugin.so -Xclang -plugin -Xclang MyPluginAction main.cpp
clang++ -o demo main.cpp.rewritten.cpp ./libinfer_sdk.so
./demo
输出
AiInferenceBuild:vamodel
{
"path":"./model.pt",
"input":[
"1":"float16"
];
}
main.cpp:8:21: warning: Found ai_model_define variable.
ai_model_define vamodel = R"(
^
1 warning and 1 error generated.
AiInferenceForward:This is Model:vamodel
AiInferenceForward arg:0 value:Hello
AiInferenceForward arg:1 value:World
说明
- 1.使用Rewriter修改的代码,编译器默认不会编译修改后的代码
- 2.因此,需要在插件中使用Rewriter进行代码修改,并将修改后的代码输出到文件,然后重新编译修改后的源文件
- 3.或者在AST中插入函数定义,使编译器在编译时能够看到新的函数,如下:
void GenerateAndInsertVAModelAST(const std::string &Content, VarDecl *VD) { ASTContext &Context = VD->getASTContext(); // 1. 准备函数的返回类型和参数类型 QualType ReturnType = Context.VoidTy; QualType ParamType = Context.VoidPtrTy; // 2. 创建函数名标识符 IdentifierInfo *FuncId = &Context.Idents.get("vamodel"); // 3. 创建函数参数列表 ParmVarDecl *Param = ParmVarDecl::Create( Context, nullptr, // 没有父上下文 SourceLocation(), SourceLocation(), &Context.Idents.get("input"), ParamType, nullptr, SC_None, nullptr ); SmallVector<ParmVarDecl *, 1> Params; Params.push_back(Param); // 4. 创建函数声明 FunctionDecl *FuncDecl = FunctionDecl::Create( Context, Context.getTranslationUnitDecl(), // 所属的上下文 SourceLocation(), SourceLocation(), FuncId, ReturnType, nullptr, SC_None, false, false, false ); FuncDecl->setParams(Params); // 5. 创建函数体(此处生成示例代码,可根据需要完善) // 构建函数体内容 // 这里只是一个空的复合语句,可以添加实际的语句 CompoundStmt *FuncBody = new (Context) CompoundStmt(SourceLocation(), SourceLocation()); FuncDecl->setBody(FuncBody); // 6. 将函数声明添加到翻译单元 Context.getTranslationUnitDecl()->addDecl(FuncDecl); }