当前位置: 首页 > article >正文

深度学习基础--自定义函数对数据集进行图像分类,以车牌号识别为例

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

前言

  • 对于做机器学习来说,数据才是最重要的,如果对于图像识别数据没有归类,如何自己进行进行数据标签化、归类呢???
  • 这个划分数据集,对于准确率计算难度还是有的,主要是矩阵配对,我做了一个多小时才算好
  • 这周是课设周,跟新较慢;
  • 欢迎收藏加关注,本人将会持续更新。

    文章目录

    • 1、讲解
    • 2、案例
      • 1、数据处理
        • 1、导入库
        • 2、查看文件名称
        • 3、展示一批数据
        • 4、自定义数据集
          • 1、定义数据标签
          • 2、自定义数据加载器
        • 5、数据加载与标准化
        • 6、数据划分
        • 7、动态加载数据
      • 2、创建模型
      • 3、模型训练
        • 1、设置超参数
        • 2、创建训练集
        • 3、创建测试集
        • 4、模型正式训练
      • 4、结果展示

1、讲解

对于划分好的数据,如下文件夹所示:

在这里插入图片描述

一般对于这种情况下的数据加载,一般直接用 datasets.ImageFolders即可,这样有几个好处,如:

  • 每一个标签就是一个神经元,神经元也是一维的

但是对于这种情况,没有分类好的数据,如下:

在这里插入图片描述

这种情况就需要自己去划分,划分核心就是:图片绑定标签。

对于这种情况,本文这种情况对于不同类型,其实主要就是根据车牌号数据特征创建标签,我们可以创建一个矩阵,行代表车牌号长度,列代表车牌号数据,这样就可以对每一个数据标签都可以转化为我们自定义矩阵上,但是有以下注意:

  • 不是一个神经元代表一个类,这个时候是一个矩阵代表一个标签,这个时候在最后分类的时候需要在全连接层展开后转化成矩阵大小
  • 准确率计算:这个时候很难,需要将预测和自定义标签再次转化为一样的矩阵维度,具体如代码所示(难度我感觉还是有的),我弄了一两个小时才弄好。
  • 准确率、损失率:选取平均准确率、损失率
  • 具体如下代码所示

2、案例

1、数据处理

1、导入库

import torch  
import torchvision  
import torch.nn as nn 
from torchvision import datasets, transforms 

device = 'cuda' if torch.cuda.is_available() else 'cpu'

device
'cuda'

2、查看文件名称

import os, pathlib 

data_dir = './data/'
data_dir = pathlib.Path(data_dir)

data_paths = data_dir.glob('*')
data_paths_name = [str(path) for path in data_paths]
data_paths_name
[
'data/000000000_川W9BR26.jpg',
 'data/000000000_藏WP66B0.jpg',
 'data/000000001_沪E264UD.jpg',
 'data/000000001_津D8Z15T.jpg',
]

文件名格式:data\000000000_川W9BR26.jpg
需求:划分文件名,提取车牌号

classnames = [str(path).split("/")[1].split('_')[1].split('.')[0] for path in data_paths_name]
classnames
[
'川W9BR26',
 '藏WP66B0',
 '沪E264UD',
 '津D8Z15T',
 '浙E198UJ',
 '陕Z813VB',
]

3、展示一批数据

import matplotlib.pyplot as plt 

plt.figure(figsize=(15, 4))

for i in range(18):
    plt.subplot(3, 6, i + 1)
    
    images = plt.imread(data_paths_name[i])
    plt.imshow(images)
    
plt.show()


在这里插入图片描述

4、自定义数据集

1、定义数据标签
import numpy as np 

char_enum = ["京","沪","津","渝","冀","晋","蒙","辽","吉","黑","苏","浙","皖","闽","赣","鲁",\
              "豫","鄂","湘","粤","桂","琼","川","贵","云","藏","陕","甘","青","宁","新","军","使"]

number = [str(i) for i in range(0, 10)]
alphabet = [chr(i) for i in range(65, 91)]  # chr(i) 将整数转化为对应的Unicode字符集

'''
标签化: 采用矩阵进行分类
    矩阵行: 长度等于 车牌长度
    矩阵列: 车牌内容
'''

content = char_enum + number + alphabet  # 矩阵列
content_length = len(content)           # 矩阵列长度
plate_length = len(classnames[0])       # 矩阵行长度

# 字符串数字化,转化为矩阵进行存储
def text2vec(text):
    vector = np.zeros([plate_length, content_length])
    for i, c in  enumerate(text):
        col = content.index(c)
        vector[i][col] = 1.0    # 1.0 代表有该内容
    
    return vector 

all_labels = [text2vec(i) for i in classnames] 
2、自定义数据加载器

目的:图片数据 + 图片标签相配对, 和图片处理

import PIL  
from PIL import Image 

class MyData(torch.utils.data.Dataset):
    def __init__(self, all_data_paths, all_labels, data_transforms):
        super().__init__()
        self.data_paths = all_data_paths 
        self.data_labels = all_labels 
        self.transforms = data_transforms 
        
    # 重写  __len__(self)  获取数据集大小
    def __len__(self):
        return len(all_labels)
    
    # 重写 __getitem(self, index) 通过索引获取数据集的某一个数据, 获取格式: 图片,标签
    def __getitem__(self, index):
        image = Image.open(self.data_paths[index])
        label = self.data_labels[index]
        
        if self.transforms:
            image = self.transforms(image)
            
        return image, label 

5、数据加载与标准化

data_transform = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

total_data = MyData(data_paths_name, all_labels, data_transform)
total_data
<__main__.MyData at 0x7f3104921be0>

6、数据划分

train_size = int(len(total_data) * 0.8)
test_size = len(total_data) - train_size 
train_data, test_data = torch.utils.data.random_split(total_data, [train_size, test_size])
print("train_size: ", len(train_data))
print("test_size: ", len(test_data))
train_size:  10940
test_size:  2735

7、动态加载数据

batch_size = 16   # 每批数据位 16

train_dl = torch.utils.data.DataLoader(train_data,
                                       shuffle=True,
                                       batch_size=batch_size)

test_dl = torch.utils.data.DataLoader(test_data,
                                      shuffle=True,
                                      batch_size=batch_size)
# 查看数据维度
for images, labels in train_dl:
    print("images: ", images.shape)
    print("lables: ", labels.shape)
    break
images:  torch.Size([16, 3, 224, 224])
lables:  torch.Size([16, 7, 69])

2、创建模型

由于模型比较简单,故这里定义模型结构:
1、卷积核:5 * 5,步伐:1 * 1,填充:0
2、卷积通道:3 -> 12 -> 12 -> 池化 -> 24 -> 24 -> 重新塑造形状(自定义标签形状)
3、每一次通过卷积层后 标准化数据

import torch.nn.functional as F 

class NetWork(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(3, 12, kernel_size=5, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(12)
        self.conv2 = nn.Conv2d(12, 12, kernel_size=5, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(12)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(12, 24, kernel_size=5, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(24)
        self.conv4 = nn.Conv2d(24, 24, kernel_size=5, stride=1, padding=0)
        self.bn4 = nn.BatchNorm2d(24)
        self.fc1 = nn.Linear(24 * 50 * 50, plate_length * content_length)
        
        # 映射成标签形状
        self.reshape = Reshape([plate_length, content_length])
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool(x)
        x = x.view(-1, 24 * 50 * 50)
        x = self.fc1(x)
        
        x = self.reshape(x)
        
        return x 
        
# 定义映射---> 匹配标签
class Reshape(nn.Module):
    def __init__(self, shape):
        super(Reshape, self).__init__()
        self.shape = shape 
        
    def forward(self, x):
        return x.view(x.size(0), *self.shape)   # x.size(0) x展开后的大小 --> self.shape大小(*作用是解包[])
    
model = NetWork().to(device)
model
NetWork(
  (conv1): Conv2d(3, 12, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(12, 12, kernel_size=(5, 5), stride=(1, 1))
  (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(12, 24, kernel_size=(5, 5), stride=(1, 1))
  (bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(24, 24, kernel_size=(5, 5), stride=(1, 1))
  (bn4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=60000, out_features=483, bias=True)
  (reshape): Reshape()
)

3、模型训练

1、设置超参数

learn_rate = 1e-4 
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)
loss_fn = nn.CrossEntropyLoss()

2、创建训练集

def train(dataloader, model, optimizer, loss_fn):
    size = len(dataloader.dataset)  # 数据集大小
    num_batches = len(dataloader)   # 批次数目
    
    model.train()
    train_loss, correct = 0.0, 0.0  # 初始化为浮点数
    
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        
        # 前向传播
        pred = model(X)
        
        # 确保 pred 和 y 的形状匹配 [N, 7, 69]
        pred_flat = pred.view(-1, 69)  # [N * 7, 69]
        y_flat = y.view(-1, 69)  # [N * 7, 69]

        # 计算损失
        loss = loss_fn(pred_flat, y_flat.float())
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 更新训练损失和准确率
        train_loss += loss.item()

        # 计算准确率(例如,可以计算每个位置上的平均准确率)
        with torch.no_grad():
            pred_probs = F.sigmoid(pred_flat)
            batch_correct = ((pred_probs > 0.5) == y_flat.bool()).float().mean().item()
            correct += batch_correct

    # 计算平均损失和准确率
    train_loss /= num_batches
    train_acc = correct / num_batches

    return train_acc, train_loss

3、创建测试集

def test(dataloader, model, loss_fn):
    num_batches = len(dataloader)  # 批次数目
    
    test_loss, correct = 0.0, 0.0  # 初始化为浮点数
    
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            
            pred = model(X)
            # 确保 pred 和 y 的形状匹配 [N, 7, 69]
            pred_flat = pred.view(-1, 69)  # [N * 7, 69]
            y_flat = y.view(-1, 69)  # [N * 7, 69]

            # 计算损失
            loss = loss_fn(pred_flat, y_flat.float())
            test_loss += loss.item()
            
            # 计算准确率(例如,可以计算每个位置上的平均准确率)
            pred_probs = F.sigmoid(pred_flat)
            batch_correct = ((pred_probs > 0.5) == y_flat.bool()).float().mean().item()
            correct += batch_correct
    
    # 计算平均损失和准确率
    test_loss /= num_batches
    test_acc = correct / num_batches
    
    return test_acc, test_loss

4、模型正式训练

epochs = 20

train_acc, train_loss, test_acc, test_loss = [], [], [], []

for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, optimizer, loss_fn)
    
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn) 
    
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss) 
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    # 输出
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}')
    print(template.format(epoch + 1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
Epoch: 1, Train_acc:96.7%, Train_loss:2.410, Test_acc:98.3%, Test_loss:1.407
Epoch: 2, Train_acc:98.7%, Train_loss:0.686, Test_acc:98.5%, Test_loss:0.952
Epoch: 3, Train_acc:99.1%, Train_loss:0.251, Test_acc:98.7%, Test_loss:0.813
Epoch: 4, Train_acc:99.3%, Train_loss:0.108, Test_acc:98.8%, Test_loss:0.744
Epoch: 5, Train_acc:99.4%, Train_loss:0.061, Test_acc:99.0%, Test_loss:0.751
Epoch: 6, Train_acc:99.3%, Train_loss:0.083, Test_acc:98.6%, Test_loss:0.892
Epoch: 7, Train_acc:99.3%, Train_loss:0.097, Test_acc:99.0%, Test_loss:0.962
Epoch: 8, Train_acc:99.4%, Train_loss:0.058, Test_acc:99.0%, Test_loss:0.847
Epoch: 9, Train_acc:99.5%, Train_loss:0.041, Test_acc:99.0%, Test_loss:0.920
Epoch:10, Train_acc:99.6%, Train_loss:0.030, Test_acc:99.1%, Test_loss:0.879
Epoch:11, Train_acc:99.5%, Train_loss:0.036, Test_acc:99.1%, Test_loss:1.031
Epoch:12, Train_acc:99.5%, Train_loss:0.040, Test_acc:99.0%, Test_loss:1.004
Epoch:13, Train_acc:99.6%, Train_loss:0.028, Test_acc:99.1%, Test_loss:0.862
Epoch:14, Train_acc:99.6%, Train_loss:0.017, Test_acc:99.2%, Test_loss:0.911
Epoch:15, Train_acc:99.7%, Train_loss:0.020, Test_acc:99.1%, Test_loss:0.864
Epoch:16, Train_acc:99.7%, Train_loss:0.024, Test_acc:99.2%, Test_loss:1.007
Epoch:17, Train_acc:99.7%, Train_loss:0.025, Test_acc:99.2%, Test_loss:1.121
Epoch:18, Train_acc:99.7%, Train_loss:0.017, Test_acc:99.2%, Test_loss:0.897
Epoch:19, Train_acc:99.7%, Train_loss:0.007, Test_acc:99.2%, Test_loss:0.867
Epoch:20, Train_acc:99.7%, Train_loss:0.011, Test_acc:99.3%, Test_loss:0.887

4、结果展示

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training= Loss')
plt.show()


在这里插入图片描述


http://www.kler.cn/a/444751.html

相关文章:

  • SpringBoot开发——整合JSONPath解析JSON信息
  • linux---多线程
  • 【服务器】MyBatis是如何在java中使用并进行分页的?
  • 面试题整理9----谈谈对k8s的理解2
  • Go框架比较:goframe、beego、iris和gin
  • 顺序表的操作
  • MCU驱动使用
  • MFC 应用程序语言切换
  • #Java篇:java项目init和写接口流程步骤详细
  • UG NX二次开发(C#)-如何设置UGOpen的UF_CAM_geom_type_e枚举类型
  • Go语言封装Cron定时任务
  • 【c++丨STL】set/multiset的使用
  • 2025年NISP考试时间是什么时候?NISP要多少钱?NISP考试时间及费用超全解说!
  • tryhackme-Pre Security-HTTP in Detail(HTTP的详细内容)
  • 2024159读书笔记|《南山册页:齐白石果蔬册鱼虫册》节选
  • 【Rust自学】4.3. 所有权与函数
  • WPF+MVVM案例实战与特效(四十三)- 打造动态炫酷彩虹字控件,让你的界面动起来
  • SQLite 命令
  • 亚信安全春节14天双倍假期通告
  • 在 Windows 上添加 github SSH 密钥
  • Unity录屏插件-使用Recorder录制视频
  • vscode不同的项目使用不同的环境变量或编译环境
  • 《小米创业思考》
  • 【数据库系列】MongoTemplate 基本入门:MongoDB 的增删改查
  • Ubuntu搭建ES8集群+加密通讯+https访问
  • 灯光开关切换(c++)