|
@@ -0,0 +1,175 @@
|
|
|
+import random
|
|
|
+import math
|
|
|
+import numbers
|
|
|
+import collections
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+from PIL import Image, ImageOps
|
|
|
+try:
|
|
|
+ import accimage
|
|
|
+except ImportError:
|
|
|
+ accimage = None
|
|
|
+
|
|
|
+
|
|
|
+class Compose(object):
|
|
|
+ """Composes several transforms together.
|
|
|
+ Args:
|
|
|
+ transforms (list of ``Transform`` objects): list of transforms to compose.
|
|
|
+ Example:
|
|
|
+ >>> transforms.Compose([
|
|
|
+ >>> transforms.CenterCrop(10),
|
|
|
+ >>> transforms.ToTensor(),
|
|
|
+ >>> ])
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, transforms):
|
|
|
+ self.transforms = transforms
|
|
|
+
|
|
|
+ def __call__(self, img):
|
|
|
+ for t in self.transforms:
|
|
|
+ img = t(img)
|
|
|
+ return img
|
|
|
+
|
|
|
+
|
|
|
+class ToTensor(object):
|
|
|
+ """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
|
|
|
+ Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
|
|
|
+ [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
|
|
|
+ """
|
|
|
+
|
|
|
+ def __call__(self, pic):
|
|
|
+ """
|
|
|
+ Args:
|
|
|
+ pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
|
|
|
+ Returns:
|
|
|
+ Tensor: Converted image.
|
|
|
+ """
|
|
|
+ if isinstance(pic, np.ndarray):
|
|
|
+ # handle numpy array
|
|
|
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
|
|
+ # backward compatibility
|
|
|
+ return img.float()
|
|
|
+
|
|
|
+ if accimage is not None and isinstance(pic, accimage.Image):
|
|
|
+ nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
|
|
|
+ pic.copyto(nppic)
|
|
|
+ return torch.from_numpy(nppic)
|
|
|
+
|
|
|
+ # handle PIL Image
|
|
|
+ if pic.mode == 'I':
|
|
|
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
|
|
+ elif pic.mode == 'I;16':
|
|
|
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
|
|
+ else:
|
|
|
+ img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
|
|
|
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
|
|
+ if pic.mode == 'YCbCr':
|
|
|
+ nchannel = 3
|
|
|
+ elif pic.mode == 'I;16':
|
|
|
+ nchannel = 1
|
|
|
+ else:
|
|
|
+ nchannel = len(pic.mode)
|
|
|
+ img = img.view(pic.size[1], pic.size[0], nchannel)
|
|
|
+ # put it from HWC to CHW format
|
|
|
+ # yikes, this transpose takes 80% of the loading time/CPU
|
|
|
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
|
|
+ if isinstance(img, torch.ByteTensor):
|
|
|
+ return img.float()
|
|
|
+ else:
|
|
|
+ return img
|
|
|
+
|
|
|
+
|
|
|
+class Normalize(object):
|
|
|
+ """Normalize an tensor image with mean and standard deviation.
|
|
|
+ Given mean: (R, G, B) and std: (R, G, B),
|
|
|
+ will normalize each channel of the torch.*Tensor, i.e.
|
|
|
+ channel = (channel - mean) / std
|
|
|
+ Args:
|
|
|
+ mean (sequence): Sequence of means for R, G, B channels respecitvely.
|
|
|
+ std (sequence): Sequence of standard deviations for R, G, B channels
|
|
|
+ respecitvely.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, mean, std):
|
|
|
+ self.mean = mean
|
|
|
+ self.std = std
|
|
|
+
|
|
|
+ def __call__(self, tensor):
|
|
|
+ """
|
|
|
+ Args:
|
|
|
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
|
|
+ Returns:
|
|
|
+ Tensor: Normalized image.
|
|
|
+ """
|
|
|
+ # TODO: make efficient
|
|
|
+ for t, m, s in zip(tensor, self.mean, self.std):
|
|
|
+ t.sub_(m).div_(s)
|
|
|
+ return tensor
|
|
|
+
|
|
|
+
|
|
|
+class Scale(object):
|
|
|
+ """Rescale the input PIL.Image to the given size.
|
|
|
+ Args:
|
|
|
+ size (sequence or int): Desired output size. If size is a sequence like
|
|
|
+ (w, h), output size will be matched to this. If size is an int,
|
|
|
+ smaller edge of the image will be matched to this number.
|
|
|
+ i.e, if height > width, then image will be rescaled to
|
|
|
+ (size * height / width, size)
|
|
|
+ interpolation (int, optional): Desired interpolation. Default is
|
|
|
+ ``PIL.Image.BILINEAR``
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, size, interpolation=Image.BILINEAR):
|
|
|
+ assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
|
|
|
+ self.size = size
|
|
|
+ self.interpolation = interpolation
|
|
|
+
|
|
|
+ def __call__(self, img):
|
|
|
+ """
|
|
|
+ Args:
|
|
|
+ img (PIL.Image): Image to be scaled.
|
|
|
+ Returns:
|
|
|
+ PIL.Image: Rescaled image.
|
|
|
+ """
|
|
|
+ if isinstance(self.size, int):
|
|
|
+ w, h = img.size
|
|
|
+ if (w <= h and w == self.size) or (h <= w and h == self.size):
|
|
|
+ return img
|
|
|
+ if w < h:
|
|
|
+ ow = self.size
|
|
|
+ oh = int(self.size * h / w)
|
|
|
+ return img.resize((ow, oh), self.interpolation)
|
|
|
+ else:
|
|
|
+ oh = self.size
|
|
|
+ ow = int(self.size * w / h)
|
|
|
+ return img.resize((ow, oh), self.interpolation)
|
|
|
+ else:
|
|
|
+ return img.resize(self.size, self.interpolation)
|
|
|
+
|
|
|
+
|
|
|
+class CenterCrop(object):
|
|
|
+ """Crops the given PIL.Image at the center.
|
|
|
+ Args:
|
|
|
+ size (sequence or int): Desired output size of the crop. If size is an
|
|
|
+ int instead of sequence like (h, w), a square crop (size, size) is
|
|
|
+ made.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, size):
|
|
|
+ if isinstance(size, numbers.Number):
|
|
|
+ self.size = (int(size), int(size))
|
|
|
+ else:
|
|
|
+ self.size = size
|
|
|
+
|
|
|
+ def __call__(self, img):
|
|
|
+ """
|
|
|
+ Args:
|
|
|
+ img (PIL.Image): Image to be cropped.
|
|
|
+ Returns:
|
|
|
+ PIL.Image: Cropped image.
|
|
|
+ """
|
|
|
+ w, h = img.size
|
|
|
+ th, tw = self.size
|
|
|
+ x1 = int(round((w - tw) / 2.))
|
|
|
+ y1 = int(round((h - th) / 2.))
|
|
|
+ return img.crop((x1, y1, x1 + tw, y1 + th))
|