Source code for vlkit.transforms.compose

import torch
import numpy as np
from PIL import Image


[docs]class RandomChoice(torch.nn.Module): """Random choose transforms Args: transforms (list): list of transforms to be selected n (int): number of transforms will be selected in each step p (list): probabilities if each transforms being selected """ def __init__(self, transforms, p=0.5, n=1, p_choice=None): super().__init__() assert isinstance(transforms, (list, type(None))) self.transforms = transforms self.p = p self.n = n self.p_choice = p_choice if isinstance(self.p_choice, list): assert len(self.p_choice) == len(self.transforms) self.p_choice = np.array(self.p_choice) assert self.p_choice.min() >= 0 self.p_choice /= self.p_choice.sum()
[docs] def forward(self, x:Image.Image) -> Image.Image: if np.random.uniform() < self.p: transforms = np.random.choice(self.transforms, size=self.n, p=self.p_choice) for t in transforms: x = t(x) return x
def __repr__(self): return self.__class__.__name__ + '(transforms={0}, p={1}, n={2}, p_choice={3})'.format( self.transforms, self.p, self.n, self.p_choice)