当前位置: 首页 > article >正文

ccc-pytorch-宝可梦自定义数据集实战-加载数据部分(9)

文章目录

      • 第一步:构建路径与种类的映射关系
      • 第二步:载入所有的宝可梦图像
      • 第三步:打散顺序并通过路径名提取映射关系构建映射文件
      • 第四步:完善选取、获取图片信息功能并可视化
      • 第五步:对数据进行预处理
      • 第六步:批量读取图片

文件/数据结构:
image-20230313185704642
在这里插入图片描述

第一步:构建路径与种类的映射关系

import os
from torch.utils.data import Dataset

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize
        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)

    def __len__(self):
        pass
    def __getitem__(self, idx):
        pass

def main():
    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')

if __name__ == '__main__':
    main()

image-20230311204936302

第二步:载入所有的宝可梦图像

import os,glob

from torch.utils.data import Dataset

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize
        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.load_csv('images.csv')
    def load_csv(self,filename):
        images = []
        for name in self.name2label.keys():
            images +=glob.glob(os.path.join(self.root,name,'*.png'))
            images += glob.glob(os.path.join(self.root, name, '*.jpg'))
            images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
        #1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
        print(len(images),images)

    def __len__(self):
        pass
    def __getitem__(self, idx):
        pass

def main():

    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')

if __name__ == '__main__':
    main()

image-20230311210101708

第三步:打散顺序并通过路径名提取映射关系构建映射文件

import csv
import os,glob
import random

from torch.utils.data import Dataset

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize
        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.images,self.labels = self.load_csv('images.csv')

    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):

            images = []
            for name in self.name2label.keys():
                images +=glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
            print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images :
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file',filename)

            images,labels = [],[]
            with open(os.path.join(self.root,filename)) as f:
                reader = csv.reader(f)
                for row in reader:
                    img , label = row
                    label = int (label)
                    images.append(img)
                    labels.append(label)
            assert  len(images) == len(labels)
            return images,labels

    def __len__(self):
        pass
    def __getitem__(self, idx):
        pass

def main():

    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')

if __name__ == '__main__':
    main()

image-20230311211317832
image-20230311211306693

第四步:完善选取、获取图片信息功能并可视化

import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize

        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.images,self.labels = self.load_csv('images.csv')

        if model == 'train':
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif model == 'val':
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
        else :
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.images)):]

    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):

            images = []
            for name in self.name2label.keys():
                images +=glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
            print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images :
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file',filename)

        images,labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img , label = row
                label = int (label)
                images.append(img)
                labels.append(label)
        assert  len(images) == len(labels)
        return images,labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        #img:D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png
        img , label = self.images[idx],self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),
            transforms.Resize((self.resize,self.resize)),
            transforms.ToTensor()
        ])

        img = tf(img)
        label = torch.tensor(label)

        return img,label

def main():
    import visdom
    viz = visdom.Visdom()

    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')
    # 得到迭代器第一个样本
    x,y = next(iter(db))
    print('sample:',x.shape,y.shape)
    viz.images(x,win='sample_x',opts=dict(title='sample_x'))

if __name__ == '__main__':
    main()

在这里插入图片描述

image-20230312210040000

第五步:对数据进行预处理

import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize

        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.images,self.labels = self.load_csv('images.csv')

        if model == 'train':
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif model == 'val':
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
        else :
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.images)):]

    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):

            images = []
            for name in self.name2label.keys():
                images +=glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
            print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images :
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file',filename)

        images,labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img , label = row
                label = int (label)
                images.append(img)
                labels.append(label)
        assert  len(images) == len(labels)
        return images,labels

    def __len__(self):
        return len(self.images)

    def denormalize(self,x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x-mean)/std
        # x = x_hat*std + mean
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        x = x_hat * std + mean
        return x

    def __getitem__(self, idx):
        #img:D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png
        img , label = self.images[idx],self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),
            transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),#大小放缩
            transforms.RandomRotation(15),#随机旋转
            transforms.CenterCrop(self.resize),#中心裁剪
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],#通用数据
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)

        return img,label

def main():
    import visdom
    import time
    viz = visdom.Visdom()

    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')
    # 得到迭代器第一个样本
    x,y = next(iter(db))
    print('sample:',x.shape,y.shape)
    viz.images(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))

if __name__ == '__main__':
    main()

image-20230312211459931
如果没有denormalize生成图片如下:
image-20230312210836559

第六步:批量读取图片

import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize

        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.images,self.labels = self.load_csv('images.csv')

        if model == 'train':
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif model == 'val':
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
        else :
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.images)):]

    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):

            images = []
            for name in self.name2label.keys():
                images +=glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
            print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images :
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file',filename)

        images,labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img , label = row
                label = int (label)
                images.append(img)
                labels.append(label)
        assert  len(images) == len(labels)
        return images,labels

    def __len__(self):
        return len(self.images)

    def denormalize(self,x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x-mean)/std
        # x = x_hat*std + mean
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        x = x_hat * std + mean
        return x

    def __getitem__(self, idx):
        #img:D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png
        img , label = self.images[idx],self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),
            transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),#大小放缩
            transforms.RandomRotation(15),#随机旋转
            transforms.CenterCrop(self.resize),#中心裁剪
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],#通用数据
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)

        return img,label

def main():
    import visdom
    import time
    viz = visdom.Visdom()

    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',64,'train')
    # 得到迭代器第一个样本
    x,y = next(iter(db))
    print('sample:',x.shape,y.shape)
    viz.images(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
    loader = DataLoader(db,batch_size=32,shuffle=True,num_workers=8)

    for x ,y in loader:
        viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))
        viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
        time.sleep(10)


if __name__ == '__main__':
    main()

image-20230312212317555
对于分类分类有序的结构可以更简单的调用API

tf = transforms.Compose([
                transforms.Resize((64,64)),
                transforms.ToTensor(),
])
db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
loader = DataLoader(db, batch_size=32, shuffle=True)

print(db.class_to_idx)

for x,y in loader:
    viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
    viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))

    time.sleep(10)

http://www.kler.cn/a/2486.html

相关文章:

  • 《全面解析 QT 各版本:特性、应用与选择策略》
  • 模具生产过程中的标签使用流程图
  • Soul Android端稳定性背后的那些事
  • 怎么将pdf中的某一个提取出来?介绍几种提取PDF中页面的方法
  • 允许某段网络访问Linux服务器上的MariaDB
  • 代码随想录第51天
  • 智慧物业类管理APP开发功能有哪些?
  • Lesson 9.1 集成学习的三大关键领域、Bagging 方法的基本思想和 RandomForestRegressor 的实现
  • Spring入门篇3 --- 依赖注入(DI)方式、集合注入
  • 网络技术与应用概论(上)——“计算机网络”
  • 第29次CCFCSP认证经验总结
  • C语言 结构体进阶 结构体、枚举、联合详解(2)
  • AWS白皮书总结
  • 计算机网络管理 TCP三次握手的建立过程,Wireshark抓包分析并验证TCP三次握手建立连接的报文
  • I2C模块理解
  • Linux系统下gdb调试
  • 【Go】K8s 管理系统项目[Jenkins Pipeline K8s环境–应用部署]
  • Python 项目之实现文件内容的反转再输入(一)完全反转
  • react中渲染企业微信的表情
  • 使用shell 脚本,批量解压一批zip文件,解压后的文件放在以原zip文件名前10个字符的文件夹中的例子
  • Java stream性能比较
  • java基础面试题(四)
  • TU-95 strategic bomber气动布局分析
  • 蓝桥杯训练day3
  • 深入理解JVM虚拟机(六)
  • 梳理LVM逻辑卷管理,