使用pytorch自定义dataset
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
file_train = [os.path.join("./dog_breed/train_path",i) for i in file ]
def train_transform(self, rgb):
do_flip = np.random.uniform(0.0, 1.0) > 0.5
transform = transforms.Compose([
transforms.HorizontalFlip(do_flip),
transforms.CenterCrop((228, 304))
transforms.ColorJitter(0.4, 0.4, 0.4)
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]
)
])
rgb_np = transform(rgb)
rgb_np = np.asfarray(rgb_np, dtype='float') / 255
return rgb_np
def default_loader(path):
img_pil = Image.open(path)
img_pil = img_pil.resize((224,224))
img_tensor = train_transform(img_pil)
return img_tensor
class trainset(Dataset):
def __init__(self, loader=default_loader):
self.images = file_train
self.target = number_train
self.loader = loader
def __getitem__(self, index):
fn = self.images[index]
img = self.loader(fn)
target = self.target[index]
return img,target
def __len__(self):
return len(self.images)