"""
Load a dataset of images by specifying the folder where its located.
"""
# Utils
import logging
import os
import sys
from multiprocessing import Pool
import numpy as np
# Torch related stuff
import torch.utils.data as data
import torchvision
from torchvision.datasets.folder import pil_loader
from util.misc import get_all_files_in_folders_and_subfolders, has_extension
[docs]def load_dataset(dataset_folder, in_memory=False, workers=1):
"""
Loads the dataset from file system and provides the dataset splits for train validation and test
The dataset is expected to be in the following structure, where 'dataset_folder' has to point to
the root of the three folder train/val/test.
Example:
dataset_folder = "~/../../data/cifar"
which contains the splits sub-folders as follow:
'dataset_folder'/train
'dataset_folder'/val
'dataset_folder'/test
In each of the three splits (train, val, test) should have different classes in a separate folder
with the class name. The file name can be arbitrary i.e. it does not have to be 0-* for classes 0
of MNIST.
Example:
train/dog/whatever.png
train/dog/you.png
train/dog/like.png
train/cat/123.png
train/cat/nsdf3.png
train/cat/asd932_.png
train/"class_name"/*.png
Parameters
----------
dataset_folder : string
Path to the dataset on the file System
in_memory : boolean
Load the whole dataset in memory. If False, only file names are stored and images are loaded
on demand. This is slower than storing everything in memory.
workers: int
Number of workers to use for the dataloaders
Returns
-------
train_ds : data.Dataset
val_ds : data.Dataset
test_ds : data.Dataset
Train, validation and test splits
"""
# Get the splits folders
train_dir = os.path.join(dataset_folder, 'train')
val_dir = os.path.join(dataset_folder, 'val')
test_dir = os.path.join(dataset_folder, 'test')
# Sanity check on the splits folders
if not os.path.isdir(train_dir):
logging.error("Train folder not found in the dataset_folder=" + dataset_folder)
sys.exit(-1)
if not os.path.isdir(val_dir):
logging.error("Val folder not found in the dataset_folder=" + dataset_folder)
sys.exit(-1)
if not os.path.isdir(test_dir):
logging.error("Test folder not found in the dataset_folder=" + dataset_folder)
sys.exit(-1)
# If its requested online, delegate to torchvision.datasets.ImageFolder()
if not in_memory:
# Get an online dataset for each split
train_ds = torchvision.datasets.ImageFolder(train_dir)
val_ds = torchvision.datasets.ImageFolder(val_dir)
test_ds = torchvision.datasets.ImageFolder(test_dir)
return train_ds, val_ds, test_ds
else:
# Get an offline (in-memory) dataset for each split
train_ds = ImageFolderInMemory(train_dir, workers)
val_ds = ImageFolderInMemory(val_dir, workers)
test_ds = ImageFolderInMemory(test_dir, workers)
return train_ds, val_ds, test_ds
[docs]class ImageFolderInMemory(data.Dataset):
"""
This class loads the data provided and stores it entirely in memory as a dataset.
It makes use of torchvision.datasets.ImageFolder() to create a dataset. Afterward all images are
sequentially stored in memory for faster use when paired with dataloders. It is responsibility of
the user ensuring that the dataset actually fits in memory.
"""
def __init__(self, path, transform=None, target_transform=None, workers=1):
"""
Load the data in memory and prepares it as a dataset.
Parameters
----------
path : string
Path to the dataset on the file System
transform : torchvision.transforms
Transformation to apply on the data
target_transform : torchvision.transforms
Transformation to apply on the labels
workers: int
Number of workers to use for the dataloaders
"""
self.dataset_folder = os.path.expanduser(path)
self.transform = transform
self.target_transform = target_transform
# Get an online dataset
dataset = torchvision.datasets.ImageFolder(path)
# Shuffle the data once (otherwise you get clusters of samples of same class in each minibatch for val and test)
np.random.shuffle(dataset.imgs)
# Extract the actual file names and labels as entries
file_names = np.asarray([item[0] for item in dataset.imgs])
self.labels = np.asarray([item[1] for item in dataset.imgs])
# Load all samples
pool = Pool(workers)
self.data = pool.map(pil_loader, file_names)
pool.close()
# Set expected class attributes
self.classes = np.unique(self.labels)
def __getitem__(self, index):
"""
Retrieve a sample by index
Parameters
----------
index : int
Returns
-------
img : FloatTensor
target : int
label of the image
"""
img, target = self.data[index], self.labels[index]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
[docs]class ImageFolderApply(data.Dataset):
"""
TODO fill me
"""
def __init__(self, path, transform=None, target_transform=None, classify=False):
"""
TODO fill me
Parameters
----------
path : string
Path to the dataset on the file System
transform : torchvision.transforms
Transformation to apply on the data
target_transform : torchvision.transforms
Transformation to apply on the labels
"""
self.dataset_folder = os.path.expanduser(path)
self.transform = transform
self.target_transform = target_transform
if classify is True:
# Get an online dataset
dataset = torchvision.datasets.ImageFolder(path)
# Extract the actual file names and labels as entries
self.file_names = np.asarray([item[0] for item in dataset.imgs])
self.labels = np.asarray([item[1] for item in dataset.imgs])
else:
# Get all files in the folder that are images
self.file_names = self._get_filenames(self.dataset_folder)
# Extract the label for each file (assuming standard format of root_folder/class_folder/img.jpg)
self.labels = [item.split('/')[-2] for item in self.file_names]
# Set expected class attributes
self.classes = np.unique(self.labels)
def _get_filenames(self, path):
file_names = []
for item in get_all_files_in_folders_and_subfolders(path):
if has_extension(item, ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']):
file_names.append(item)
return file_names
def __getitem__(self, index):
"""
Retrieve a sample by index and provides its filename as well
Parameters
----------
index : int
Returns
-------
img : FloatTensor
target : int
label of the image
filename : string
"""
img = pil_loader(self.file_names[index])
target, filename = self.labels[index], self.file_names[index]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, filename
def __len__(self):
return len(self.file_names)