# Utils
import logging
import time
# Torch related stuff
import torch
from tqdm import tqdm
# DeepDIVA
from util.misc import AverageMeter
from util.evaluation.metrics import accuracy
[docs]def train(train_loader, model, criterion, optimizer, writer, epoch, no_cuda=False, log_interval=25,
**kwargs):
"""
Training routine
Parameters
----------
train_loader : torch.utils.data.DataLoader
The dataloader of the train set.
model : torch.nn.module
The network model being used.
criterion : torch.nn.loss
The loss function used to compute the loss of the model.
optimizer : torch.optim
The optimizer used to perform the weight update.
writer : tensorboardX.writer.SummaryWriter
The tensorboard writer object. Used to log values on file for the tensorboard visualization.
epoch : int
Number of the epoch (for logging purposes).
no_cuda : boolean
Specifies whether the GPU should be used or not. A value of 'True' means the CPU will be used.
log_interval : int
Interval limiting the logging of mini-batches. Default value of 10.
Returns
----------
top1.avg : float
Accuracy of the model of the evaluated split
"""
multi_run = kwargs['run'] if 'run' in kwargs else None
# Instantiate the counters
batch_time = AverageMeter()
loss_meter = AverageMeter()
acc_meter = AverageMeter()
data_time = AverageMeter()
# Switch to train mode (turn on dropout & stuff)
model.train()
# Iterate over whole training set
end = time.time()
pbar = tqdm(enumerate(train_loader), total=len(train_loader), unit='batch', ncols=150, leave=False)
for batch_idx, (input, target) in pbar:
# Measure data loading time
data_time.update(time.time() - end)
# Moving data to GPU
if not no_cuda:
input = input.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# Convert the input and its labels to Torch Variables
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
acc, loss = train_one_mini_batch(model, criterion, optimizer, input_var, target_var, loss_meter, acc_meter)
# Add loss and accuracy to Tensorboard
if multi_run is None:
writer.add_scalar('train/mb_loss', loss.item(), epoch * len(train_loader) + batch_idx)
writer.add_scalar('train/mb_accuracy', acc.cpu().numpy(), epoch * len(train_loader) + batch_idx)
else:
writer.add_scalar('train/mb_loss_{}'.format(multi_run), loss.item(),
epoch * len(train_loader) + batch_idx)
writer.add_scalar('train/mb_accuracy_{}'.format(multi_run), acc.cpu().numpy(),
epoch * len(train_loader) + batch_idx)
# Measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# Log to console
if batch_idx % log_interval == 0:
pbar.set_description('train epoch [{0}][{1}/{2}]\t'.format(epoch, batch_idx, len(train_loader)))
pbar.set_postfix(Time='{batch_time.avg:.3f}\t'.format(batch_time=batch_time),
Loss='{loss.avg:.4f}\t'.format(loss=loss_meter),
Acc1='{acc_meter.avg:.3f}\t'.format(acc_meter=acc_meter),
Data='{data_time.avg:.3f}\t'.format(data_time=data_time))
# Logging the epoch-wise accuracy
if multi_run is None:
writer.add_scalar('train/accuracy', acc_meter.avg, epoch)
else:
writer.add_scalar('train/accuracy_{}'.format(multi_run), acc_meter.avg, epoch)
logging.debug('Train epoch[{}]: '
'Acc@1={acc_meter.avg:.3f}\t'
'Loss={loss.avg:.4f}\t'
'Batch time={batch_time.avg:.3f} ({data_time.avg:.3f} to load data)'
.format(epoch, batch_time=batch_time, data_time=data_time, loss=loss_meter, acc_meter=acc_meter))
return acc_meter.avg
[docs]def train_one_mini_batch(model, criterion, optimizer, input_var, target_var, loss_meter, acc_meter):
"""
This routing train the model passed as parameter for one mini-batch
Parameters
----------
model : torch.nn.module
The network model being used.
criterion : torch.nn.loss
The loss function used to compute the loss of the model.
optimizer : torch.optim
The optimizer used to perform the weight update.
input_var : torch.autograd.Variable
The input data for the mini-batch
target_var : torch.autograd.Variable
The target data (labels) for the mini-batch
loss_meter : AverageMeter
Tracker for the overall loss
acc_meter : AverageMeter
Tracker for the overall accuracy
Returns
-------
acc : float
Accuracy for this mini-batch
loss : float
Loss for this mini-batch
"""
# Compute output
output = model(input_var)
# Compute and record the loss
loss = criterion(output, target_var)
loss_meter.update(loss.item(), len(input_var))
# Compute and record the accuracy
acc = accuracy(output.data, target_var.data, topk=(1,))[0]
acc_meter.update(acc[0], len(input_var))
# Reset gradient
optimizer.zero_grad()
# Compute gradients
loss.backward()
# Perform a step by updating the weights
optimizer.step()
return acc, loss