Source code for template.runner.image_classification.evaluate

# Utils
import logging
import time
import warnings

import numpy as np
# Torch related stuff
import torch
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm

from util.evaluation.metrics import accuracy
# DeepDIVA
from util.misc import AverageMeter, _prettyprint_logging_label, save_image_and_log_to_tensorboard
from util.visualization.confusion_matrix_heatmap import make_heatmap


[docs]def evaluate(data_loader, model, criterion, writer, epoch, logging_label, no_cuda=False, log_interval=20, **kwargs): """ The evaluation routine Parameters ---------- data_loader : torch.utils.data.DataLoader The dataloader of the evaluation set model : torch.nn.module The network model being used criterion: torch.nn.loss The loss function used to compute the loss of the model writer : tensorboardX.writer.SummaryWriter The tensorboard writer object. Used to log values on file for the tensorboard visualization. epoch : int Number of the epoch (for logging purposes) logging_label : string Label for logging purposes. Typically 'test' or 'valid'. Its prepended to the logging output path and messages. no_cuda : boolean Specifies whether the GPU should be used or not. A value of 'True' means the CPU will be used. log_interval : int Interval limiting the logging of mini-batches. Default value of 10. Returns ------- top1.avg : float Accuracy of the model of the evaluated split """ # 'Run' is injected in kwargs at runtime IFF it is a multi-run event multi_run = kwargs['run'] if 'run' in kwargs else None # Instantiate the counters batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() data_time = AverageMeter() # Switch to evaluate mode (turn off dropout & such ) model.eval() # Iterate over whole evaluation set end = time.time() # Empty lists to store the predictions and target values preds = [] targets = [] pbar = tqdm(enumerate(data_loader), total=len(data_loader), unit='batch', ncols=150, leave=False) with torch.no_grad(): for batch_idx, (input, target) in pbar: # Measure data loading time data_time.update(time.time() - end) # Moving data to GPU if not no_cuda: input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # Compute output output = model(input) # Compute and record the loss loss = criterion(output, target) losses.update(loss.item(), input.size(0)) # Compute and record the accuracy acc1 = accuracy(output.data, target, topk=(1,))[0] top1.update(acc1[0], input.size(0)) # Get the predictions _ = [preds.append(item) for item in [np.argmax(item) for item in output.data.cpu().numpy()]] _ = [targets.append(item) for item in target.cpu().numpy()] # Add loss and accuracy to Tensorboard if multi_run is None: writer.add_scalar(logging_label + '/mb_loss', loss.item(), epoch * len(data_loader) + batch_idx) writer.add_scalar(logging_label + '/mb_accuracy', acc1.cpu().numpy(), epoch * len(data_loader) + batch_idx) else: writer.add_scalar(logging_label + '/mb_loss_{}'.format(multi_run), loss.item(), epoch * len(data_loader) + batch_idx) writer.add_scalar(logging_label + '/mb_accuracy_{}'.format(multi_run), acc1.cpu().numpy(), epoch * len(data_loader) + batch_idx) # Measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % log_interval == 0: pbar.set_description(logging_label + ' epoch [{0}][{1}/{2}]\t'.format(epoch, batch_idx, len(data_loader))) pbar.set_postfix(Time='{batch_time.avg:.3f}\t'.format(batch_time=batch_time), Loss='{loss.avg:.4f}\t'.format(loss=losses), Acc1='{top1.avg:.3f}\t'.format(top1=top1), Data='{data_time.avg:.3f}\t'.format(data_time=data_time)) # Make a confusion matrix try: cm = confusion_matrix(y_true=targets, y_pred=preds) confusion_matrix_heatmap = make_heatmap(cm, data_loader.dataset.classes) except ValueError: logging.warning('Confusion Matrix did not work as expected') confusion_matrix_heatmap = np.zeros((10, 10, 3)) # Logging the epoch-wise accuracy and confusion matrix if multi_run is None: writer.add_scalar(logging_label + '/accuracy', top1.avg, epoch) save_image_and_log_to_tensorboard(writer, tag=logging_label + '/confusion_matrix', image=confusion_matrix_heatmap, global_step=epoch) else: writer.add_scalar(logging_label + '/accuracy_{}'.format(multi_run), top1.avg, epoch) save_image_and_log_to_tensorboard(writer, tag=logging_label + '/confusion_matrix_{}'.format(multi_run), image=confusion_matrix_heatmap, global_step=epoch) logging.info(_prettyprint_logging_label(logging_label) + ' epoch[{}]: ' 'Acc@1={top1.avg:.3f}\t' 'Loss={loss.avg:.4f}\t' 'Batch time={batch_time.avg:.3f} ({data_time.avg:.3f} to load data)' .format(epoch, batch_time=batch_time, data_time=data_time, loss=losses, top1=top1)) # Generate a classification report for each epoch _log_classification_report(data_loader, epoch, preds, targets, writer) return top1.avg.item()
def _log_classification_report(data_loader, epoch, preds, targets, writer): """ This routine computes and prints on Tensorboard TEXT a classification report with F1 score, Precision, Recall and similar metrics computed per-class. Parameters ---------- data_loader : torch.utils.data.DataLoader The dataloader of the evaluation set epoch : int Number of the epoch (for logging purposes) preds : list List of all predictions of the model for this epoch targets : list List of all correct labels for this epoch writer : tensorboardX.writer.SummaryWriter The tensorboard writer object. Used to log values on file for the tensorboard visualization. Returns ------- None """ with warnings.catch_warnings(): warnings.simplefilter('ignore') classification_report_string = str(classification_report(y_true=targets, y_pred=preds, target_names=[str(item) for item in data_loader.dataset.classes])) # Fix for TB writer. Its an ugly workaround to have it printed nicely in the TEXT section of TB. classification_report_string = classification_report_string.replace('\n ', '\n\n ') classification_report_string = classification_report_string.replace('precision', ' precision', 1) classification_report_string = classification_report_string.replace('avg', ' avg', 1) writer.add_text('Classification Report for epoch {}\n'.format(epoch), '\n' + classification_report_string, epoch)