PyTorch中如何读取数据(Dataset类的使用)

PyTorch中如何读取数据(Dataset类的使用)

在pytorch中如何读取数据主要有两个类。

分别是Dataset和Dataloader。

dataset可以理解为:提供一种方式去获取数据及其label(标签)。

可以实现(1)如何获取每一个数据及其label;(2)总共有多少数据。这两个功能。

dataloader可以理解为:为后面的网络提供不同的数据形势。

Dataset类怎么去用?

from torch.utils.data import Dataset

这段代码可以理解为:从torch大工具箱里面utils常用的工具区,关于数据的data区的。

可以使用help()函数查看,在jupyter或者pycharm控制台里面查询。

也可以直接在jupyter里输入Dataset??,直接可以查询。

Dataset的运用

class MyData( Dataset ) : //创建一个class(MyData)继承Dataset类

class MyData( Dataset ) :

def __init__(self): //初始化类,比如说我们要根据这个类去创建一个特例的时候,它就要运行的一个函数。这个函数它一般会为整个class提供一个全局变量。为后面的一些函数提供一些所需要的量。

def __init__(self):

def __getitem__(self, item) :

它默认为item,我们改为def __getitem__(self, idx): // idx可以看作一个编号

def __getitem__(self, item) :

如果我们要通过这个idx(索引)来获取图片的地址的话,首先要获取这些图片的列表(list)。

如果需要获取所有图片的地址的话,我们就需要用到os(python中关于系统的一个库)

dir_path = "" // ""中输入所有图片文件夹地址,我使用全地址报错了,改用相对地址后没问题

import os //使用os

img_path_list = os.listdir(dir_path) //将文件夹中的所有图片变成列表

如果我们要使用idxa去获取想要的图片的话,首先就要去创建图片地址的列表

def __init__(self, root_dir, label_dir)

使用python console验证。

import os

root_dir = "" // “”中输入放图片文件上一个文件的地址

label_dir = "" // “”中输入放图片的地址

path = os.path.join(root_dir, label_dir) //join这个给函数的作用就是在root_dir,

label_dir两个地址之间添加一个\,将这两个路径进行拼接

接着,

def __init__(self, root_dir, label_dir)

self.root_dir = root_dir //为什么用self,我们知道一个函数中的变量是不能传

递给另外一个函数的变量的。而这个self,它可以把self指定的一个变量给后面的函数使用。它就

相当于指定了一个类中的全局变量。

self.label_dir = label_dir

self.path = os.path.join(self.root_dir, self.label_dir) //获得图片的路径地址

self.img_path = os.listdir(self.path) // 获得所有图片列表

如果我们想验证这个函数,可以在python console中验证。

如果要获取所有图片中某一个图片的话,

def __getitem__(self, idx):

img_name = self.img_path[idx] // 名称,从list里面读取遥感对应位置, 加self是

指全局的,引用上面的 self.img_path

img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) //获

取某一个图片的路径

自此可以使用python console验证。

接着,可以使用import PIL import Image来读取图片

img = Image.open(img_item_path) //读取图片

label = self.label_dir

return img, label

def __len__(self): //查看这个列表的长度有多长

return len(self.img_path)

怎么读取电脑中的一张图片

from PIL import Image //一个读取图片的方法

可以先在Python控制台进行调试。

from PIL import Image

img_path = "" //获取图片地址 “”中输入图片地址

img = Image.open(img_path)

img.show() //显示该图片

全部代码

from torch.utils.data import Dataset

from PIL import Image

import os

class MyData(Dataset):

def __init__(self, root_dir, label_dir):

self.root_dir = root_dir

self.label_dir = label_dir

self.path = os.path.join(self.root_dir, self.label_dir)

self.img_path = os.listdir(self.path)

def __getitem__(self, idx):

img_name = self.img_path[idx]

img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)

img = Image.open(img_item_path)

label = self.label_dir

return img, label

def __len__(self):

return len(self.img_path)

root_dir = "dataset/train"

ants_label_dir = "ants"

bees_label_dir = "bees"

ants_dataset = MyData(root_dir, ants_label_dir)

bees_dataset = MyData(root_dir, bees_label_dir)

train_dataset = ants_dataset + bees_dataset

# "len(train_dataset)"指令可以在Python console中查看train_dataset数据集中有多少个元素。

img, label = train_dataset[230]

img.show()

✨ 相关推荐

中国啶虫脒原药行业发展趋势分析与未来前景研究报告(2022-2029年)
《lol》吸血装备有哪些 吸血装备汇总介绍
365体育管网登录网站

《lol》吸血装备有哪些 吸血装备汇总介绍

📅 07-13 👀 5556
武进概况
365平台怎么注册

武进概况

📅 07-02 👀 4013