"""
This file is the template for the boilerplate of train/test of a triplet network.
This code has initially been adapted to our purposes from
http://www.iis.ee.ic.ac.uk/%7Evbalnt/shallow_descr/TFeat_paper.pdf
"""
# Utils
from __future__ import print_function
import logging
import sys
import numpy as np
import torch
import torch.nn as nn
# DeepDIVA
import models
from torch.nn import init
from template.runner.triplet.setup import setup_dataloaders
from template.setup import set_up_model
from util.misc import adjust_learning_rate, checkpoint
# Delegated
from template.runner.triplet import train, evaluate
#######################################################################################################################
[docs]class Triplet:
[docs] @staticmethod
def single_run(writer, current_log_folder, model_name, epochs, lr, decay_lr,
margin, anchor_swap, validation_interval, regenerate_every,
checkpoint_all_epochs, only_evaluate, **kwargs):
"""
This is the main routine where train(), validate() and test() are called.
Parameters
----------
writer : Tensorboard SummaryWriter
Responsible for writing logs in Tensorboard compatible format.
current_log_folder : string
Path to where logs/checkpoints are saved
model_name : string
Name of the model
epochs : int
Number of epochs to train
lr : float
Value for learning rate
margin : float
The margin value for the triplet loss function
anchor_swap : boolean
Turns on anchor swap
decay_lr : boolean
Decay the lr flag
validation_interval : int
Run evaluation on validation set every N epochs
regenerate_every : int
Re-generate triplets every N epochs
checkpoint_all_epochs : bool
If enabled, save checkpoint after every epoch.
only_evaluate : boolean
Flag : if True, only the test set is loaded.
Returns
-------
train_value, val_value, test_value
Mean Average Precision values for train and validation splits.
"""
# Sanity check on parameters
if kwargs["output_channels"] is None:
logging.error("Using triplet class but --output-channels is not specified.")
sys.exit(-1)
# Get the selected model input size
model_expected_input_size = models.__dict__[model_name]().expected_input_size
Triplet._validate_model_input_size(model_expected_input_size, model_name)
logging.info('Model {} expects input size of {}'.format(model_name, model_expected_input_size))
# Setting up the dataloaders
if only_evaluate:
_, _, test_loader = setup_dataloaders(
model_expected_input_size=model_expected_input_size,
only_evaluate=only_evaluate,
**kwargs)
else:
train_loader, val_loader, test_loader = setup_dataloaders(
model_expected_input_size=model_expected_input_size,
**kwargs)
# Setting up model, optimizer, criterion
model, _, optimizer, best_value = set_up_model(model_name=model_name,
lr=lr,
**kwargs)
# Set the special criterion for triplets
criterion = nn.TripletMarginLoss(margin=margin, swap=anchor_swap)
train_value = np.zeros(epochs)
val_value = np.zeros(epochs)
if not only_evaluate:
# Core routine
logging.info('Begin training')
Triplet._validate(val_loader, model, writer, -1, **kwargs)
for epoch in range(epochs):
# Train
train_value[epoch] = Triplet._train(train_loader=train_loader,
model=model,
criterion=criterion,
optimizer=optimizer,
writer=writer,
epoch=epoch,
**kwargs)
# Validate
if epoch % validation_interval == 0:
val_value[epoch] = Triplet._validate(val_loader=val_loader,
model=model,
writer=writer,
epoch=epoch,
**kwargs)
if decay_lr is not None:
adjust_learning_rate(lr, optimizer, epoch, epochs)
best_value = checkpoint(epoch=epoch,
new_value=val_value[epoch],
best_value=best_value,
model=model,
optimizer=optimizer,
log_dir=current_log_folder,
invert_best=False,
checkpoint_all_epochs=checkpoint_all_epochs)
# Generate new triplets every N epochs
if epoch % regenerate_every == 0:
train_loader.triplets = train_loader.dataset.generate_triplets()
logging.info('Training completed')
# Test
test_value = Triplet._test(test_loader=test_loader,
model=model,
writer=writer,
epoch=(epochs - 1),
**kwargs)
return train_value, val_value, test_value
####################################################################################################################
@staticmethod
def _validate_model_input_size(model_expected_input_size, model_name):
"""
This method verifies that the model expected input size is a tuple of 2 elements.
This is necessary to avoid confusion with models which run on other types of data.
Parameters
----------
model_expected_input_size
The item retrieved from the model which corresponds to the expected input size
model_name : String
Name of the model (logging purpose only)
Returns
-------
None
"""
if type(model_expected_input_size) is not tuple or len(model_expected_input_size) != 2:
logging.error('Model {model_name} expected input size is not a tuple. '
'Received: {model_expected_input_size}'
.format(model_name=model_name,
model_expected_input_size=model_expected_input_size))
sys.exit(-1)
####################################################################################################################
# These methods delegate their function to other classes in this package.
# It is useful because sub-classes can selectively change the logic of certain parts only.
@classmethod
def _train(cls, train_loader, model, criterion, optimizer, writer, epoch, **kwargs):
return train.train(train_loader, model, criterion, optimizer, writer, epoch, **kwargs)
@classmethod
def _validate(cls, val_loader, model, writer, epoch, **kwargs):
return evaluate.validate(val_loader, model, writer, epoch, **kwargs)
@classmethod
def _test(cls, test_loader, model, writer, epoch, **kwargs):
return evaluate.test(test_loader, model, writer, epoch, **kwargs)