当前位置: 首页 > article >正文

如何在OpenCV中运行自定义OCR模型

我们首先介绍如何获取自定义OCR模型,然后介绍如何转换自己的OCR模型以便能够被opencv_dnn模块正确运行,最后我们将提供一些预先训练的模型。

训练你自己的 OCR 模型

此存储库是训练您自己的 OCR 模型的良好起点。在存储库中,MJSynth+SynthText 默认设置为训练集。此外,您可以配置所需的模型结构和数据集。

将 OCR 模型转换为 ONNX 格式并在 OpenCV DNN 中使用它

完成模型训练后,请使用transform_to_onnx.py将模型转换为onnx格式。

在网络摄像头中执行

源码:

'''
    Text detection model: https://github.com/argman/EAST
    Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1

    CRNN Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch
    How to convert from pb to onnx:
    Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py

    More converted onnx text recognition models can be downloaded directly here:
    Download link: https://drive.google.com/drive/folders/1cTbQ3nuZG-EKWak6emD_s8_hHXWz7lAr?usp=sharing
    And these models taken from here:https://github.com/clovaai/deep-text-recognition-benchmark

    import torch
    from models.crnn import CRNN

    model = CRNN(32, 1, 37, 256)
    model.load_state_dict(torch.load('crnn.pth'))
    dummy_input = torch.randn(1, 1, 32, 100)
    torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
'''


# Import required modules
import numpy as np
import cv2 as cv
import math
import argparse

############ Add argument parser for command line arguments ############
parser = argparse.ArgumentParser(
    description="Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
                "EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)"
                "The OCR model can be obtained from converting the pretrained CRNN model to .onnx format from the github repository https://github.com/meijieru/crnn.pytorch"
                "Or you can download trained OCR model directly from https://drive.google.com/drive/folders/1cTbQ3nuZG-EKWak6emD_s8_hHXWz7lAr?usp=sharing")
parser.add_argument('--input',
                    help='Path to input image or video file. Skip this argument to capture frames from a camera.')
parser.add_argument('--model', '-m', required=True,
                    help='Path to a binary .pb file contains trained detector network.')
parser.add_argument('--ocr', default="crnn.onnx",
                    help="Path to a binary .pb or .onnx file contains trained recognition network", )
parser.add_argument('--width', type=int, default=320,
                    help='Preprocess input image by resizing to a specific width. It should be multiple by 32.')
parser.add_argument('--height', type=int, default=320,
                    help='Preprocess input image by resizing to a specific height. It should be multiple by 32.')
parser.add_argument('--thr', type=float, default=0.5,
                    help='Confidence threshold.')
parser.add_argument('--nms', type=float, default=0.4,
                    help='Non-maximum suppression threshold.')
args = parser.parse_args()


############ Utility functions ############

def fourPointsTransform(frame, vertices):
    vertices = np.asarray(vertices)
    outputSize = (100, 32)
    targetVertices = np.array([
        [0, outputSize[1] - 1],
        [0, 0],
        [outputSize[0] - 1, 0],
        [outputSize[0] - 1, outputSize[1] - 1]], dtype="float32")

    rotationMatrix = cv.getPerspectiveTransform(vertices, targetVertices)
    result = cv.warpPerspective(frame, rotationMatrix, outputSize)
    return result


def decodeText(scores):
    text = ""
    alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
    for i in range(scores.shape[0]):
        c = np.argmax(scores[i][0])
        if c != 0:
            text += alphabet[c - 1]
        else:
            text += '-'

    # adjacent same letters as well as background text must be removed to get the final output
    char_list = []
    for i in range(len(text)):
        if text[i] != '-' and (not (i > 0 and text[i] == text[i - 1])):
            char_list.append(text[i])
    return ''.join(char_list)


def decodeBoundingBoxes(scores, geometry, scoreThresh):
    detections = []
    confidences = []

    ############ CHECK DIMENSIONS AND SHAPES OF geometry AND scores ############
    assert len(scores.shape) == 4, "Incorrect dimensions of scores"
    assert len(geometry.shape) == 4, "Incorrect dimensions of geometry"
    assert scores.shape[0] == 1, "Invalid dimensions of scores"
    assert geometry.shape[0] == 1, "Invalid dimensions of geometry"
    assert scores.shape[1] == 1, "Invalid dimensions of scores"
    assert geometry.shape[1] == 5, "Invalid dimensions of geometry"
    assert scores.shape[2] == geometry.shape[2], "Invalid dimensions of scores and geometry"
    assert scores.shape[3] == geometry.shape[3], "Invalid dimensions of scores and geometry"
    height = scores.shape[2]
    width = scores.shape[3]
    for y in range(0, height):

        # Extract data from scores
        scoresData = scores[0][0][y]
        x0_data = geometry[0][0][y]
        x1_data = geometry[0][1][y]
        x2_data = geometry[0][2][y]
        x3_data = geometry[0][3][y]
        anglesData = geometry[0][4][y]
        for x in range(0, width):
            score = scoresData[x]

            # If score is lower than threshold score, move to next x
            if (score < scoreThresh):
                continue

            # Calculate offset
            offsetX = x * 4.0
            offsetY = y * 4.0
            angle = anglesData[x]

            # Calculate cos and sin of angle
            cosA = math.cos(angle)
            sinA = math.sin(angle)
            h = x0_data[x] + x2_data[x]
            w = x1_data[x] + x3_data[x]

            # Calculate offset
            offset = ([offsetX + cosA * x1_data[x] + sinA * x2_data[x], offsetY - sinA * x1_data[x] + cosA * x2_data[x]])

            # Find points for rectangle
            p1 = (-sinA * h + offset[0], -cosA * h + offset[1])
            p3 = (-cosA * w + offset[0], sinA * w + offset[1])
            center = (0.5 * (p1[0] + p3[0]), 0.5 * (p1[1] + p3[1]))
            detections.append((center, (w, h), -1 * angle * 180.0 / math.pi))
            confidences.append(float(score))

    # Return detections and confidences
    return [detections, confidences]


def main():
    # Read and store arguments
    confThreshold = args.thr
    nmsThreshold = args.nms
    inpWidth = args.width
    inpHeight = args.height
    modelDetector = args.model
    modelRecognition = args.ocr

    # Load network
    detector = cv.dnn.readNet(modelDetector)
    recognizer = cv.dnn.readNet(modelRecognition)

    # Create a new named window
    kWinName = "EAST: An Efficient and Accurate Scene Text Detector"
    cv.namedWindow(kWinName, cv.WINDOW_NORMAL)
    outNames = []
    outNames.append("feature_fusion/Conv_7/Sigmoid")
    outNames.append("feature_fusion/concat_3")

    # Open a video file or an image file or a camera stream
    cap = cv.VideoCapture(args.input if args.input else 0)

    tickmeter = cv.TickMeter()
    while cv.waitKey(1) < 0:
        # Read frame
        hasFrame, frame = cap.read()
        if not hasFrame:
            cv.waitKey()
            break

        # Get frame height and width
        height_ = frame.shape[0]
        width_ = frame.shape[1]
        rW = width_ / float(inpWidth)
        rH = height_ / float(inpHeight)

        # Create a 4D blob from frame.
        blob = cv.dnn.blobFromImage(frame, 1.0, (inpWidth, inpHeight), (123.68, 116.78, 103.94), True, False)

        # Run the detection model
        detector.setInput(blob)

        tickmeter.start()
        outs = detector.forward(outNames)
        tickmeter.stop()

        # Get scores and geometry
        scores = outs[0]
        geometry = outs[1]
        [boxes, confidences] = decodeBoundingBoxes(scores, geometry, confThreshold)

        # Apply NMS
        indices = cv.dnn.NMSBoxesRotated(boxes, confidences, confThreshold, nmsThreshold)
        for i in indices:
            # get 4 corners of the rotated rect
            vertices = cv.boxPoints(boxes[i])
            # scale the bounding box coordinates based on the respective ratios
            for j in range(4):
                vertices[j][0] *= rW
                vertices[j][1] *= rH


            # get cropped image using perspective transform
            if modelRecognition:
                cropped = fourPointsTransform(frame, vertices)
                cropped = cv.cvtColor(cropped, cv.COLOR_BGR2GRAY)

                # Create a 4D blob from cropped image
                blob = cv.dnn.blobFromImage(cropped, size=(100, 32), mean=127.5, scalefactor=1 / 127.5)
                recognizer.setInput(blob)

                # Run the recognition model
                tickmeter.start()
                result = recognizer.forward()
                tickmeter.stop()

                # decode the result into text
                wordRecognized = decodeText(result)
                cv.putText(frame, wordRecognized, (int(vertices[1][0]), int(vertices[1][1])), cv.FONT_HERSHEY_SIMPLEX,
                           0.5, (255, 0, 0))

            for j in range(4):
                p1 = (int(vertices[j][0]), int(vertices[j][1]))
                p2 = (int(vertices[(j + 1) % 4][0]), int(vertices[(j + 1) % 4][1]))
                cv.line(frame, p1, p2, (0, 255, 0), 1)

        # Put efficiency information
        label = 'Inference time: %.2f ms' % (tickmeter.getTimeMilli())
        cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))

        # Display the frame
        cv.imshow(kWinName, frame)
        tickmeter.reset()


if __name__ == "__main__":
    main()
$ text_detection -m=[path_to_text_detect_model] -ocr=[path_to_text_recognition_model]

提供预先训练的 ONNX 模型

一些预先训练的模型可以在https://drive.google.com/drive/folders/1cTbQ3nuZG-EKWak6emD_s8_hHXWz7lAr?usp=sharing找到。

下表显示了它们在不同文本识别数据集上的表现:

文本识别模型的性能是在OpenCV DNN上测试的,不包括文本检测模型。

选型建议

文本识别模型的输入是文本检测模型的输出,这导致文本检测的性能极大地影响着文本识别的性能。

DenseNet_CTC 的参数最小,FPS 最好,适合边缘设备,对计算成本非常敏感。如果你的计算资源有限,又想达到更好的准确率,VGG_CTC 是个不错的选择。

CRNN_VGG_BiLSTM_CTC适用于对识别准确率要求较高的场景。


http://www.kler.cn/a/441366.html

相关文章:

  • “AI质量评估系统:智能守护,让品质无忧
  • 嵌入式实时操作系统
  • 国产编辑器EverEdit - 大纲视图
  • 【数据分享】1929-2024年全球站点的逐月平均能见度(Shp\Excel\免费获取)
  • ios打包:uuid与udid
  • 【算法】快速排序1
  • RabbitMQ安装延迟消息插件(mq报错)
  • YOLO 数据增强 Python 脚本(可选次数,无限随机增强)- 一键执行搞定,自动化提升训练集质量 | 幽络源
  • 在 Docker 中运行 Golang 应用程序,如何做?
  • 电子应用设计方案-56:智能书柜系统方案设计
  • Mac 开机 一闪框 mediasharingd
  • MySQL 事务与锁机制:确保数据一致性
  • 安装 kaldifeat
  • 企业网络构建:如何满足业务需求与提升效率
  • go开发中interface和方法接收器的使用
  • 只需3步,使用Stable Diffusion无限生成AI数字人视频
  • dolphinscheduler服务RPC框架源码解析(五)RPC提供者服务调用真实方法实现
  • ElasticSearch 数据聚合与运算
  • 达梦查询表字段详细信息脚本(字段名称、描述、类型、长度及是否为空)
  • MSSQL AlwaysOn 可用性组(Availability Group)中的所有副本均不健康排查步骤和解决方法
  • 从源码构建安装Landoop kafka-connect-ui
  • gRPC为什么比基于JSON的REST API快
  • Copilot,Cursor和通义灵码:到底谁才是你的最强代码助手?
  • 【学习笔记总结】华为云:应用上云后的安全规划及设计
  • 问题:ValueError: too many values to unpack
  • 【python篇】——python基础语法一篇就能明白,快速理解