PyTorch读取图片数据

CNN实战基础——读取图片数据
实现结果如下图:
在这里插入图片描述

取数据

1.导包

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

2.下载数据

trans = transforms.ToTensor()
train_dataset = datasets.MNIST(
    root='./drive/MyDrive/DataSet', train=True, transform=trans, download=True)
test_dataset = datasets.MNIST(
    root='./drive/MyDrive/DataSet', train=False, transform=trans, download=True)

root='xx' :表示 数据集所在路径,数据集不用解压
train=True or False:表明是 训练集或测试集
transform=trans :把读取的Img类型图片转为Tensor
down=True :表示 若数据集在root路径下,则直接加载;若不在root里,则下载(外网很慢,可以提前下载好放进去)

下图为下载之后,自动创建了raw文件夹,数据集在raw里
在这里插入图片描述

3.加载数据

train_dl = DataLoader(train_dataset, batch_size=32)

param1:指明加载什么数据集
param2:一批有32张图

4.打印数据
我们往往想看看第一批(32张)图,但不能通过train_dl[0]等下标方式访问,下面有两种方式都可以:

4.1 next + enumerate

examples = enumerate(train_dl) # 返回数字下标+迭代器(可用next访问)
next(examples)

4.2 next + iter

examples = iter(train_dl) # 返回迭代器,无下标
imgs, labels = next(examples)

next()iter()函数

结果可视化

第一版

import matplotlib.pyplot as plt
fig = plt.figure() # 创建一个窗口
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.imshow(imgs[i][0])
  plt.title('label:{}'.format(labels[i]))# 格式化输出,否则是tensor数据

plt.imshow(imgs[i][0])这里指:imgs是4维数据,后面2维是28*28图片大小,不用管,关键是用前两维定位到图片
结果:
在这里插入图片描述
发现布局有问题,故增加plt.tight_layout()语句,结果明显分开

在这里插入图片描述
发现x,y轴的数字可以不要,添加plt.xticks([]), plt.yticks([]),结果:
在这里插入图片描述

想要灰度图,设置一个参数即可plt.imshow(imgs[i][0],cmap='gray')
在这里插入图片描述

完整代码:

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

trans = transforms.ToTensor()
train_dataset = datasets.MNIST(
    root='./drive/MyDrive/DataSet', train=True, transform=trans, download=True)
test_dataset = datasets.MNIST(
    root='./drive/MyDrive/DataSet', train=False, transform=trans, download=True)

train_dl = DataLoader(train_dataset, batch_size=32)

# train_dl里有很多batch,每个batch里有batch_size张图片
#examples = enumerate(train_dl)
#next(examples)
# 用iter会少序号,和enumerate有一点不同,其它差不多可以打印出来看
examples = iter(train_dl)
imgs, labels = next(examples)
# print(len(imgs), len(labels))

#-------------------------------数据显示--------------------------------------------
import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(imgs[i][0],cmap='gray') # 4维数据,后面2维是28*28图片大小,不用管,关键是用前两维定位到图片
  plt.title('label:{}'.format(labels[i]))# 格式化输出,否则是tensor数据
  plt.xticks([])            # 清空x,y轴的数字
  plt.yticks([])

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇
下一篇>>