123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- import random
- import math
- import numbers
- import collections
- import numpy as np
- import torch
- from PIL import Image, ImageOps
- try:
- import accimage
- except ImportError:
- accimage = None
- class Compose(object):
- """Composes several transforms together.
- Args:
- transforms (list of ``Transform`` objects): list of transforms to compose.
- Example:
- >>> transforms.Compose([
- >>> transforms.CenterCrop(10),
- >>> transforms.ToTensor(),
- >>> ])
- """
- def __init__(self, transforms):
- self.transforms = transforms
- def __call__(self, img):
- for t in self.transforms:
- img = t(img)
- return img
- class ToTensor(object):
- """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
- Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
- [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
- """
- def __call__(self, pic):
- """
- Args:
- pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
- Returns:
- Tensor: Converted image.
- """
- if isinstance(pic, np.ndarray):
- # handle numpy array
- img = torch.from_numpy(pic.transpose((2, 0, 1)))
- # backward compatibility
- return img.float()
- if accimage is not None and isinstance(pic, accimage.Image):
- nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
- pic.copyto(nppic)
- return torch.from_numpy(nppic)
- # handle PIL Image
- if pic.mode == 'I':
- img = torch.from_numpy(np.array(pic, np.int32, copy=False))
- elif pic.mode == 'I;16':
- img = torch.from_numpy(np.array(pic, np.int16, copy=False))
- else:
- img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
- # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
- if pic.mode == 'YCbCr':
- nchannel = 3
- elif pic.mode == 'I;16':
- nchannel = 1
- else:
- nchannel = len(pic.mode)
- img = img.view(pic.size[1], pic.size[0], nchannel)
- # put it from HWC to CHW format
- # yikes, this transpose takes 80% of the loading time/CPU
- img = img.transpose(0, 1).transpose(0, 2).contiguous()
- if isinstance(img, torch.ByteTensor):
- return img.float()
- else:
- return img
- class Normalize(object):
- """Normalize an tensor image with mean and standard deviation.
- Given mean: (R, G, B) and std: (R, G, B),
- will normalize each channel of the torch.*Tensor, i.e.
- channel = (channel - mean) / std
- Args:
- mean (sequence): Sequence of means for R, G, B channels respecitvely.
- std (sequence): Sequence of standard deviations for R, G, B channels
- respecitvely.
- """
- def __init__(self, mean, std):
- self.mean = mean
- self.std = std
- def __call__(self, tensor):
- """
- Args:
- tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
- Returns:
- Tensor: Normalized image.
- """
- # TODO: make efficient
- for t, m, s in zip(tensor, self.mean, self.std):
- t.sub_(m).div_(s)
- return tensor
- class Scale(object):
- """Rescale the input PIL.Image to the given size.
- Args:
- size (sequence or int): Desired output size. If size is a sequence like
- (w, h), output size will be matched to this. If size is an int,
- smaller edge of the image will be matched to this number.
- i.e, if height > width, then image will be rescaled to
- (size * height / width, size)
- interpolation (int, optional): Desired interpolation. Default is
- ``PIL.Image.BILINEAR``
- """
- def __init__(self, size, interpolation=Image.BILINEAR):
- assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
- self.size = size
- self.interpolation = interpolation
- def __call__(self, img):
- """
- Args:
- img (PIL.Image): Image to be scaled.
- Returns:
- PIL.Image: Rescaled image.
- """
- if isinstance(self.size, int):
- w, h = img.size
- if (w <= h and w == self.size) or (h <= w and h == self.size):
- return img
- if w < h:
- ow = self.size
- oh = int(self.size * h / w)
- return img.resize((ow, oh), self.interpolation)
- else:
- oh = self.size
- ow = int(self.size * w / h)
- return img.resize((ow, oh), self.interpolation)
- else:
- return img.resize(self.size, self.interpolation)
- class CenterCrop(object):
- """Crops the given PIL.Image at the center.
- Args:
- size (sequence or int): Desired output size of the crop. If size is an
- int instead of sequence like (h, w), a square crop (size, size) is
- made.
- """
- def __init__(self, size):
- if isinstance(size, numbers.Number):
- self.size = (int(size), int(size))
- else:
- self.size = size
- def __call__(self, img):
- """
- Args:
- img (PIL.Image): Image to be cropped.
- Returns:
- PIL.Image: Cropped image.
- """
- w, h = img.size
- th, tw = self.size
- x1 = int(round((w - tw) / 2.))
- y1 = int(round((h - th) / 2.))
- return img.crop((x1, y1, x1 + tw, y1 + th))
|