支持多种数据来源的ocr识别,GOT-OCR2.0模型部署
GOT-OCR2.0是由Ucas-HaoranWei团队开发并开源的一个端到端模型,它标志着OCR技术从传统的AI-1.0时代向更加先进的AI-2.0时代的过渡。
GOT-OCR2.0模型不仅支持多语言、多种类型的文本图像(包括手写体和印刷体),而且能够处理各种复杂的场景文本,如自然场景中的文字、文档以及细粒度OCR任务。
与以往需要针对不同任务使用多个独立模型的方法相比,GOT-OCR2.0通过统一的架构实现了对多种OCR任务的支持,极大地简化了部署流程并提升了整体性能。
GOT-OCR2.0采用了视觉编码器加上输入嵌入层再结合解码器的设计,其中编码器部分利用了带有局部注意力机制的VITDet架构,有效地管理了显存使用,使得模型可以在资源受限的情况下依然表现出色。
GOT-OCR2.0的训练分为三个阶段,首先是对编码器进行高效的预训练,随后是利用小型OPT-125M作为解码器快速引入大量数据,最后是整个模型的微调以达到最佳效果。
github项目地址:https://github.com/Ucas-HaoranWei/GOT-OCR2.0。
一、环境安装
1、python环境
建议安装python版本在3.10以上。
2、pip库安装
cd GOT-OCR-2.0-master/
pip install -e .
pip install ninja natsort
pip install flash-attn --no-build-isolation
3、GOT-OCR-2.0模型下载:
git lfs install
git clone https://huggingface.co/stepfun-ai/GOT-OCR2_0
二、功能测试
1、运行测试:
(1)python代码调用ocr
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from GOT.utils.conversation import conv_templates, SeparatorStyle
from GOT.utils.utils import disable_torch_init
from GOT.model import GOTQwenForCausalLM
from GOT.utils.utils import KeywordsStoppingCriteria
from GOT.model.plug.blip_process import BlipImageEvalProcessor
from transformers import TextStreamer
from PIL import Image
import requests
from io import BytesIO
from GOT.demo.process_results import punctuation_dict, svg_to_html
import string
# Constants
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'
translation_table = str.maketrans(punctuation_dict)
image_token_len = 256
# Load image from file or URL
def load_image(image_file):
if image_file.startswith(('http', 'https')):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def process_query_string(type, box, color, w, h):
qs = 'OCR with format: ' if type == 'format' else 'OCR: '
if box:
bbox = eval(box)
if len(bbox) in {2, 4}:
bbox = [int(coord / dimension * 1000) for coord, dimension in zip(bbox, [w, h] * (len(bbox) // 2))]
qs = f"{str(bbox)} {qs}"
if color:
qs = f"[{color}] {qs}"
qs = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_PATCH_TOKEN * image_token_len}{DEFAULT_IM_END_TOKEN}\n{qs}"
return qs
def render_output(outputs, type):
if '**kern' in outputs:
svg = convert_to_svg(outputs)
svg_to_html(svg, "./results/demo.html")
elif type == 'format' and '**kern' not in outputs:
web_content = process_math_output(outputs)
with open("./results/demo.html", 'w') as f:
f.write(web_content)
def convert_to_svg(outputs):
import verovio
tk = verovio.toolkit()
tk.loadData(outputs)
tk.setOptions({
"pageWidth": 2100, "footer": 'none',
"barLineWidth": 0.5, "beamMaxSlope": 15,
"staffLineWidth": 0.2, "spacingStaff": 6
})
svg = tk.renderToSVG()
svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
return svg
def process_math_output(outputs):
template_path = "./render_tools/content-mmd-to-html.html"
outputs = sanitize_math_outputs(outputs)
outputs_list = outputs.split('\n')
gt = ''.join([f'"{line.replace("\\", "\\\\")}\\n"+\n' for line in outputs_list])[:-2]
with open(template_path, 'r') as f:
template = f.read()
web_content = template.replace("const text =", f"const text ={gt}")
return web_content
def sanitize_math_outputs(outputs):
right_num = outputs.count('\\right')
left_num = outputs.count('\\left')
if right_num != left_num:
outputs = outputs.replace('\\left(', '(').replace('\\right)', ')').replace('\\left[', '[').replace('\\right]', ']').replace('\\left{', '{').replace('\\right}', '}').replace('\\left|', '|').replace('\\right|', '|').replace('\\left.', '.').replace('\\right.', '.')
return outputs.replace('"', '``').replace('$', '')
def eval_model(args):
# Disable torch initialization
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()
model.to(device='cuda', dtype=torch.bfloat16)
# Load image and preprocess
image = load_image(args.image_file)
w, h = image.size
# Query string based on type, box and color
qs = process_query_string(args.type, args.box, args.color, w, h)
# Prepare conversation
conv_mode = "mpt"
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
# Tokenize inputs
inputs = tokenizer([prompt])
input_ids = torch.as_tensor(inputs.input_ids).cuda()
# Process image
image_processor = BlipImageEvalProcessor(image_size=1024)
image_tensor = image_processor(image)
image_tensor_high = BlipImageEvalProcessor(image_size=1024)(image.copy())
# Generate output
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(
input_ids,
images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_high.unsqueeze(0).half().cuda())],
do_sample=False,
num_beams=1,
no_repeat_ngram_size=20,
streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
if args.render:
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
outputs = outputs.rstrip(stop_str).strip()
render_output(outputs, args.type)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--type", type=str, required=True)
parser.add_argument("--box", type=str, default='')
parser.add_argument("--color", type=str, default='')
parser.add_argument("--render", action='store_true')
args = parser.parse_args()
eval_model(args)
未完......
更多详细的欢迎关注:杰哥新技术