spatial_transforms.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import random
  2. import math
  3. import numbers
  4. import collections
  5. import numpy as np
  6. import torch
  7. from PIL import Image, ImageOps
  8. try:
  9. import accimage
  10. except ImportError:
  11. accimage = None
  12. class Compose(object):
  13. """Composes several transforms together.
  14. Args:
  15. transforms (list of ``Transform`` objects): list of transforms to compose.
  16. Example:
  17. >>> transforms.Compose([
  18. >>> transforms.CenterCrop(10),
  19. >>> transforms.ToTensor(),
  20. >>> ])
  21. """
  22. def __init__(self, transforms):
  23. self.transforms = transforms
  24. def __call__(self, img):
  25. for t in self.transforms:
  26. img = t(img)
  27. return img
  28. class ToTensor(object):
  29. """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
  30. Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
  31. [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
  32. """
  33. def __call__(self, pic):
  34. """
  35. Args:
  36. pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
  37. Returns:
  38. Tensor: Converted image.
  39. """
  40. if isinstance(pic, np.ndarray):
  41. # handle numpy array
  42. img = torch.from_numpy(pic.transpose((2, 0, 1)))
  43. # backward compatibility
  44. return img.float()
  45. if accimage is not None and isinstance(pic, accimage.Image):
  46. nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
  47. pic.copyto(nppic)
  48. return torch.from_numpy(nppic)
  49. # handle PIL Image
  50. if pic.mode == 'I':
  51. img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  52. elif pic.mode == 'I;16':
  53. img = torch.from_numpy(np.array(pic, np.int16, copy=False))
  54. else:
  55. img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
  56. # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
  57. if pic.mode == 'YCbCr':
  58. nchannel = 3
  59. elif pic.mode == 'I;16':
  60. nchannel = 1
  61. else:
  62. nchannel = len(pic.mode)
  63. img = img.view(pic.size[1], pic.size[0], nchannel)
  64. # put it from HWC to CHW format
  65. # yikes, this transpose takes 80% of the loading time/CPU
  66. img = img.transpose(0, 1).transpose(0, 2).contiguous()
  67. if isinstance(img, torch.ByteTensor):
  68. return img.float()
  69. else:
  70. return img
  71. class Normalize(object):
  72. """Normalize an tensor image with mean and standard deviation.
  73. Given mean: (R, G, B) and std: (R, G, B),
  74. will normalize each channel of the torch.*Tensor, i.e.
  75. channel = (channel - mean) / std
  76. Args:
  77. mean (sequence): Sequence of means for R, G, B channels respecitvely.
  78. std (sequence): Sequence of standard deviations for R, G, B channels
  79. respecitvely.
  80. """
  81. def __init__(self, mean, std):
  82. self.mean = mean
  83. self.std = std
  84. def __call__(self, tensor):
  85. """
  86. Args:
  87. tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
  88. Returns:
  89. Tensor: Normalized image.
  90. """
  91. # TODO: make efficient
  92. for t, m, s in zip(tensor, self.mean, self.std):
  93. t.sub_(m).div_(s)
  94. return tensor
  95. class Scale(object):
  96. """Rescale the input PIL.Image to the given size.
  97. Args:
  98. size (sequence or int): Desired output size. If size is a sequence like
  99. (w, h), output size will be matched to this. If size is an int,
  100. smaller edge of the image will be matched to this number.
  101. i.e, if height > width, then image will be rescaled to
  102. (size * height / width, size)
  103. interpolation (int, optional): Desired interpolation. Default is
  104. ``PIL.Image.BILINEAR``
  105. """
  106. def __init__(self, size, interpolation=Image.BILINEAR):
  107. assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
  108. self.size = size
  109. self.interpolation = interpolation
  110. def __call__(self, img):
  111. """
  112. Args:
  113. img (PIL.Image): Image to be scaled.
  114. Returns:
  115. PIL.Image: Rescaled image.
  116. """
  117. if isinstance(self.size, int):
  118. w, h = img.size
  119. if (w <= h and w == self.size) or (h <= w and h == self.size):
  120. return img
  121. if w < h:
  122. ow = self.size
  123. oh = int(self.size * h / w)
  124. return img.resize((ow, oh), self.interpolation)
  125. else:
  126. oh = self.size
  127. ow = int(self.size * w / h)
  128. return img.resize((ow, oh), self.interpolation)
  129. else:
  130. return img.resize(self.size, self.interpolation)
  131. class CenterCrop(object):
  132. """Crops the given PIL.Image at the center.
  133. Args:
  134. size (sequence or int): Desired output size of the crop. If size is an
  135. int instead of sequence like (h, w), a square crop (size, size) is
  136. made.
  137. """
  138. def __init__(self, size):
  139. if isinstance(size, numbers.Number):
  140. self.size = (int(size), int(size))
  141. else:
  142. self.size = size
  143. def __call__(self, img):
  144. """
  145. Args:
  146. img (PIL.Image): Image to be cropped.
  147. Returns:
  148. PIL.Image: Cropped image.
  149. """
  150. w, h = img.size
  151. th, tw = self.size
  152. x1 = int(round((w - tw) / 2.))
  153. y1 = int(round((h - th) / 2.))
  154. return img.crop((x1, y1, x1 + tw, y1 + th))