当前位置: 首页 > article >正文

pytorch cnn 实现猫狗分类

文章目录

    • @[toc]
  • 1. 导入必要的库
  • 2. 定义数据集类
  • 3. 数据预处理和加载
  • 4. 定义 CNN 模型
  • 5. 定义损失函数和优化器
  • 6. 训练模型
  • 7. 保存模型
  • 8. 使用模型进行预测
  • 9 完整代码
  • 10. 总结

1. 导入必要的库

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os

2. 定义数据集类

我们将创建一个自定义数据集类来加载猫狗图片。


class CatDogDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['cat', 'dog']
        self.image_paths = []
        self.labels = []

        # 遍历 cat 和 dog 目录,加载图片路径和标签
        for idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_dir):
                self.image_paths.append(os.path.join(class_dir, img_name))
                self.labels.append(idx)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')  # 确保图片是 RGB 格式
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

3. 数据预处理和加载

定义数据预处理方法,并加载数据集。

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # 调整图片大小
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])

# 加载数据集
train_dataset = CatDogDataset(root_dir='path_to_train_data', transform=transform)
val_dataset = CatDogDataset(root_dir='path_to_val_data', transform=transform)
# 数据加载器

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

4. 定义 CNN 模型


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, 2)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 16 * 16)
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

model = CNN()

5. 定义损失函数和优化器

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

6. 训练模型

num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

    # 验证模型
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f"Validation Accuracy: {100 * correct / total:.2f}%")

7. 保存模型

torch.save(model.state_dict(), 'cat_dog_classifier.pth')

8. 使用模型进行预测

# 加载模型
model.load_state_dict(torch.load('cat_dog_classifier.pth'))
model.eval()

# 预测函数
def predict_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)
    
    return 'cat' if predicted.item() == 0 else 'dog'

# 使用模型进行预测
image_path = 'path_to_test_image.jpg'
prediction = predict_image(image_path)
print(f"The image is a {prediction}")

# 使用模型进行预测
image_path = 'path_to_test_image.jpg'
prediction = predict_image(image_path)
print(f"The image is a {prediction}")

9 完整代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os

class CatDogDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['cats', 'dogs']
        self.image_paths = []
        self.labels = []

        # 遍历 cat 和 dog 目录,加载图片路径和标签
        for idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            num_pets = 0
            for img_name in os.listdir(class_dir):
                self.image_paths.append(os.path.join(class_dir, img_name))
                self.labels.append(idx)
                # print("class_dir : ", img_name)
                # num_pets = num_pets + 1
                # if num_pets >= 5000:
                #     break


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')  # 确保图片是 RGB 格式
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # 调整图片大小
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])

# 加载数据集
train_dataset = CatDogDataset(root_dir='D:/Cache/dataset/PetImages/train', transform=transform)
val_dataset = CatDogDataset(root_dir='D:/Cache/dataset/PetImages/valid', transform=transform)

# 数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, 2)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 16 * 16)
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

model = CNN()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def train():
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

        # 验证模型
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        print(f"Validation Accuracy: {100 * correct / total:.2f}%")


    torch.save(model.state_dict(), 'cat_dog_classifier.pth')

# 预测函数
def predict_image(image_path):
    
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)
    
    return 'cat' if predicted.item() == 0 else 'dog'


def test():
    # 使用模型进行预测
    # 加载模型
    model.load_state_dict(torch.load('cat_dog_classifier.pth'))
    model.eval()
    image_path = 'D:/Cache/dataset/PetImages/Dog/6.jpg'
    image_path = 'D:/develop/pytorch/dogcat/img/training/dogs/dog1.jpg'
    image_path = 'D:/develop/pytorch/dogcat/img/training/cats/4.jpg'
    prediction = predict_image(image_path)
    print(f"The image is a {prediction}")

import matplotlib.pyplot as plt

def test1():
    image_path = 'D:/develop/pytorch/dogcat/img/training/cats/3.jpg'
    img = Image.open(image_path).convert('RGB')

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    img = transform(img)
    img = img.unsqueeze(0)  # 添加batch维度

    model.load_state_dict(torch.load('cat_dog_classifier.pth'))
    model.eval()
    prediction = predict_image(image_path)

    class_names = ['cat', 'dog']
    print("Predicted class:", prediction)

    plt.imshow(img.squeeze().numpy().transpose((1, 2, 0)))
    plt.show()

if __name__ == '__main__':
    train()
    test1()

D:\develop\pytorch\dogcat>python3.7 dogVsCat.py
C:\Users\yosola\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\PIL\TiffImagePlugin.py:864: UserWarning: Truncated File Read
  warnings.warn(str(msg))
Epoch [1/10], Loss: 0.5782
Validation Accuracy: 76.25%
Epoch [2/10], Loss: 0.4676
Validation Accuracy: 81.10%
Epoch [3/10], Loss: 0.4201
Validation Accuracy: 84.90%
Epoch [4/10], Loss: 0.3605
Validation Accuracy: 88.25%
Epoch [5/10], Loss: 0.2949
Validation Accuracy: 92.50%
Epoch [6/10], Loss: 0.2234
Validation Accuracy: 95.90%
Epoch [7/10], Loss: 0.1562
Validation Accuracy: 98.00%
Epoch [8/10], Loss: 0.1069
Validation Accuracy: 98.60%
Epoch [9/10], Loss: 0.0907
Validation Accuracy: 99.70%
Epoch [10/10], Loss: 0.0785
Validation Accuracy: 99.50%
Predicted class: cat

在这里插入图片描述

10. 总结

我们定义了一个自定义数据集类 CatDogDataset 来加载猫狗图片。

使用 PyTorch 的 DataLoader 加载数据。

定义了一个简单的 CNN 模型进行训练。

保存训练好的模型,并使用模型进行预测。

你可以根据需要调整模型的架构、超参数和数据增强方法。希望这个示例对你有帮助!


http://www.kler.cn/a/549662.html

相关文章:

  • 【C++】详解 set multiset map multiset 的使用
  • 网络安全讲座
  • Redis 的常见应用场景
  • 【第9章:计算机视觉实战—9.1 目标检测与识别:YOLO、Faster R-CNN等模型的实现与应用】
  • 案例-06.部门管理-根据ID查询
  • ESP32通过MQTT连接阿里云平台实现消息发布与订阅
  • 9. Docker 当中的复杂安装(MySQL主从复制的安装,Redis 的3主3从的安装配置)的详细说明[附加步骤图]
  • Linux(ubuntu)下载ollama速度慢解决办法
  • 今日写题06work(链表)
  • Golang GORM系列:GORM数据库迁移
  • Apache 服务启动失败的故障诊断
  • Leetcode 221-最大正方形
  • 使用css实现镂空效果
  • 小初高各学科教材,PDF电子版下载
  • 数据库-通用数据接口标准
  • mysql系列8—Innodb的undolog
  • MVC模式和MVVM模式
  • 【漫话机器学习系列】092.模型的一致性(Consistency of a Model)
  • 国产FPGA开发板选择
  • Spring Boot实战:拦截器