Source code for template.runner.image_classification.image_classification

"""
This file is the template for the boilerplate of train/test of a DNN for image classification

There are a lot of parameter which can be specified to modify the behaviour and they should be used 
instead of hard-coding stuff.
"""

import logging
import sys
import os

# Utils
import numpy as np

# DeepDIVA
import models
# Delegated
from template.runner.image_classification.train import train
from template.runner.image_classification.evaluate import evaluate
from template.setup import set_up_model, set_up_dataloaders
from util.misc import checkpoint, adjust_learning_rate


[docs]class ImageClassification:
[docs] @classmethod def single_run(cls, **kwargs): """ This is the main routine where train(), validate() and test() are called. Returns ------- train_value : ndarray[floats] of size (1, `epochs`) Accuracy values for train split val_value : ndarray[floats] of size (1, `epochs`+1) Accuracy values for validation split test_value : float Accuracy value for test split """ # Prepare the data, optimizer and criterion model, num_classes, best_value, train_loader, val_loader, test_loader, optimizer, criterion = cls.prepare(**kwargs) # Train routine train_value, val_value = cls.train_routine(model=model, best_value=best_value, optimizer=optimizer, criterion=criterion, train_loader=train_loader, val_loader=val_loader, **kwargs) # Test routine test_value = cls.test_routine(criterion=criterion, num_classes=num_classes, test_loader=test_loader, **kwargs) return train_value, val_value, test_value
####################################################################################################################
[docs] @classmethod def prepare(cls, model_name, **kwargs): """ Loads and prepares the data, the optimizer and the criterion Parameters ---------- model_name : str Name of the model. Used for loading the model. kwargs : dict Any additional arguments. Returns ------- model : DataParallel The model to train num_classes : int How many different classes there are in our problem. Used for loading the model. best_value : float Best value of the model so far. Non-zero only in case of --resume being used train_loader : torch.utils.data.dataloader.DataLoader Training dataloader val_loader : torch.utils.data.dataloader.DataLoader Validation dataloader test_loader : torch.utils.data.dataloader.DataLoader Test set dataloader optimizer : torch.optim Optimizer to use during training, e.g. SGD criterion : torch.nn.modules.loss Loss function to use, e.g. cross-entropy """ # Get the selected model input size model_expected_input_size = models.__dict__[model_name]().expected_input_size if type(model_expected_input_size) is not tuple or len(model_expected_input_size) != 2: logging.error('Model {model_name} expected input size is not a tuple. ' 'Received: {model_expected_input_size}' .format(model_name=model_name, model_expected_input_size=model_expected_input_size)) sys.exit(-1) logging.info('Model {} expects input size of {}'.format(model_name, model_expected_input_size)) # Setting up the dataloaders train_loader, val_loader, test_loader, num_classes = set_up_dataloaders(model_expected_input_size, **kwargs) # Setting up model, optimizer, criterion model, criterion, optimizer, best_value = set_up_model(model_name=model_name, num_classes=num_classes, **kwargs) return model, num_classes, best_value, train_loader, val_loader, test_loader, optimizer, criterion
[docs] @classmethod def train_routine(cls, best_value, decay_lr, validation_interval, start_epoch, epochs, checkpoint_all_epochs, current_log_folder, **kwargs): """ Performs the training and validatation routines Parameters ---------- best_value : float Best value of the model so far. Non-zero only in case of --resume being used decay_lr : boolean Decay the lr flag validation_interval : int Run evaluation on validation set every N epochs start_epoch : int Int to initialize the starting epoch. Non-zero only in case of --resume being used epochs : int Number of epochs to train checkpoint_all_epochs : bool Save checkpoint at each epoch current_log_folder : string Path to where logs/checkpoints are saved kwargs : dict Any additional arguments. Returns ------- train_value : ndarray[floats] of size (1, `epochs`) Accuracy values for train split val_value : ndarray[floats] of size (1, `epochs`+1) Accuracy values for validation split """ logging.info('Begin training') val_value = np.zeros((epochs + 1 - start_epoch)) train_value = np.zeros((epochs - start_epoch)) # Validate before training val_value[-1] = cls._validate(epoch=-1, **kwargs) for epoch in range(start_epoch, epochs): # Train train_value[epoch] = cls._train(epoch=epoch, **kwargs) # Validate if epoch % validation_interval == 0: val_value[epoch] = cls._validate(epoch=epoch, **kwargs) if decay_lr is not None: adjust_learning_rate(epoch=epoch, decay_lr_epochs=decay_lr, **kwargs) # Checkpoint best_value = checkpoint(epoch=epoch, new_value=val_value[epoch], best_value=best_value, log_dir=current_log_folder, checkpoint_all_epochs=checkpoint_all_epochs, **kwargs) logging.info('Training done') return train_value, val_value
[docs] @classmethod def test_routine(cls, model_name, num_classes, criterion, epochs, current_log_folder, writer, **kwargs): """ Load the best model according to the validation score (early stopping) and runs the test routine. Parameters ---------- model_name : str name of the model. Used for loading the model. num_classes : int How many different classes there are in our problem. Used for loading the model. criterion : torch.nn.modules.loss Loss function to use, e.g. cross-entropy epochs : int After how many epochs are we testing current_log_folder : string Path to where logs/checkpoints are saved writer : Tensorboard.SummaryWriter Responsible for writing logs in Tensorboard compatible format. kwargs : dict Any additional arguments. Returns ------- test_value : float Accuracy value for test split """ # Load the best model before evaluating on the test set. logging.info('Loading the best model before evaluating on the test set.') if os.path.exists(os.path.join(current_log_folder, 'model_best.pth.tar')): kwargs["load_model"] = os.path.join(current_log_folder, 'model_best.pth.tar') else: logging.warning('File model_best.pth.tar not found in {}'.format(current_log_folder)) logging.warning('Using checkpoint.pth.tar instead') kwargs["load_model"] = os.path.join(current_log_folder, 'checkpoint.pth.tar') model, _, _, _ = set_up_model(num_classes=num_classes, model_name=model_name, **kwargs) # Test test_value = cls._test(model=model, criterion=criterion, writer=writer, epoch=epochs - 1, **kwargs) logging.info('Training completed') return test_value
#################################################################################################################### """ These methods delegate their function to other classes in this package. It is useful because sub-classes can selectively change the logic of certain parts only. """ @classmethod def _train(cls, **kwargs): return train(**kwargs) @classmethod def _validate(cls, **kwargs): return evaluate(data_loader=kwargs['val_loader'], logging_label='val', **kwargs) @classmethod def _test(cls, **kwargs): return evaluate(data_loader=kwargs['test_loader'], logging_label='test', **kwargs)