bicubic.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4. class BicubicDownSample(nn.Module):
  5. def bicubic_kernel(self, x, a=-0.50):
  6. """
  7. This equation is exactly copied from the website below:
  8. https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic
  9. """
  10. abs_x = torch.abs(x)
  11. if abs_x <= 1.:
  12. return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1
  13. elif 1. < abs_x < 2.:
  14. return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a
  15. else:
  16. return 0.0
  17. def __init__(self, factor=4, cuda=True, padding='reflect'):
  18. super().__init__()
  19. self.factor = factor
  20. size = factor * 4
  21. k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor)
  22. for i in range(size)], dtype=torch.float32)
  23. k = k / torch.sum(k)
  24. # k = torch.einsum('i,j->ij', (k, k))
  25. k1 = torch.reshape(k, shape=(1, 1, size, 1))
  26. self.k1 = torch.cat([k1, k1, k1], dim=0)
  27. k2 = torch.reshape(k, shape=(1, 1, 1, size))
  28. self.k2 = torch.cat([k2, k2, k2], dim=0)
  29. self.cuda = '.cuda' if cuda else ''
  30. self.padding = padding
  31. for param in self.parameters():
  32. param.requires_grad = False
  33. def forward(self, x, nhwc=False, clip_round=False, byte_output=False):
  34. # x = torch.from_numpy(x).type('torch.FloatTensor')
  35. filter_height = self.factor * 4
  36. filter_width = self.factor * 4
  37. stride = self.factor
  38. pad_along_height = max(filter_height - stride, 0)
  39. pad_along_width = max(filter_width - stride, 0)
  40. filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda))
  41. filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda))
  42. # compute actual padding values for each side
  43. pad_top = pad_along_height // 2
  44. pad_bottom = pad_along_height - pad_top
  45. pad_left = pad_along_width // 2
  46. pad_right = pad_along_width - pad_left
  47. # apply mirror padding
  48. if nhwc:
  49. x = torch.transpose(torch.transpose(
  50. x, 2, 3), 1, 2) # NHWC to NCHW
  51. # downscaling performed by 1-d convolution
  52. x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding)
  53. x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3)
  54. if clip_round:
  55. x = torch.clamp(torch.round(x), 0.0, 255.)
  56. x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding)
  57. x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3)
  58. if clip_round:
  59. x = torch.clamp(torch.round(x), 0.0, 255.)
  60. if nhwc:
  61. x = torch.transpose(torch.transpose(x, 1, 3), 1, 2)
  62. if byte_output:
  63. return x.type('torch.ByteTensor'.format(self.cuda))
  64. else:
  65. return x