qualitative.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # adapted from http://cs231n.github.io/assignments2019/assignment3/
  2. import random
  3. from typing import Dict, List
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import torch
  7. from PIL import Image
  8. from scipy.ndimage.filters import gaussian_filter1d
  9. from torch import nn
  10. import params
  11. from pipeline import BuildDataset
  12. class QualitativeAnalysis:
  13. def __init__(self,
  14. model: nn.Module,
  15. img_size: int,
  16. videos_and_frames: Dict[str, List[int]],
  17. class_names: Dict[int, str]):
  18. self.class_names = class_names
  19. self.img_size = img_size
  20. self.model = model
  21. self.videos_and_frames = videos_and_frames
  22. self.features = {
  23. vid: BuildDataset.one_video_extract_audio_and_stills(vid)
  24. for vid in self.videos_and_frames}
  25. # method is adapted from stanford cs231n assignment 3 available at:
  26. # http://cs231n.github.io/assignments2019/assignment3/
  27. @staticmethod
  28. def _compute_saliency_maps(A, I, y, model):
  29. """
  30. Compute a class saliency map using the model for images X and labels y.
  31. Input:
  32. - A: Input audio; Tensor of shape (N, 1, 96, 64)
  33. - I: Input images; Tensor of shape (N, 3, H, W)
  34. - y: Labels for X; LongTensor of shape (N,)
  35. - model: A pretrained CNN that will be used to compute the saliency map.
  36. Returns:
  37. - saliency: A Tensor of shape (N, H, W) giving the saliency maps for the input
  38. images.
  39. """
  40. # Make sure the model is in "test" mode
  41. model.eval()
  42. # Make input tensor require gradient
  43. # A.requires_grad_()
  44. I.requires_grad_()
  45. scores = model(A, I).gather(1, y.view(-1, 1)).squeeze()
  46. scores.backward(torch.ones(scores.size()))
  47. saliency, _ = torch.max(I.grad.abs(), dim=1)
  48. return saliency
  49. # also adapted from cs231n assignment 3
  50. def _show_saliency_maps(self, A, I, y):
  51. # Convert X and y from numpy arrays to Torch Tensors
  52. I_tensor = torch.cat([
  53. BuildDataset.transformer(self.img_size)(Image.fromarray(i)).unsqueeze(0)
  54. for i in I], dim=0)
  55. A_tensor = torch.cat([a.unsqueeze(0) for a in A])
  56. y_tensor = torch.LongTensor(y)
  57. # Compute saliency maps for images in X
  58. saliency = self._compute_saliency_maps(A_tensor, I_tensor, y_tensor, self.model)
  59. # Convert the saliency map from Torch Tensor to numpy array and show images
  60. # and saliency maps together.
  61. saliency = saliency.numpy()
  62. N = len(I)
  63. for i in range(N):
  64. plt.subplot(2, N, i + 1)
  65. plt.imshow(I[i])
  66. plt.axis('off')
  67. plt.title(self.class_names[y[i]])
  68. plt.subplot(2, N, N + i + 1)
  69. plt.imshow(saliency[i], cmap=plt.cm.hot)
  70. plt.axis('off')
  71. plt.gcf().set_size_inches(12, 5)
  72. plt.show()
  73. @staticmethod
  74. def _img_transform_reverse_to_np(x: torch.Tensor) -> np.array:
  75. rev = BuildDataset.transform_reverse(x)
  76. return np.array(rev)
  77. def saliency_maps(self):
  78. for vid, indices in self.videos_and_frames.items():
  79. A = [self.features[vid][0][idx] for idx in indices]
  80. I = [self._img_transform_reverse_to_np(self.features[vid][1][idx])
  81. for idx in indices]
  82. y = [1 if 'kissing' in vid else 0] * len(A)
  83. self._show_saliency_maps(A, I, y)
  84. print('=' * 10)
  85. # next few methods taken from cs231n
  86. @staticmethod
  87. def jitter(X, ox, oy):
  88. """
  89. Helper function to randomly jitter an image.
  90. Inputs
  91. - X: PyTorch Tensor of shape (N, C, H, W)
  92. - ox, oy: Integers giving number of pixels to jitter along W and H axes
  93. Returns: A new PyTorch Tensor of shape (N, C, H, W)
  94. """
  95. if ox != 0:
  96. left = X[:, :, :, :-ox]
  97. right = X[:, :, :, -ox:]
  98. X = torch.cat([right, left], dim=3)
  99. if oy != 0:
  100. top = X[:, :, :-oy]
  101. bottom = X[:, :, -oy:]
  102. X = torch.cat([bottom, top], dim=2)
  103. return X
  104. @staticmethod
  105. def _blur_image(X, sigma=1):
  106. X_np = X.cpu().clone().numpy()
  107. X_np = gaussian_filter1d(X_np, sigma, axis=2)
  108. X_np = gaussian_filter1d(X_np, sigma, axis=3)
  109. X.copy_(torch.Tensor(X_np).type_as(X))
  110. return X
  111. def create_class_visualization(self, target_y, model, dtype, a, **kwargs):
  112. """
  113. Generate an image to maximize the score of target_y under a pretrained model.
  114. Inputs:
  115. - target_y: Integer in the range [0, 1000) giving the index of the class
  116. - model: A pretrained CNN that will be used to generate the image
  117. - dtype: Torch datatype to use for computations
  118. Keyword arguments:
  119. - l2_reg: Strength of L2 regularization on the image
  120. - learning_rate: How big of a step to take
  121. - num_iterations: How many iterations to use
  122. - blur_every: How often to blur the image as an implicit regularizer
  123. - max_jitter: How much to gjitter the image as an implicit regularizer
  124. - show_every: How often to show the intermediate result
  125. """
  126. def deprocess(x):
  127. return BuildDataset.transform_reverse(x.squeeze(0))
  128. model.type(dtype)
  129. l2_reg = kwargs.pop('l2_reg', 1e-3)
  130. learning_rate = kwargs.pop('learning_rate', 25)
  131. num_iterations = kwargs.pop('num_iterations', 100)
  132. blur_every = kwargs.pop('blur_every', 10)
  133. max_jitter = kwargs.pop('max_jitter', 16)
  134. show_every = kwargs.pop('show_every', 25)
  135. # Randomly initialize the image as a PyTorch Tensor, and make it requires gradient.
  136. img = torch.randn(1, 3, 224, 224).mul_(1.0).type(dtype).requires_grad_()
  137. for t in range(num_iterations):
  138. # Randomly jitter the image a bit; this gives slightly nicer results
  139. ox, oy = random.randint(0, max_jitter), random.randint(0, max_jitter)
  140. img.data.copy_(self.jitter(img.data, ox, oy))
  141. target = model(a, img)[0, target_y]
  142. target.backward()
  143. g = img.grad.data
  144. g -= 2 * l2_reg * img.data
  145. img.data += learning_rate * (g / g.norm())
  146. img.grad.zero_()
  147. # Undo the random jitter
  148. img.data.copy_(self.jitter(img.data, -ox, -oy))
  149. # As regularizer, clamp and periodically blur the image
  150. for c in range(3):
  151. lo = float(-params.mean[c] / params.std[c])
  152. hi = float((1.0 - params.mean[c]) / params.std[c])
  153. img.data[:, c].clamp_(min=lo, max=hi)
  154. if t % blur_every == 0:
  155. self._blur_image(img.data, sigma=0.5)
  156. # Periodically show the image
  157. if t == 0 or (t + 1) % show_every == 0 or t == num_iterations - 1:
  158. plt.imshow(deprocess(img.data.clone().cpu()))
  159. class_name = self.class_names[target_y]
  160. plt.title('%s\nIteration %d / %d' % (class_name, t + 1, num_iterations))
  161. plt.gcf().set_size_inches(4, 4)
  162. plt.axis('off')
  163. plt.show()
  164. return deprocess(img.data.cpu())