Source code for template.runner.triplet.transforms

import numbers

import numpy as np
import torch

[docs]class MultiCrop(object): """ Crop the given PIL Image into multiple random crops Parameters ---------- size : tuple or int Desired output size of the crop. If size is an ``int`` instead of sequence like (h, w), a square crop of size (size, size) is made. n_crops : int The number of crops to be generated from a page. Returns ------- None Example ------- >>> MultiCrop(size=model_expected_input_size, n_crops=multi_crop), >>> transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), >>> transforms.Lambda(lambda items: torch.stack([transforms.Normalize(mean=mean, std=std)(item) for item in items])) >>> transform = Compose([ >>> MultiCrop(size), # this is a list of PIL Images >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor >>> ]) >>> #In your test loop you can do the following: >>> input, target = batch # input is a 5d tensor, target is 2d >>> bs, ncrops, c, h, w = input.size() >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops """ def __init__(self, size, n_crops): # TODO: DOES NOT PLAY WELL WITH SEEDS. Figure out why! self.size = size self.n_crops = n_crops if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: assert len(size) == 2, "Please provide only two dimensions (h, w) for size." self.size = size def __call__(self, img): return multi_crop(img, self.size, self.n_crops)
[docs]def multi_crop(img, size, n_crops): """ Crop the given PIL Image into multiple random crops. Parameters ---------- img : PIL.Image The Image to be processed. size : tuple or int Desired output size of the crop. If size is an ``int`` instead of sequence like (h, w), a square crop of size (size, size) is made. n_crops : int The number of crops to be generated from a page. Returns ------- crops : list of PIL.Images A list of PIL.Images which are the crops from the page. """ if isinstance(size, numbers.Number): size = (int(size), int(size)) else: assert len(size) == 2, "Please provide only two dimensions (h, w) for size." w, h = img.size crop_h, crop_w = size if crop_w > w or crop_h > h: raise ValueError("Requested crop size {} is bigger than input size {}".format(size, (h, w))) crops = [] for i in range(n_crops): x = np.random.randint(0, w - crop_w) y = np.random.randint(0, h - crop_h) assert x + crop_w < w assert y + crop_h < h crops.append(img.crop((x, y, x + crop_w, y + crop_h))) return crops