Source code for template.runner.apply_model.setup

# Utils
import logging
import os

# Torch related stuff
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms

# DeepDIVA
from datasets.image_folder_dataset import ImageFolderApply
from template.runner.triplet.transforms import MultiCrop
from template.setup import _load_mean_std_from_file


[docs]def set_up_dataloader(model_expected_input_size, dataset_folder, batch_size, workers, inmem, multi_crop, classify, **kwargs): """ Set up the dataloaders for the specified datasets. Parameters ---------- :param model_expected_input_size: tuple Specify the height and width that the model expects. :param dataset_folder: string Path string that points to the three folder train/val/test. Example: ~/../../data/svhn :param batch_size: int Number of datapoints to process at once :param workers: int Number of workers to use for the dataloaders :param inmem: boolean Flag: if False, the dataset is loaded in an online fashion i.e. only file names are stored and images are loaded on demand. This is slower than storing everything in memory. :param multi_crop: int if None, the MultiCrop transform is not applied to the data. Otherwise, multi_crop contains an integer which specifies how many crops to make from each image. :param classify : boolean Specifies whether to generate a classification report for the data or not. :param kwargs: dict Any additional arguments. :return: dataloader, dataloader, dataloader, int Three dataloaders for train, val and test. Number of classes for the model. """ # Recover dataset name dataset = os.path.basename(os.path.normpath(dataset_folder)) logging.info('Loading {} from:{}'.format(dataset, dataset_folder)) # Load the dataset as images apply_ds = ImageFolderApply(path=dataset_folder, classify=classify) # Loads the analytics csv and extract mean and std try: mean, std = _load_mean_std_from_file(dataset_folder, inmem, workers, kwargs['runner_class']) except: logging.error('analytics.csv not found in folder. Please copy the one generated in the ' 'training folder to this folder.') logging.error('Currently normalizing with 0.5 for all channels for mean and std.') mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] # Set up dataset transforms logging.debug('Setting up dataset transforms') if multi_crop == None: apply_ds.transform = transforms.Compose([ transforms.Resize(model_expected_input_size), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) else: apply_ds.transform = transforms.Compose([ MultiCrop(size=model_expected_input_size, n_crops=multi_crop), transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), transforms.Lambda(lambda crops: torch.stack([transforms.Normalize(mean=mean, std=std)(crop) for crop in crops])), ]) apply_loader = torch.utils.data.DataLoader(apply_ds, shuffle=False, batch_size=batch_size, num_workers=workers, pin_memory=True) return apply_loader, len(apply_ds.classes)
# # def set_up_model(output_channels, model_name, pretrained, optimizer_name, no_cuda, resume, load_model, start_epoch, # train_loader, disable_databalancing, dataset_folder, inmem, workers, num_classes=None, **kwargs): # """ # Instantiate model, optimizer, criterion. Load a pretrained model or resume from a checkpoint. # # Parameters # ---------- # output_channels : int # Specify shape of final layer of network. Only used if num_classes is not specified. # # model_name : string # Name of the model # # pretrained : bool # Specify whether to load a pretrained model or not # # optimizer_name : string # Name of the optimizer # # lr: float # Value for learning rate # # no_cuda : bool # Specify whether to use the GPU or not # # resume : string # Path to a saved checkpoint # # load_model : string # Path to a saved model # # start_epoch : int # Epoch from which to resume training. If if not resuming a previous experiment the value is 0 # # num_classes: int # Number of classes for the model # # kwargs: dict # Any additional arguments. # # Returns # ------- # model, criterion, optimizer, best_value, start_epoch # """ # # # Initialize the model # logging.info('Setting up model {}'.format(model_name)) # # output_channels = output_channels if num_classes == None else num_classes # model = models.__dict__[model_name](output_channels=output_channels, pretrained=pretrained) # # # Get the optimizer created with the specified parameters in kwargs (such as lr, momentum, ... ) # optimizer = _get_optimizer(optimizer_name, model, **kwargs) # # # Get the criterion # if disable_databalancing: # criterion = nn.CrossEntropyLoss() # else: # try: # weights = _load_class_frequencies_weights_from_file(dataset_folder, inmem, workers) # criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(weights).type(torch.FloatTensor)) # logging.info('Loading weights for data balancing') # except: # logging.warning('Unable to load information for data balancing. Using normal criterion') # criterion = nn.CrossEntropyLoss() # # # Transfer model to GPU (if desired) # if not no_cuda: # logging.info('Transfer model to GPU') # model = torch.nn.DataParallel(model).cuda() # criterion = criterion.cuda() # cudnn.benchmark = True # # # Load saved model # if load_model: # if os.path.isfile(load_model): # model_dict = torch.load(load_model) # logging.info('Loading a saved model') # try: # model.load_state_dict(model_dict['state_dict'], strict=False) # except Exception as exp: # logging.warning(exp) # else: # logging.error("No model dict found at '{}'".format(load_model)) # sys.exit(-1) # # # Resume from checkpoint # if resume: # if os.path.isfile(resume): # logging.info("Loading checkpoint '{}'".format(resume)) # checkpoint = torch.load(resume) # start_epoch = checkpoint['epoch'] # best_value = checkpoint['best_value'] # model.load_state_dict(checkpoint['state_dict']) # optimizer.load_state_dict(checkpoint['optimizer']) # # val_losses = [checkpoint['val_loss']] #not used? # logging.info("Loaded checkpoint '{}' (epoch {})".format(resume, checkpoint['epoch'])) # else: # logging.error("No checkpoint found at '{}'".format(resume)) # sys.exit(-1) # else: # best_value = 0.0 # # return model, criterion, optimizer, best_value, start_epoch