当前位置: 首页 > article >正文

精准识别花生豆:基于EfficientNetB0的深度学习检测与分类项目

精准检测花生豆:基于EfficientNet的深度学习分类项目

在现代农业生产中,作物的质量检测和分类是确保产品质量的重要环节。针对花生豆的检测与分类需求,我们开发了一套基于深度学习的解决方案,利用EfficientNetB0模型实现高效、准确的花生豆分类。本博客将详细介绍该项目的背景、数据处理、模型架构、训练过程、评估方法及预测应用。

目录

  1. 项目背景
  2. 项目概述
  3. 数据处理
    • 数据集结构
    • 数据增强与规范化
  4. 模型架构
  5. 训练过程
    • 训练脚本 (train.py)
  6. 模型评估
    • 评估脚本 (evaluate.py)
  7. 预测与应用
    • 预测脚本 (predict.py)
  8. 项目成果
  9. 结论与未来工作

项目背景

花生豆作为一种重要的经济作物,其品质直接影响到市场价值和消费者满意度。传统的人工检测方法不仅耗时耗力,而且易受主观因素影响,难以实现大规模、精准的分类。因此,开发一种高效、准确的自动化检测系统显得尤为重要。

项目概述

本项目旨在利用深度学习技术,构建一个能够自动检测和分类花生豆的系统。通过收集和处理大量花生豆图像数据,训练一个高性能的卷积神经网络模型,实现对不同类别花生豆的精准分类。项目主要包括以下几个部分:

  • 数据处理:图像数据的加载、预处理与增强。
  • 模型架构:基于EfficientNetB0的分类模型设计。
  • 训练过程:模型的训练与优化,包括断点续训与学习率调度。
  • 模型评估:在测试集上的性能评估。
  • 预测应用:对新图像进行花生豆分类与标注。

数据处理

数据集结构

项目使用的数据集分为训练集、验证集和测试集,具体结构如下:

./data/dataset/
├── train/
│   ├── baiban/
│   ├── bandian/
│   ├── famei/
│   ├── faya/
│   ├── hongpi/
│   ├── qipao/
│   ├── youwu/
│   └── zhengchang/
├── validation/
│   ├── baiban/
│   ├── bandian/
│   ├── famei/
│   ├── faya/
│   ├── hongpi/
│   ├── qipao/
│   ├── youwu/
│   └── zhengchang/
└── test/
    ├── baiban/
    ├── bandian/
    ├── famei/
    ├── faya/
    ├── hongpi/
    ├── qipao/
    ├── youwu/
    └── zhengchang/

每个子文件夹对应一种花生豆类别,包含相应的图像数据。

数据增强与规范化

为了提高模型的泛化能力,训练过程中对图像数据进行了多种数据增强操作,如随机裁剪、水平翻转、旋转和颜色抖动。同时,使用ImageNet的均值和标准差对图像进行了归一化处理,与预训练模型的输入要求保持一致。

# utils/dataLoader.py

train_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

validation_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

test_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

模型架构

本项目采用了EfficientNetB0作为基础模型。EfficientNet系列通过系统性地平衡网络的宽度、深度和分辨率,在模型性能和计算效率之间取得了优异的平衡。具体来说:

  • 预训练权重:使用在ImageNet上预训练的权重,帮助模型在较小的数据集上快速收敛。
  • 冻结特征提取部分:根据需要,可以选择冻结模型的特征提取层,仅训练最后的分类器,适用于数据量较小的情况。
  • 分类器设计:在原有分类器前添加了Dropout层,减少过拟合风险。
# utils/model.py

class EfficientNetB0(nn.Module):
    def __init__(self, num_classes, pretrained=True, freeze_features=False):
        super(EfficientNetB0, self).__init__()
        if pretrained:
            self.model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        else:
            self.model = models.efficientnet_b0(weights=None)
        
        if freeze_features:
            for param in self.model.features.parameters():
                param.requires_grad = False

        in_features = self.model.classifier[1].in_features
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.4, inplace=True),
            nn.Linear(in_features, num_classes)
        )
    
    def forward(self, x):
        return self.model(x)

训练过程

训练脚本 (train.py)

训练脚本负责模型的训练与验证,包括数据加载、模型初始化、训练循环、学习率调度、模型保存和训练曲线绘制等功能。

关键功能包括:

  • 训练与验证循环:每个epoch包括训练阶段和验证阶段,记录损失与准确率。
  • 优化与调度:使用Adam优化器和ReduceLROnPlateau学习率调度器,根据验证损失动态调整学习率。
  • 模型保存:保存验证集准确率最高的模型,并定期自动保存模型检查点。
  • 断点续训:支持从保存的检查点继续训练,避免重复计算。
  • 训练曲线绘制:训练结束后,生成并保存训练与验证的准确率和损失曲线。
# train.py

import torch
import torch.nn as nn
from utils.dataLoader import load_data
from utils.model import EfficientNetB0
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
import os

def accuracy(predictions, labels):
    pred = torch.argmax(predictions, dim=1)
    correct = (pred == labels).sum().item()
    return correct

def train(net, start_epoch, epochs, train_loader, validation_loader, device, criterion, optimizer, scheduler, model_path, auto_save):
    # 初始化
    train_acc_list, validation_acc_list = [], []
    train_loss_list, validation_loss_list = [], []
    best_validation_acc = 0
    net = net.to(device)
    
    if start_epoch > 0:
        print(f"从 epoch {start_epoch} 开始训练。")
    
    for epoch in range(start_epoch, epochs):
        # 训练阶段
        net.train()
        train_correct, train_loss, total = 0, 0, 0
        with tqdm(train_loader, ncols=100, colour='green', desc=f"Train Epoch {epoch+1}/{epochs}") as pbar:
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = net(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                train_loss += loss.item() * images.size(0)
                train_correct += accuracy(outputs, labels)
                total += labels.size(0)

                pbar.set_postfix({'loss': f"{train_loss / total:.4f}", 'acc': f"{train_correct / total:.4f}"})
        
        train_acc = train_correct / total
        train_loss = train_loss / total
        train_acc_list.append(train_acc)
        train_loss_list.append(train_loss)

        # 验证阶段
        net.eval()
        validation_correct, validation_loss, total_validation = 0, 0, 0
        with torch.no_grad():
            with tqdm(validation_loader, ncols=100, colour='blue', desc=f"Validation Epoch {epoch+1}/{epochs}") as pbar:
                for images, labels in pbar:
                    images, labels = images.to(device), labels.to(device)
                    outputs = net(images)
                    loss = criterion(outputs, labels)

                    validation_loss += loss.item() * images.size(0)
                    validation_correct += accuracy(outputs, labels)
                    total_validation += labels.size(0)

                    pbar.set_postfix({'loss': f"{validation_loss / total_validation:.4f}", 'acc': f"{validation_correct / total_validation:.4f}"})
        
        validation_acc = validation_correct / total_validation
        validation_loss = validation_loss / total_validation
        validation_acc_list.append(validation_acc)
        validation_loss_list.append(validation_loss)

        # 更新学习率
        scheduler.step(validation_loss)

        # 保存最佳模型
        if validation_acc > best_validation_acc:
            best_validation_acc = validation_acc
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'validation_acc': best_validation_acc
            }
            torch.save(checkpoint, model_path)
            print(f"保存最佳模型,验证准确率: {best_validation_acc:.4f}")

        # 自动保存模型
        if (epoch + 1) % auto_save == 0:
            save_path = model_path.replace('.pth', f'_epoch{epoch+1}.pth')
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'validation_acc': best_validation_acc
            }
            torch.save(checkpoint, save_path)
            print(f"自动保存模型到 {save_path}")

    # 绘制训练曲线
    def plot_training_curves(train_acc_list, validation_acc_list, train_loss_list, validation_loss_list, epochs):
        plt.figure(figsize=(12, 5))

        plt.subplot(1, 2, 1)
        plt.plot(range(1, epochs+1), train_acc_list, 'bo-', label="训练准确率")
        plt.plot(range(1, epochs+1), validation_acc_list, 'ro-', label="验证准确率")
        plt.title("训练准确率 vs 验证准确率")
        plt.xlabel("轮次")
        plt.ylabel("准确率")
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(range(1, epochs+1), train_loss_list, 'bo-', label="训练损失")
        plt.plot(range(1, epochs+1), validation_loss_list, 'ro-', label="验证损失")
        plt.title("训练损失 vs 验证损失")
        plt.xlabel("轮次")
        plt.ylabel("损失")
        plt.legend()

        os.makedirs('logs', exist_ok=True)
        plt.savefig('logs/training_curve.png')
        plt.show()

    plot_training_curves(train_acc_list, validation_acc_list, train_loss_list, validation_loss_list, epochs)

if __name__ == '__main__':
    batch_size = 64
    image_size = 224
    classes_num = 8
    num_epochs = 100
    auto_save = 10
    lr = 1e-4
    weight_decay = 1e-4
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    classify = {'baiban': 0, 'bandian': 1, 'famei': 2, 'faya': 3, 'hongpi': 4, 'qipao': 5, 'youwu': 6, 'zhengchang': 7}
    train_loader, validation_loader, test_loader = load_data(batch_size, image_size, classify)

    net = EfficientNetB0(classes_num, pretrained=True, freeze_features=False)
    model_path = 'model_weights/EfficientNetB0.pth'
    os.makedirs('model_weights', exist_ok=True)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)

    # 检查点续训
    start_epoch = 0
    best_validation_acc = 0
    if os.path.exists(model_path):
        try:
            checkpoint = torch.load(model_path, map_location=device)
            required_keys = ['model_state_dict', 'optimizer_state_dict', 'scheduler_state_dict', 'epoch', 'validation_acc']
            if all(key in checkpoint for key in required_keys):
                net.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
                start_epoch = checkpoint['epoch'] + 1
                best_validation_acc = checkpoint['validation_acc']
                print(f"从 epoch {checkpoint['epoch']} 继续训练,最佳验证准确率: {best_validation_acc:.4f}")
            else:
                print(f"检查点文件缺少必要的键,开始从头训练。")
        except Exception as e:
            print(f"加载检查点时发生错误: {e}")
            print("开始从头训练。")

    print("训练开始")
    time_start = time.time()
    train(net, start_epoch, num_epochs, train_loader, validation_loader, device=device, 
          criterion=criterion, optimizer=optimizer, scheduler=scheduler, 
          model_path=model_path, auto_save=auto_save)
    time_end = time.time()
    seconds = time_end - time_start
    m, s = divmod(seconds, 60)
    h, m = divmod(m, 60)
    print("训练结束")
    print("本次训练时长为:%02d:%02d:%02d" % (h, m, s))

主要特点:

  • 进度条可视化:使用tqdm库实时展示训练和验证进度。
  • 断点续训:支持从上一次中断的epoch继续训练,确保训练过程的连续性。
  • 自动保存:定期保存模型检查点,防止意外中断导致的训练损失。
  • 训练曲线:生成并保存训练与验证的准确率和损失曲线,便于后续分析与调优。

模型评估

评估脚本 (evaluate.py)

评估脚本用于在测试集上评估训练好的模型性能,计算准确率和损失,并将结果保存到文件中。

# evaluate.py

import torch
import torch.nn as nn
from utils.dataLoader import load_data
from utils.model import EfficientNetB0
from tqdm import tqdm
import os

def accuracy(predictions, labels):
    pred = torch.argmax(predictions, dim=1)
    correct = (pred == labels).sum().item()
    return correct

def evaluate(net, test_loader, device, criterion, output_path):
    net.eval()
    test_correct, test_loss, total_test = 0, 0, 0
    with torch.no_grad():
        with tqdm(test_loader, ncols=100, colour='blue', desc=f"Evaluating on Test Set") as pbar:
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                loss = criterion(outputs, labels)

                test_loss += loss.item() * images.size(0)
                test_correct += accuracy(outputs, labels)
                total_test += labels.size(0)

                pbar.set_postfix({'loss': f"{test_loss / total_test:.4f}", 'acc': f"{test_correct / total_test:.4f}"})
    
    test_acc = test_correct / total_test
    test_loss = test_loss / total_test
    result = f"测试集准确率: {test_acc:.4f}, 测试集损失: {test_loss:.4f}"
    print(result)
    
    # 保存结果到文件
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, 'a') as f:
        f.write(result + '\n')

if __name__ == '__main__':
    batch_size = 64
    image_size = 224
    classes_num = 8
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    classify = {'baiban': 0, 'bandian': 1, 'famei': 2, 'faya': 3, 'hongpi': 4, 'qipao': 5, 'youwu': 6, 'zhengchang': 7}
    _, _, test_loader = load_data(batch_size, image_size, classify)

    net = EfficientNetB0(classes_num, pretrained=False)
    model_path = 'model_weights/EfficientNetB0.pth'
    if not os.path.exists(model_path):
        print(f"模型权重文件 {model_path} 不存在,请先训练模型。")
        exit()
    net.load_state_dict(torch.load(model_path, map_location=device))
    net.to(device)

    criterion = nn.CrossEntropyLoss()

    evaluation_output_path = 'outputs/evaluation_results.txt'

    # 清空之前的评估结果
    if os.path.exists(evaluation_output_path):
        os.remove(evaluation_output_path)

    print("评估开始")
    evaluate(net, test_loader, device=device, criterion=criterion, output_path=evaluation_output_path)
    print("评估结束")

评估流程:

  1. 加载模型:从保存的权重文件中加载训练好的模型。
  2. 模型评估:在测试集上计算模型的准确率和损失。
  3. 结果保存:将评估结果保存到指定的输出文件中,便于后续查看与分析。

预测与应用

预测脚本 (predict.py)

预测脚本用于对新图像进行花生豆分类,并在图像上标注分类结果和边框。

# predict.py

import os
import cv2
import numpy as np
import torch
from PIL import Image
from utils.model import EfficientNetB0
from torchvision import transforms

def delet_contours(contours, delete_list):
    delta = 0
    for i in range(len(delete_list)):
        del contours[delete_list[i] - delta]
        delta += 1
    return contours

def main():
    input_path = 'data/pic'
    output_dir = 'outputs/predicted_images'
    os.makedirs(output_dir, exist_ok=True)
    image_files = os.listdir(input_path)
    
    classify = {0: 'baiban', 1: 'bandian', 2: 'famei', 3: 'faya', 4: 'hongpi', 5: 'qipao', 6: 'youwu', 7: 'zhengchang'}
    
    # 与训练时相同的预处理
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),  # ImageNet均值
                             std=(0.229, 0.224, 0.225))   # ImageNet标准差
    ])
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    net = EfficientNetB0(8, pretrained=False)
    model_path = 'model_weights/EfficientNetB0.pth'
    if not os.path.exists(model_path):
        print(f"模型权重文件 {model_path} 不存在,请先训练模型。")
        return
    net.load_state_dict(torch.load(model_path, map_location=device))
    net.to(device)
    net.eval()
    
    min_size = 30
    max_size = 400
    
    for img_name in image_files:
        img_path = os.path.join(input_path, img_name)
        img = cv2.imread(img_path)
        if img is None:
            print(f"无法读取图像: {img_path}")
            continue
        
        hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)  # 转换到HSV颜色空间

        # 根据HSV颜色范围进行掩膜操作(根据实际情况调整颜色范围)
        lower_blue = np.array([100, 100, 8])
        upper_blue = np.array([255, 255, 255])

        mask = cv2.inRange(hsv, lower_blue, upper_blue)  # 创建掩膜
        result = cv2.bitwise_and(img, img, mask=mask)    # 应用掩膜
        result = result.astype(np.uint8)

        # 转换为灰度图并二值化
        gray = cv2.cvtColor(result, cv2.COLOR_BGR2GRAY)
        _, binary_image = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)

        # 查找轮廓
        contours, _ = cv2.findContours(binary_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        contours = list(contours)

        # 过滤轮廓
        delete_list = []
        for idx, contour in enumerate(contours):
            perimeter = cv2.arcLength(contour, True)
            if perimeter < min_size or perimeter > max_size:
                delete_list.append(idx)
        contours = delet_contours(contours, delete_list)

        # 对每个轮廓进行分类
        for contour in contours:
            x, y, w, h = cv2.boundingRect(contour)
            crop = img[y:y+h, x:x+w]
            if crop.size == 0:
                continue
            crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
            crop_tensor = transform(crop_pil).unsqueeze(0).to(device)

            with torch.no_grad():
                output = net(crop_tensor)
                pred = torch.argmax(output, dim=1).item()
                label = classify[pred]

            # 标注图像
            cv2.putText(img, label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36,255,12), 2)
            cv2.rectangle(img, (x, y), (x + w, y + h), (0, 0, 255), 2)

        # 保存结果图像到outputs/predicted_images/
        output_image_path = os.path.join(output_dir, f"{os.path.splitext(img_name)[0]}_predicted.jpg")
        cv2.imwrite(output_image_path, img)
        print(f"保存预测结果到 {output_image_path}")

    print("所有图像的预测和标注已完成并保存到 ./outputs/predicted_images/")

if __name__ == '__main__':
    main()

预测流程:

  1. 图像预处理:将输入图像转换到HSV颜色空间,应用颜色掩膜提取花生豆区域。
  2. 轮廓检测与过滤:查找并过滤不符合大小要求的轮廓,确保只处理有效的花生豆区域。
  3. 分类与标注:对每个有效轮廓进行裁剪、预处理,并使用训练好的模型进行分类。在图像上标注分类结果和边框。
  4. 结果保存:将标注后的图像保存到指定的输出目录。

预测结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

项目成果

通过本项目,我们成功构建了一个能够高效、准确地检测和分类花生豆的深度学习模型。主要成果包括:

  • 高准确率:模型在测试集上达到了令人满意的分类准确率。
  • 自动化检测:实现了对新图像的自动检测与分类,大大提高了检测效率。
  • 可视化结果:通过图像标注,直观展示了分类结果,便于用户理解和应用。

训练与验证的准确率和损失曲线示例
在这里插入图片描述

结论与未来工作

本项目展示了基于深度学习的花生豆检测与分类的可行性与有效性。通过采用预训练的EfficientNetB0模型,并结合数据增强与优化策略,模型在花生豆分类任务中表现出色。

未来的工作方向包括:

  • 模型优化:尝试更深更复杂的模型,如EfficientNetB7,以进一步提升分类性能。
  • 数据扩展:收集更多多样化的花生豆图像,增强模型的泛化能力。
  • 实时检测:优化模型推理速度,实现实时花生豆检测与分类。
  • 部署应用:将模型集成到移动设备或嵌入式系统,便于现场检测与应用。

通过持续的优化与扩展,我们相信这一系统将在农业生产中发挥更大的价值,助力智能农业的发展。


感谢阅读本博客!如果您对本项目有任何疑问或建议,欢迎在下方留言交流。


http://www.kler.cn/a/453156.html

相关文章:

  • 后端开发如何高效使用 Apifox?
  • [论文笔记] 从生成到评估:LLM-as-a-judge 的机遇与挑战
  • React 第十九节 useLayoutEffect 用途使用技巧注意事项详解
  • 如何使用MySQL WorkBench操作MySQL数据库
  • 生成对抗网络,边缘计算,知识图谱,解释性AI
  • 人工智能ACA(七)——计算机视觉基础
  • @RequestParam和@PathVariable的解释与区别
  • 从自动驾驶到具身智能漫谈
  • 正则表达式(三剑客之sed)
  • HarmonyOS NEXT 实战之元服务:静态案例效果---每日玩机技巧
  • 跨境电商培训:云手机的新舞台
  • 某车之家appso层签名逆向
  • 2024楚慧杯WP
  • CultureLLM 与 CulturePark:增强大语言模型对多元文化的理解
  • 力扣-数据结构-3【算法学习day.74】
  • 存储块的获取与释放
  • Windows下ESP32-IDF开发环境搭建
  • 智源研究院与安谋科技达成战略合作,共建开源AI“芯”生态
  • 冰狐智能辅助使用插件化开发集成三方ocr
  • Linux中的lseek 函数与fcntl函数
  • CMS(Concurrent Mark Sweep)垃圾回收器的具体流程
  • 使用Python读写文本文件
  • 【2024最新】基于Python+Mysql+django的水果销售系统Lw+PPT
  • 网络层协议--ip协议
  • uni-app 中使用微信小程序第三方 SDK 及资源汇总
  • 常用的Django模板语言