qualitative.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. from typing import Dict, List
  2. import matplotlib.pyplot as plt
  3. import torch
  4. from PIL import Image
  5. from torch import nn
  6. import numpy as np
  7. from pipeline import BuildDataset
  8. class QualitativeAnalysis:
  9. def __init__(self,
  10. model: nn.Module,
  11. img_size: int,
  12. videos_and_frames: Dict[str, List[int]],
  13. class_names: Dict[int, str]):
  14. self.class_names = class_names
  15. self.img_size = img_size
  16. self.model = model
  17. self.videos_and_frames = videos_and_frames
  18. self.features = {
  19. vid: BuildDataset.one_video_extract_audio_and_stills(vid)
  20. for vid in self.videos_and_frames}
  21. # method is adapted from stanford cs231n assignment 3 available at:
  22. # http://cs231n.github.io/assignments2019/assignment3/
  23. @staticmethod
  24. def _compute_saliency_maps(A, I, y, model):
  25. """
  26. Compute a class saliency map using the model for images X and labels y.
  27. Input:
  28. - A: Input audio; Tensor of shape (N, 1, 96, 64)
  29. - I: Input images; Tensor of shape (N, 3, H, W)
  30. - y: Labels for X; LongTensor of shape (N,)
  31. - model: A pretrained CNN that will be used to compute the saliency map.
  32. Returns:
  33. - saliency: A Tensor of shape (N, H, W) giving the saliency maps for the input
  34. images.
  35. """
  36. # Make sure the model is in "test" mode
  37. model.eval()
  38. # Make input tensor require gradient
  39. # A.requires_grad_()
  40. I.requires_grad_()
  41. scores = model(A, I).gather(1, y.view(-1, 1)).squeeze()
  42. scores.backward(torch.ones(scores.size()))
  43. saliency, _ = torch.max(I.grad.abs(), dim=1)
  44. return saliency
  45. # also adapted from cs231n assignment 3
  46. def _show_saliency_maps(self, A, I, y):
  47. # Convert X and y from numpy arrays to Torch Tensors
  48. I_tensor = torch.cat([
  49. BuildDataset.transformer(self.img_size)(Image.fromarray(i)).unsqueeze(0)
  50. for i in I], dim=0)
  51. A_tensor = torch.cat([a.unsqueeze(0) for a in A])
  52. y_tensor = torch.LongTensor(y)
  53. # Compute saliency maps for images in X
  54. saliency = self._compute_saliency_maps(A_tensor, I_tensor, y_tensor, self.model)
  55. # Convert the saliency map from Torch Tensor to numpy array and show images
  56. # and saliency maps together.
  57. saliency = saliency.numpy()
  58. N = len(I)
  59. for i in range(N):
  60. plt.subplot(2, N, i + 1)
  61. plt.imshow(I[i])
  62. plt.axis('off')
  63. plt.title(self.class_names[y[i]])
  64. plt.subplot(2, N, N + i + 1)
  65. plt.imshow(saliency[i], cmap=plt.cm.hot)
  66. plt.axis('off')
  67. plt.gcf().set_size_inches(12, 5)
  68. plt.show()
  69. @staticmethod
  70. def _img_transform_reverse_to_np(x: torch.Tensor) -> np.array:
  71. rev = BuildDataset.transform_reverse(x)
  72. return np.array(rev)
  73. def saliency_maps(self):
  74. for vid, indices in self.videos_and_frames.items():
  75. A = [self.features[vid][0][idx] for idx in indices]
  76. I = [self._img_transform_reverse_to_np(self.features[vid][1][idx])
  77. for idx in indices]
  78. y = [1 if 'kissing' in vid else 0] * len(A)
  79. self._show_saliency_maps(A, I, y)
  80. print('=' * 10)
  81. # next few methods taken from cs231n
  82. @staticmethod
  83. def jitter(X, ox, oy):
  84. """
  85. Helper function to randomly jitter an image.
  86. Inputs
  87. - X: PyTorch Tensor of shape (N, C, H, W)
  88. - ox, oy: Integers giving number of pixels to jitter along W and H axes
  89. Returns: A new PyTorch Tensor of shape (N, C, H, W)
  90. """
  91. if ox != 0:
  92. left = X[:, :, :, :-ox]
  93. right = X[:, :, :, -ox:]
  94. X = torch.cat([right, left], dim=3)
  95. if oy != 0:
  96. top = X[:, :, :-oy]
  97. bottom = X[:, :, -oy:]
  98. X = torch.cat([bottom, top], dim=2)
  99. return X
  100. def create_class_visualization(target_y, model, dtype, **kwargs):
  101. """
  102. Generate an image to maximize the score of target_y under a pretrained model.
  103. Inputs:
  104. - target_y: Integer in the range [0, 1000) giving the index of the class
  105. - model: A pretrained CNN that will be used to generate the image
  106. - dtype: Torch datatype to use for computations
  107. Keyword arguments:
  108. - l2_reg: Strength of L2 regularization on the image
  109. - learning_rate: How big of a step to take
  110. - num_iterations: How many iterations to use
  111. - blur_every: How often to blur the image as an implicit regularizer
  112. - max_jitter: How much to gjitter the image as an implicit regularizer
  113. - show_every: How often to show the intermediate result
  114. """
  115. model.type(dtype)
  116. l2_reg = kwargs.pop('l2_reg', 1e-3)
  117. learning_rate = kwargs.pop('learning_rate', 25)
  118. num_iterations = kwargs.pop('num_iterations', 100)
  119. blur_every = kwargs.pop('blur_every', 10)
  120. max_jitter = kwargs.pop('max_jitter', 16)
  121. show_every = kwargs.pop('show_every', 25)
  122. # Randomly initialize the image as a PyTorch Tensor, and make it requires gradient.
  123. img = torch.randn(1, 3, 224, 224).mul_(1.0).type(dtype).requires_grad_()
  124. for t in range(num_iterations):
  125. # Randomly jitter the image a bit; this gives slightly nicer results
  126. ox, oy = random.randint(0, max_jitter), random.randint(0, max_jitter)
  127. img.data.copy_(jitter(img.data, ox, oy))
  128. ########################################################################
  129. # TODO: Use the model to compute the gradient of the score for the #
  130. # class target_y with respect to the pixels of the image, and make a #
  131. # gradient step on the image using the learning rate. Don't forget the #
  132. # L2 regularization term! #
  133. # Be very careful about the signs of elements in your code. #
  134. ########################################################################
  135. # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
  136. target = model(img)[0, target_y]
  137. target.backward()
  138. g = img.grad.data
  139. g -= 2 * l2_reg * img.data
  140. img.data += learning_rate * (g / g.norm())
  141. img.grad.zero_()
  142. # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
  143. ########################################################################
  144. # END OF YOUR CODE #
  145. ########################################################################
  146. # Undo the random jitter
  147. img.data.copy_(jitter(img.data, -ox, -oy))
  148. # As regularizer, clamp and periodically blur the image
  149. for c in range(3):
  150. lo = float(-SQUEEZENET_MEAN[c] / SQUEEZENET_STD[c])
  151. hi = float((1.0 - SQUEEZENET_MEAN[c]) / SQUEEZENET_STD[c])
  152. img.data[:, c].clamp_(min=lo, max=hi)
  153. if t % blur_every == 0:
  154. blur_image(img.data, sigma=0.5)
  155. # Periodically show the image
  156. if t == 0 or (t + 1) % show_every == 0 or t == num_iterations - 1:
  157. plt.imshow(deprocess(img.data.clone().cpu()))
  158. class_name = class_names[target_y]
  159. plt.title('%s\nIteration %d / %d' % (class_name, t + 1, num_iterations))
  160. plt.gcf().set_size_inches(4, 4)
  161. plt.axis('off')
  162. plt.show()
  163. return deprocess(img.data.cpu())