Source code for datasets.image_folder_triplet

"""
Load a dataset of images by specifying the folder where its located and prepares it for triplet
similarity matching training.
"""

# Utils
import logging
import os
import random
import sys
from multiprocessing import Pool
import numpy as np
import torch.utils.data as data

# Torch related stuff
import torchvision
from tqdm import trange
from torchvision.datasets.folder import pil_loader


[docs]def load_dataset(dataset_folder, num_triplets=None, in_memory=False, workers=1, only_evaluate=False): """ Loads the dataset from file system and provides the dataset splits for train validation and test. The dataset is expected to be in the same structure as described in image_folder_dataset.load_dataset() Parameters ---------- dataset_folder : string Path to the dataset on the file System num_triplets : int Number of triplets [a, p, n] to generate on dataset creation in_memory : boolean Load the whole dataset in memory. If False, only file names are stored and images are loaded on demand. This is slower than storing everything in memory. workers: int Number of workers to use for the dataloaders only_evaluate : boolean Flag : if True, only the test set is loaded. Returns ------- train_ds : data.Dataset val_ds : data.Dataset test_ds : data.Dataset Train, validation and test splits """ if only_evaluate: # Get the splits folders test_dir = os.path.join(dataset_folder, 'test') # Sanity check on the splits folders if not os.path.isdir(test_dir): logging.error("Test folder not found in the args.dataset_folder=" + dataset_folder) sys.exit(-1) test_ds = ImageFolderTriplet(test_dir, train=False, num_triplets=num_triplets, workers=workers, in_memory=in_memory) return None, None, test_ds else: # Get the splits folders train_dir = os.path.join(dataset_folder, 'train') val_dir = os.path.join(dataset_folder, 'val') test_dir = os.path.join(dataset_folder, 'test') # Sanity check on the splits folders if not os.path.isdir(train_dir): logging.error("Train folder not found in the args.dataset_folder=" + dataset_folder) sys.exit(-1) if not os.path.isdir(val_dir): logging.error("Val folder not found in the args.dataset_folder=" + dataset_folder) sys.exit(-1) if not os.path.isdir(test_dir): logging.error("Test folder not found in the args.dataset_folder=" + dataset_folder) sys.exit(-1) train_ds = ImageFolderTriplet(train_dir, train=True, num_triplets=num_triplets, workers=workers, in_memory=in_memory) val_ds = ImageFolderTriplet(val_dir, train=False, num_triplets=num_triplets, workers=workers, in_memory=in_memory) test_ds = ImageFolderTriplet(test_dir, train=False, num_triplets=num_triplets, workers=workers, in_memory=in_memory) return train_ds, val_ds, test_ds
[docs]class ImageFolderTriplet(data.Dataset): """ This class loads the data provided and stores it entirely in memory as a dataset. Additionally, triplets will be generated in the format of [a, p, n] and their file names stored in memory. """ def __init__(self, path, train=None, num_triplets=None, in_memory=None, transform=None, target_transform=None, workers=None): """ Load the data in memory and prepares it as a dataset. Parameters ---------- path : string Path to the dataset on the file System train : bool Denotes whether this dataset will be used for training. Its very important as for validation and test there are no triplet but pairs to evaluate similarity matching. num_triplets : int Number of triplets [a, p, n] to generate on dataset creation in_memory : boolean Load the whole dataset in memory. If False, only file names are stored and images are loaded on demand. This is slower than storing everything in memory. transform : torchvision.transforms Transformation to apply on the data target_transform : torchvision.transforms Transformation to apply on the labels workers: int Number of workers to use for the dataloaders """ self.dataset_folder = os.path.expanduser(path) self.train = train self.transform = transform self.target_transform = target_transform self.num_triplets = num_triplets self.in_memory = in_memory dataset = torchvision.datasets.ImageFolder(path) # Shuffle the data once (otherwise you get clusters of samples of same class in each batch for val and test) np.random.shuffle(dataset.imgs) # Extract the actual file names and labels as entries self.file_names = np.asarray([item[0] for item in dataset.imgs]) self.labels = np.asarray([item[1] for item in dataset.imgs]) # Set expected class attributes self.classes = np.unique(self.labels) if self.train: self.triplets = self.generate_triplets() if self.in_memory: # Load all samples with Pool(workers) as pool: self.data = pool.map(pil_loader, self.file_names)
[docs] def generate_triplets(self): """ Generate triplets for training. Triplets have format [anchor, positive, negative] """ logging.info('Begin generating triplets') triplets = [] for _ in trange(self.num_triplets, leave=False): # Select two different classes, c1 and c2 c1 = np.random.randint(0, np.max(self.labels)) c2 = np.random.randint(0, np.max(self.labels)) while c1 == c2: c2 = np.random.randint(0, np.max(self.labels)) # Select two different object of class c1, a and p c1_items = np.where(self.labels == c1)[0] a = random.choice(c1_items) p = random.choice(c1_items) while a == p: p = random.choice(c1_items) # Select an item from class c2, n c2_items = np.where(self.labels == c2)[0] n = random.choice(c2_items) # Add the triplet to the list as we now have a,p,n triplets.append([a, p, n]) logging.info('Finished generating {} triplets'.format(self.num_triplets)) return triplets
def __getitem__(self, index): """ Retrieve a sample by index Parameters ---------- index : int Returns ------- img_a : FloatTensor Anchor image img_p : FloatTensor Positive image (same class of anchor) img_n : FloatTensor Negative image (different class of anchor) """ if not self.train: l = self.labels[index] if self.in_memory: img_a = self.data[index] else: img_a = pil_loader(self.file_names[index]) if self.transform is not None: img_a = self.transform(img_a) return img_a, l a, p, n = self.triplets[index] if self.in_memory: img_a = self.data[a] img_p = self.data[p] img_n = self.data[n] else: img_a = pil_loader(self.file_names[a]) img_p = pil_loader(self.file_names[p]) img_n = pil_loader(self.file_names[n]) if self.transform is not None: img_a = self.transform(img_a) img_p = self.transform(img_p) img_n = self.transform(img_n) return img_a, img_p, img_n def __len__(self): if self.train: return len(self.triplets) else: return len(self.file_names)