qualitative.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from typing import Dict, List
  2. import matplotlib.pyplot as plt
  3. import torch
  4. from PIL import Image
  5. from torch import nn
  6. import numpy as np
  7. from pipeline import BuildDataset
  8. class QualitativeAnalysis:
  9. def __init__(self,
  10. model: nn.Module,
  11. img_size: int,
  12. videos_and_frames: Dict[str, List[int]],
  13. class_names: Dict[int, str]):
  14. self.class_names = class_names
  15. self.img_size = img_size
  16. self.model = model
  17. self.videos_and_frames = videos_and_frames
  18. self.features = {
  19. vid: BuildDataset.one_video_extract_audio_and_stills(vid)
  20. for vid in self.videos_and_frames}
  21. # method is adapted from stanford cs231n assignment 3 available at:
  22. # http://cs231n.github.io/assignments2019/assignment3/
  23. @staticmethod
  24. def _compute_saliency_maps(A, I, y, model):
  25. """
  26. Compute a class saliency map using the model for images X and labels y.
  27. Input:
  28. - A: Input audio; Tensor of shape (N, 1, 96, 64)
  29. - I: Input images; Tensor of shape (N, 3, H, W)
  30. - y: Labels for X; LongTensor of shape (N,)
  31. - model: A pretrained CNN that will be used to compute the saliency map.
  32. Returns:
  33. - saliency: A Tensor of shape (N, H, W) giving the saliency maps for the input
  34. images.
  35. """
  36. # Make sure the model is in "test" mode
  37. model.eval()
  38. # Make input tensor require gradient
  39. # A.requires_grad_()
  40. I.requires_grad_()
  41. scores = model(A, I).gather(1, y.view(-1, 1)).squeeze()
  42. scores.backward(torch.ones(scores.size()))
  43. saliency, _ = torch.max(I.grad.abs(), dim=1)
  44. return saliency
  45. # also adapted from cs231n assignment 3
  46. def _show_saliency_maps(self, A, I, y):
  47. # Convert X and y from numpy arrays to Torch Tensors
  48. I_tensor = torch.cat([
  49. BuildDataset.transformer(self.img_size)(Image.fromarray(i)).unsqueeze(0)
  50. for i in I], dim=0)
  51. A_tensor = torch.cat([a.unsqueeze(0) for a in A])
  52. y_tensor = torch.LongTensor(y)
  53. # Compute saliency maps for images in X
  54. saliency = self._compute_saliency_maps(A_tensor, I_tensor, y_tensor, self.model)
  55. # Convert the saliency map from Torch Tensor to numpy array and show images
  56. # and saliency maps together.
  57. saliency = saliency.numpy()
  58. N = len(I)
  59. for i in range(N):
  60. plt.subplot(2, N, i + 1)
  61. plt.imshow(I[i])
  62. plt.axis('off')
  63. plt.title(self.class_names[y[i]])
  64. plt.subplot(2, N, N + i + 1)
  65. plt.imshow(saliency[i], cmap=plt.cm.hot)
  66. plt.axis('off')
  67. plt.gcf().set_size_inches(12, 5)
  68. plt.show()
  69. @staticmethod
  70. def _img_transform_reverse_to_np(x: torch.Tensor) -> np.array:
  71. rev = BuildDataset.transform_reverse(x)
  72. return np.array(rev)
  73. def saliency_maps(self):
  74. for vid, indices in self.videos_and_frames.items():
  75. A = [self.features[vid][0][idx] for idx in indices]
  76. I = [self._img_transform_reverse_to_np(self.features[vid][1][idx])
  77. for idx in indices]
  78. y = [1 if 'kissing' in vid else 0] * len(A)
  79. self._show_saliency_maps(A, I, y)
  80. print('=' * 10)