当前位置: 首页 > article >正文

第P10周-Pytorch实现车牌号识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目标

具体实现

(一)环境

语言环境:Python 3.10
编 译 器: PyCharm
框 架: Pytorch

(二)具体步骤
1. 文件结构

image.png

2. config.py
import argparse  
  
def get_options(parser=argparse.ArgumentParser()):  
    parser.add_argument('--workers', type=int, default=0, help='Number of parallel workers')  
    parser.add_argument('--batch-size', type=int, default=4, help='input batch size, default=32')  
    parser.add_argument('--size', type=tuple, default=(224, 224), help='input image size')  
    parser.add_argument('--lr', type=float, default=1e-4, help='learning rate, default=0.0001')  
    parser.add_argument('--epochs', type=int, default=20, help='number of epochs')  
    parser.add_argument('--seed', type=int, default=112, help='random seed')  
    parser.add_argument('--save-path', type=str, default='./models/', help='path to save checkpoints')  
  
    opt = parser.parse_args()  
  
    if opt:  
        print(f'num_workers:{opt.workers}')  
        print(f'batch_size:{opt.batch_size}')  
        print(f'learn rate:{opt.lr}')  
        print(f'epochs:{opt.epochs}')  
        print(f'random seed:{opt.seed}')  
        print(f'save_path:{opt.save_path}')  
  
    return opt  
  
if __name__ == '__main__':  
    opt = get_options()
3. Utils.py
import torch  
import pathlib  
import matplotlib.pyplot as plt  
from torchvision.transforms import transforms  
  
  
# 第一步:设置GPU  
def USE_GPU():  
    if torch.cuda.is_available():  
        print('CUDA is available, will use GPU')  
        device = torch.device("cuda")  
    else:  
        print('CUDA is not available. Will use CPU')  
        device = torch.device("cpu")  
    return device  
  
temp_dict = dict()  
def recursive_iterate(path):  
    """  
    根据所提供的路径遍历该路径下的所有子目录,列出所有子目录下的文件  
    :param path: 路径  
    :return: 返回最后一级目录的数据  
    """    path = pathlib.Path(path)  
    for file in path.iterdir():  
        if file.is_file():  
            temp_key = str(file).split('\\')[-2]  
            if temp_key in temp_dict:  
                temp_dict.update({temp_key: temp_dict[temp_key] + 1})  
            else:  
                temp_dict.update({temp_key: 1})  
            # print(file)  
        elif file.is_dir():  
            recursive_iterate(file)  
  
    return temp_dict  
  
  
def data_from_directory(directory, train_dir=None, test_dir=None, show=False):  
    """  
    提供是的数据集是文件形式的,提供目录方式导入数据,简单分析数据并返回数据分类  
    :param test_dir: 是否设置了测试集目录  
    :param train_dir: 是否设置了训练集目录  
    :param directory: 数据集所在目录  
    :param show: 是否需要以柱状图形式显示数据分类情况,默认显示  
    :return: 数据分类列表,类型: list  
    """    global total_image  
    print("数据目录:{}".format(directory))  
    data_dir = pathlib.Path(directory)  
  
    # for d in data_dir.glob('**/*'): # **/*通配符可以遍历所有子目录  
    #     if d.is_dir():  
    #         print(d)    class_name = []  
    total_image = 0  
    # temp_sum = 0  
  
    if train_dir is None or test_dir is None:  
        data_path = list(data_dir.glob('*'))  
        class_name = [str(path).split('\\')[-1] for path in data_path]  
        print("数据分类: {}, 类别数量:{}".format(class_name, len(list(data_dir.glob('*')))))  
        total_image = len(list(data_dir.glob('*/*')))  
        print("图片数据总数: {}".format(total_image))  
    else:  
        temp_dict.clear()  
        train_data_path = directory + '/' + train_dir  
        train_data_info = recursive_iterate(train_data_path)  
        print("{}目录:{},{}".format(train_dir, train_data_path, train_data_info))  
  
        temp_dict.clear()  
        test_data_path = directory + '/' + test_dir  
        print("{}目录:{},{}".format(test_dir,  test_data_path, recursive_iterate(test_data_path)))  
        class_name = temp_dict.keys()  
  
    if show:  
        # 隐藏警告  
        import warnings  
        warnings.filterwarnings("ignore")  # 忽略警告信息  
        plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签  
        plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号  
        plt.rcParams['figure.dpi'] = 100  # 分辨率  
  
        for i in class_name:  
            data = len(list(pathlib.Path((directory + '\\' + i + '\\')).glob('*')))  
            plt.title('数据分类情况')  
            plt.grid(ls='--', alpha=0.5)  
            plt.bar(i, data)  
            plt.text(i, data, str(data), ha='center', va='bottom')  
            print("类别-{}:{}".format(i, data))  
            # temp_sum += data  
        plt.show()  
  
    # if temp_sum == total_image:  
    #     print("图片数据总数检查一致")  
    # else:    #     print("数据数据总数检查不一致,请检查数据集是否正确!")  
    return class_name  
  
  
def get_transforms_setting(size):  
    """  
    获取transforms的初始设置  
    :param size: 图片大小  
    :return: transforms.compose设置  
    """    transform_setting = {  
        'train': transforms.Compose([  
            transforms.Resize(size),  
            transforms.ToTensor(),  
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  
        ]),  
        'test': transforms.Compose([  
            transforms.Resize(size),  
            transforms.ToTensor(),  
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  
        ])  
    }  
  
    return transform_setting  
  
  
# 训练循环  
def train(dataloader, device, model, loss_fn, optimizer):  
    size = len(dataloader.dataset)  # 训练集的大小  
    num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)  
  
    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率  
  
    for X, y in dataloader:  # 获取图片及其标签  
        X, y = X.to(device), y.to(device)  
  
        # 计算预测误差  
        pred = model(X)  # 网络输出  
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失  
  
        # 反向传播  
        optimizer.zero_grad()  # grad属性归零  
        loss.backward()  # 反向传播  
        optimizer.step()  # 每一步自动更新  
  
        # 记录acc与loss  
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()  
        train_loss += loss.item()  
  
    train_acc /= size  
    train_loss /= num_batches  
  
    return train_acc, train_loss  
  
  
def test(dataloader, device, model, loss_fn):  
    size = len(dataloader.dataset)  # 测试集的大小  
    num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)  
    test_loss, test_acc = 0, 0  
  
    # 当不进行训练时,停止梯度更新,节省计算内存消耗  
    with torch.no_grad():  
        for imgs, target in dataloader:  
            imgs, target = imgs.to(device), target.to(device)  
  
            # 计算loss  
            target_pred = model(imgs)  
            loss = loss_fn(target_pred, target)  
  
            test_loss += loss.item()  
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()  
  
    test_acc /= size  
    test_loss /= num_batches  
  
    return test_acc, test_loss  
  
  
from PIL import Image  
  
def predict_one_image(image_path, device, model, transform, classes):  
    """  
    预测单张图片  
    :param image_path: 图片路径  
    :param device: CPU or GPU    :param model: cnn模型  
    :param transform:    :param classes:    :return:  
    """    test_img = Image.open(image_path).convert('RGB')  
    plt.imshow(test_img)  # 展示预测的图片  
  
    test_img = transform(test_img)  
    img = test_img.to(device).unsqueeze(0)  
  
    model.eval()  
    output = model(img)  
  
    _, pred = torch.max(output, 1)  
    pred_class = classes[pred]  
    print(f'预测结果是:{pred_class}')
4.dataset.py
import os  
  
import torch  
from PIL import Image  
from torch.utils.data import Dataset  
from torchvision import transforms, datasets  
from Utils import get_transforms_setting  
  
class CarLicenceDataset(Dataset):  
    def __init__(self, root_dir, all_labels, transform=None):  
        self.img_dir = root_dir    # 图像目录路径  
        self.img_labels = all_labels # 获取标签信息  
        self.transform = transform  # 目标转换函数  
  
        # self.total_data = datasets.ImageFolder(root_dir, transform=transform)  
        # print(self.total_data)        # # 划分数据集  
        # train_size = int(0.8 * len(self.total_data))  
        # test_size = len(self.total_data) - train_size        # self.train_dataset, self.test_dataset = torch.utils.data.random_split(self.total_data, [train_size, test_size])        # print(self.train_dataset, self.test_dataset)  
    def __len__(self):  
        return len(self.img_labels)  
  
    def __getitem__(self, idx):  
        image = Image.open(self.img_dir[idx]).convert('RGB')  
        label = self.img_labels[idx]  
  
        if self.transform:  
            image = self.transform(image)  
  
        return image, label  
  
    # def __getds__(self, dstype):  
    #     if dstype == 'train':    #         return self.train_dataset    #     elif dstype == 'test':    #         return self.test_dataset    #     else:    #         pass
5.model.py
import torch  
from torch import nn  
import torch.nn.functional as F  
  
class Network_bn(nn.Module):  
    def __init__(self, label_name_len=1, char_set_len=1):  
        super(Network_bn, self).__init__()  
        """  
        nn.Conv2d()函数:  
        第一个参数(in_channels)是输入的channel数量  
        第二个参数(out_channels)是输出的channel数量  
        第三个参数(kernel_size)是卷积核大小  
        第四个参数(stride)是步长,默认为1  
        第五个参数(padding)是填充大小,默认为0  
        """        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0)  
        self.bn1 = nn.BatchNorm2d(12)  
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)  
        self.bn2 = nn.BatchNorm2d(12)  
        self.pool = nn.MaxPool2d(2 ,2)  
        self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0)  
        self.bn4 = nn.BatchNorm2d(24)  
        self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)  
        self.bn5 = nn.BatchNorm2d(24)  
        self.fc1 = nn.Linear(24 *50 *50, label_name_len * char_set_len)  
        self.reshape = Reshape([label_name_len ,char_set_len])  
  
    def forward(self, x):  
        x = F.relu(self.bn1(self.conv1(x)))  
        x = F.relu(self.bn2(self.conv2(x)))  
        x = self.pool(x)  
        x = F.relu(self.bn4(self.conv4(x)))  
        x = F.relu(self.bn5(self.conv5(x)))  
        x = self.pool(x)  
        x = x.view(-1, 24 *50 *50)  
        x = self.fc1(x)  
  
        # 最终reshape  
        x = self.reshape(x)  
  
        return x  
  
# 定义Reshape层  
class Reshape(nn.Module):  
    def __init__(self, shape):  
        super(Reshape, self).__init__()  
        self.shape = shape  
  
    def forward(self, x):  
        return x.view(x.size(0), *self.shape)  
  
device = "cuda" if torch.cuda.is_available() else "cpu"  
print("Using {} device".format(device))  
  
model = Network_bn().to(device)  
model
6. train.py
import torch  
import os,PIL,random,pathlib  
import matplotlib.pyplot as plt  
from torch import nn  
import numpy as np  
import torchsummary  
  
from dataset import CarLicenceDataset  
from Utils import USE_GPU, get_transforms_setting  
from config import get_options  
from model import Network_bn  
  
  
device = USE_GPU()  # 获取GPU,有则使用GPU,否则使用CPU
opt = get_options()  # 获取训练超参数,预设的
transform = get_transforms_setting((224, 224))  # 获取数据转换配置
  
  
# 支持中文  
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签  
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号  
  
data_dir = './data/licence_plate/'  # 数据集路径
data_dir = pathlib.Path(data_dir)  
  
data_paths  = list(data_dir.glob('*'))  
classeNames = [str(path).split("\\")[2].split("_")[1].split(".")[0] for path in data_paths]  
# print(classeNames)  
  
data_paths     = list(data_dir.glob('*'))  
data_paths_str = [str(path) for path in data_paths]  
# print(data_paths_str)  
  
plt.figure(figsize=(14, 5))  
plt.suptitle("数据示例", fontsize=15)  
  
for i in range(18):  
    plt.subplot(3, 6, i + 1)  
    # plt.xticks([])  
    # plt.yticks([])    # plt.grid(False)  
    # 显示图片  
    images = plt.imread(data_paths_str[i])  
    plt.imshow(images)  
  
plt.show()  

image.png


  
char_enum = ["京","沪","津","渝","冀","晋","蒙","辽","吉","黑","苏","浙","皖","闽","赣","鲁",\  
              "豫","鄂","湘","粤","桂","琼","川","贵","云","藏","陕","甘","青","宁","新","军","使"]  
  
number   = [str(i) for i in range(0, 10)]    # 0 到 9 的数字  
alphabet = [chr(i) for i in range(65, 91)]   # A 到 Z 的字母  
  
char_set       = char_enum + number + alphabet  
char_set_len   = len(char_set)  
label_name_len = len(classeNames[0])  
  
# 将字符串数字化  
def text2vec(text):  
    vector = np.zeros([label_name_len, char_set_len])  
    for i, c in enumerate(text):  
        idx = char_set.index(c)  
        vector[i][idx] = 1.0  
    return vector  
  
all_labels = [text2vec(i) for i in classeNames]  
  
total_data = CarLicenceDataset(data_paths_str, all_labels, transform['train'])  
print(total_data)  
  
train_size = int(0.8 * len(total_data))  
test_size  = len(total_data) - train_size  
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])  
print(train_size,test_size)  
  
train_loader = torch.utils.data.DataLoader(train_dataset,  
                                           batch_size=16,  
                                           shuffle=True)  
test_loader = torch.utils.data.DataLoader(test_dataset,  
                                          batch_size=16,  
                                          shuffle=True)  
  
print("The number of images in a training set is: ", len(train_loader)*16)  
print("The number of images in a test set is: ", len(test_loader)*16)  
print("The number of batches per epoch is: ", len(train_loader))  
  
for X, y in test_loader:  
    print("Shape of X [N, C, H, W]: ", X.shape)  
    print("Shape of y: ", y.shape, y.dtype)  
    break  
  
model = Network_bn(label_name_len, char_set_len).to(device)

torchsummary.summary(model, input_size=(3, 224, 224))  # 打印网络结构

image.png

# 创建一个Adam优化器
optimizer  = torch.optim.Adam(model.parameters(),  
                              lr=opt.learning_rate,  # 从配置文件中取
                              weight_decay=0.0001)  
  
loss_model = nn.CrossEntropyLoss()  # 创建一个交叉熵损失函数
  
from torch.autograd import Variable  
  
  
def test(model, test_loader, loss_model):  
    size = len(test_loader.dataset)  
    num_batches = len(test_loader)  
  
    model.eval()  
    test_loss, correct = 0, 0  
    with torch.no_grad():  
        for X, y in test_loader:  
            X, y = X.to(device), y.to(device)  
            pred = model(X)  
  
            test_loss += loss_model(pred, y).item()  
  
    test_loss /= num_batches  
  
    print(f"Avg loss: {test_loss:>8f} \n")  
    return correct, test_loss  
  
  
def train(model, train_loader, loss_model, optimizer):  
    model = model.to(device)  
    model.train()  
  
    for i, (images, labels) in enumerate(train_loader, 0):  # 0是标起始位置的值。  
  
        images = Variable(images.to(device))  
        labels = Variable(labels.to(device))  
  
        optimizer.zero_grad()  
        outputs = model(images)  
  
        loss = loss_model(outputs, labels)  
        loss.backward()  
        optimizer.step()  
  
        if i % 1000 == 0:  
            print('[%5d] loss: %.3f' % (i, loss))  
  
test_acc_list  = []  
test_loss_list = []  
epochs = 30  
  
for t in range(epochs):  
    print(f"Epoch {t+1}\n-------------------------------")  
    train(model,train_loader,loss_model,optimizer)  
    test_acc,test_loss = test(model, test_loader, loss_model)  
    test_acc_list.append(test_acc)  
    test_loss_list.append(test_loss)  
print("Done!")  
  
  
import numpy as np  
import matplotlib.pyplot as plt  
  
from datetime import datetime  
current_time = datetime.now() # 获取当前时间  
  
x = [i for i in range(1,31)]  
  
plt.plot(x, test_loss_list, label="Loss", alpha=0.8)  
  
plt.xlabel("Epoch")  
plt.ylabel("Loss")  
plt.title(current_time) # 打卡请带上时间戳,否则代码截图无效  
  
plt.legend()  
plt.show()

image.png


http://www.kler.cn/a/556873.html

相关文章:

  • Node.js中不支持require和import两种导入模块的混用
  • python-leetcode-反转链表
  • 游戏引擎学习第115天
  • 力扣-二叉树-669 修剪二叉搜索树
  • 高频网络分析仪中的sdd是什么参数
  • STL介绍1:vector、pair、string、queue、map
  • Ubuntu 的RabbitMQ安装
  • 测试data_management函数
  • 网络安全:DeepSeek已经在自动的挖掘漏洞
  • 如何在 React 中测试高阶组件?
  • Windows 下如何对 node/vue 进行多版本管理?
  • Java常用设计模式及其应用场景
  • [Windows] Umi-OCR 开源批量文字识别 支持图片,文档,二维码,截图等
  • 从0-1搭建mac环境最新版
  • 常用加解密原理及实际使用
  • Vue2 和 Vue3 的区别
  • halcon激光三角测量(二十一)calibrate_sheet_of_light_calplate
  • Ubuntu24安装MongoDB(解压版)
  • 什么是向量化?ElasticSearch如何存储向量化?
  • 如何在 Vue 应用中实现权限管理?