当前位置: 首页 > article >正文

利用深度学习实现验证码识别-2-使用Python导出ONNX模型并在Java中调用实现验证码识别

在这里插入图片描述

1. Python部分:导出ONNX模型

首先,我们需要在Python中定义并导出一个已经训练好的验证码识别模型。以下是完整的Python代码:

import string
import torch
import torch.nn as nn
import torch.nn.functional as F

CHAR_SET = string.digits

# 优化后的模型设计
class CaptchaModel(nn.Module):
    def __init__(self):
        super(CaptchaModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.fc1 = nn.Linear(128 * 5 * 12, 256)  # 调整为实际展平维度
        self.fc2 = nn.Linear(256, 4 * len(CHAR_SET))
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = F.relu(F.max_pool2d(self.conv3(x), 2))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x.view(-1, 4, len(CHAR_SET))

# 使用CUDA,如果可用的话
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 假设你的模型已经训练好并保存在 'best_model.pth'
model = CaptchaModel().to(device)
model.load_state_dict(torch.load('best_model.pth'))

# 生成一个测试输入 (示例输入的形状应与模型输入形状一致)
dummy_input = torch.randn(1, 1, 40, 100).to(device)

# 导出模型为 ONNX 格式
torch.onnx.export(model, dummy_input, "captcha_model.onnx", 
                  input_names=["input"], output_names=["output"], 
                  dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}})

print("Model exported to captcha_model.onnx")

这段代码定义了一个验证码识别模型,并将其导出为ONNX格式,以便在Java中使用。

2. Java部分:调用ONNX模型进行验证码识别

接下来,我们使用Java调用导出的ONNX模型进行验证码识别。以下是完整的Java代码:

  • 引用onnxruntime-1.19.0.jar
package com.tushuoit;

import ai.onnxruntime.*;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import java.util.List;

public class CaptchaInference {
    private static final String CHAR_SET = "0123456789";
    private static final int INPUT_WIDTH = 100;
    private static final int INPUT_HEIGHT = 40;
    private static final Random random = new Random();

    public static void main(String[] args) throws Exception {
        // 随机生成4个字符的验证码文本
        String captchaText = generateRandomText(4);
        System.out.println("Generated Captcha Text: " + captchaText);

        // 生成包含文本的Bitmap (BufferedImage)
        BufferedImage captchaImage = generateCaptcha(captchaText, 36, INPUT_WIDTH, INPUT_HEIGHT);

        // 将Bitmap保存为文件(仅用于查看生成的图像,实际使用中可以省略)
        ImageIO.write(captchaImage, "png", new File("generated_captcha.png"));

        // 将图像转换为浮点数数组,并进行归一化处理
        float[] inputData = imageToFloatArray(captchaImage);

        // 创建ONNX Runtime环境
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions opts = new OrtSession.SessionOptions();

        // 加载ONNX模型
        OrtSession session = env.createSession("captcha_model.onnx", opts);

        // 创建输入张量
        FloatBuffer inputBuffer = FloatBuffer.wrap(inputData);
        OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputBuffer,
                new long[] { 1, 1, INPUT_HEIGHT, INPUT_WIDTH });

        // 进行推理
        OrtSession.Result result = session.run(Collections.singletonMap("input", inputTensor));

        // Extract output tensor and decode it
        float[][][] outputData = (float[][][]) result.get(0).getValue();
        List<String> decodedTexts = decodeOutput(outputData);

        // Print the decoded captcha text
        for (String text : decodedTexts) {
            System.out.println("Predicted Captcha Text: " + text);
        }

        System.out.println("Inference completed.");
        // 释放资源
        session.close();
        env.close();
    }

    // 随机生成指定长度的验证码文本
    private static String generateRandomText(int length) {
        StringBuilder text = new StringBuilder(length);
        for (int i = 0; i < length; i++) {
            text.append(CHAR_SET.charAt(random.nextInt(CHAR_SET.length())));
        }
        return text.toString();
    }

    // 生成包含文本的BufferedImage
    private static BufferedImage generateCaptcha(String text, int fontSize, int width, int height) {
        BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
        Graphics2D g2d = image.createGraphics();

        // 设置背景颜色为白色
        g2d.setColor(Color.WHITE);
        g2d.fillRect(0, 0, width, height);

        // 设置字体和颜色
        g2d.setFont(new Font("DroidSansMono", Font.PLAIN, fontSize));
        g2d.setColor(Color.BLACK);

        // 绘制文本
        FontMetrics fm = g2d.getFontMetrics();
        int x = 5; // 文字开始的X坐标
        int y = fm.getAscent() + 5; // 文字开始的Y坐标
        g2d.drawString(text, x, y);

        g2d.dispose();
        return image;
    }

    // 将BufferedImage转换为float数组,并进行归一化处理
    private static float[] imageToFloatArray(BufferedImage image) {
        int width = image.getWidth();
        int height = image.getHeight();
        float[] floatArray = new float[width * height];

        for (int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                int rgb = image.getRGB(x, y);
                int gray = (rgb >> 16) & 0xFF; // 因为是灰度图,只需获取一个通道的值
                floatArray[y * width + x] = (gray / 255.0f - 0.5f) * 2.0f; // 归一化到[-1, 1]
            }
        }

        return floatArray;
    }

    private static List<String> decodeOutput(float[][][] outputData) {
        List<String> decodedTexts = new ArrayList<>();
        for (float[][] singleOutput : outputData) {
            StringBuilder decodedText = new StringBuilder();
            for (float[] charProbabilities : singleOutput) {
                int maxIndex = getMaxIndex(charProbabilities);
                decodedText.append(CHAR_SET.charAt(maxIndex));
            }
            decodedTexts.add(decodedText.toString());
        }
        return decodedTexts;
    }

    private static int getMaxIndex(float[] probabilities) {
        int maxIndex = 0;
        float maxProb = probabilities[0];
        for (int i = 1; i < probabilities.length; i++) {
            if (probabilities[i] > maxProb) {
                maxProb = probabilities[i];
                maxIndex = i;
            }
        }
        return maxIndex;
    }
}

这段Java代码首先生成一个随机的验证码图像,然后将其转换为模型输入格式,并通过ONNX Runtime调用导出的模型进行推理,最后解码模型的输出以获取识别的验证码文本。
在这里插入图片描述

总结

通过上述步骤,我们成功地在Python中导出了一个验证码识别模型,并在Java中调用该模型进行验证码识别。这种方法充分利用了Python在深度学习模型训练和导出方面的优势,以及Java在实际应用部署和性能方面的优势,实现了高效的验证码识别系统。


http://www.kler.cn/news/293348.html

相关文章:

  • 对极约束及其性质 —— 公式详细推导
  • ElementUI2.x El-Select组件 处理使用远程查找时下拉箭头丢失问题
  • 用 CSS 实现太阳系运行效果
  • XSS 漏洞检测与利用全解析:守护网络安全的关键洞察
  • 微信小程序请求数据接口封装
  • MutationObserver小试牛刀
  • 计算机基础知识-2
  • 微服务--Nacos
  • 前端进阶:JavaScript实现优雅遮罩层下的表单验证技巧
  • AI聊天应用不能上架?Google play对AI类型应用的规则要求是什么?
  • 高效实用的网站ICP备案查询接口
  • VMEMMAP分析
  • Oracle RAC关于多节点访问同一个数据的过程
  • C 语言指针与数组的深度解析
  • 鸿蒙轻内核M核源码分析系列四 中断Hwi
  • 无人机纪录片航拍认知
  • LLM指令微调实践与分析
  • 用RPC Performance Inspector 优化你的区块链
  • 技术周刊 | Rspack 1.0、v0 支持 Vue、2024 年度编程语言排行榜、Ideogram 2.0、从 0 实现一个 React
  • 深度学习(九)-图像形态操作
  • 《C++进阶之路:探寻预处理宏的替代方案》
  • Spring Boot实现大文件分片下载
  • 谈一谈MVCC
  • 人工智能、机器学习和深度学习有什么区别?应用领域有哪些?
  • Linux 简介
  • HNU-2023电路与电子学-实验1
  • 如何看待AI技术对人们生活的影响?
  • 【网络安全】Sping Boot 未授权访问敏感数据
  • 时下改变AI的6大NLP语言模型
  • 关于 export HF_ENDPOINT=https://hf-mirror.com