Browse Source

saliency maps

Amir Ziai 5 years ago
parent
commit
fe6d68924e
3 changed files with 335 additions and 10 deletions
  1. 226 2
      dev6.ipynb
  2. 13 8
      pipeline.py
  3. 96 0
      qualitative.py

File diff suppressed because it is too large
+ 226 - 2
dev6.ipynb


+ 13 - 8
pipeline.py

@@ -88,7 +88,17 @@ class BuildDataset:
             transforms.ToPILImage()])(img)
 
     @staticmethod
-    def one_video_extract_audio_and_stills(path_video: str,
+    def transformer(img_size: int):
+        return transforms.Compose([
+            transforms.RandomResizedCrop(img_size),
+            transforms.RandomHorizontalFlip(),
+            transforms.ToTensor(),
+            transforms.Normalize(params.mean, params.std)
+        ])
+
+    @classmethod
+    def one_video_extract_audio_and_stills(cls,
+                                           path_video: str,
                                            img_size: int = 224) -> Tuple[List[torch.Tensor],
                                                                          List[torch.Tensor]]:
         # return a list of image(s), audio tensors
@@ -96,12 +106,7 @@ class BuildDataset:
         frame_rate = cap.get(5)
         images = []
 
-        transformer = transforms.Compose([
-            transforms.RandomResizedCrop(img_size),
-            transforms.RandomHorizontalFlip(),
-            transforms.ToTensor(),
-            transforms.Normalize(params.mean, params.std)
-        ])
+        transformer = cls.transformer(img_size)
 
         # process the image
         while cap.isOpened():
@@ -125,7 +130,7 @@ class BuildDataset:
 
         tmp_audio_file = 'tmp.wav'
         VideoFileClip(path_video).audio.write_audiofile(tmp_audio_file)
-        # TODO: fix if n_augment > 1 by duplicating each sample n_augment times
+        # fix if n_augment > 1 by duplicating each sample n_augment times
         audio = vggish_input.wavfile_to_examples(tmp_audio_file)
         # audio = audio[:, None, :, :]  # add dummy dimension for "channel"
         # audio = torch.from_numpy(audio).float()  # Convert input example to float

+ 96 - 0
qualitative.py

@@ -0,0 +1,96 @@
+from typing import Dict, List
+
+import matplotlib.pyplot as plt
+import torch
+from PIL import Image
+from torch import nn
+import numpy as np
+
+from pipeline import BuildDataset
+
+
+class QualitativeAnalysis:
+    def __init__(self,
+                 model: nn.Module,
+                 img_size: int,
+                 videos_and_frames: Dict[str, List[int]],
+                 class_names: Dict[int, str]):
+        self.class_names = class_names
+        self.img_size = img_size
+        self.model = model
+        self.videos_and_frames = videos_and_frames
+
+        self.features = {
+            vid: BuildDataset.one_video_extract_audio_and_stills(vid)
+            for vid in self.videos_and_frames}
+
+    # method is adapted from stanford cs231n assignment 3 available at:
+    # http://cs231n.github.io/assignments2019/assignment3/
+    @staticmethod
+    def _compute_saliency_maps(A, I, y, model):
+        """
+        Compute a class saliency map using the model for images X and labels y.
+
+        Input:
+        - A: Input audio; Tensor of shape (N, 1, 96, 64)
+        - I: Input images; Tensor of shape (N, 3, H, W)
+        - y: Labels for X; LongTensor of shape (N,)
+        - model: A pretrained CNN that will be used to compute the saliency map.
+
+        Returns:
+        - saliency: A Tensor of shape (N, H, W) giving the saliency maps for the input
+        images.
+        """
+        # Make sure the model is in "test" mode
+        model.eval()
+
+        # Make input tensor require gradient
+        # A.requires_grad_()
+        I.requires_grad_()
+
+        scores = model(A, I).gather(1, y.view(-1, 1)).squeeze()
+        scores.backward(torch.ones(scores.size()))
+        saliency, _ = torch.max(I.grad.abs(), dim=1)
+
+        return saliency
+
+    # also adapted from cs231n assignment 3
+    def _show_saliency_maps(self, A, I, y):
+        # Convert X and y from numpy arrays to Torch Tensors
+        I_tensor = torch.cat([
+            BuildDataset.transformer(self.img_size)(Image.fromarray(i)).unsqueeze(0)
+            for i in I], dim=0)
+        A_tensor = torch.cat([a.unsqueeze(0) for a in A])
+        y_tensor = torch.LongTensor(y)
+
+        # Compute saliency maps for images in X
+        saliency = self._compute_saliency_maps(A_tensor, I_tensor, y_tensor, self.model)
+
+        # Convert the saliency map from Torch Tensor to numpy array and show images
+        # and saliency maps together.
+        saliency = saliency.numpy()
+        N = len(I)
+        for i in range(N):
+            plt.subplot(2, N, i + 1)
+            plt.imshow(I[i])
+            plt.axis('off')
+            plt.title(self.class_names[y[i]])
+            plt.subplot(2, N, N + i + 1)
+            plt.imshow(saliency[i], cmap=plt.cm.hot)
+            plt.axis('off')
+            plt.gcf().set_size_inches(12, 5)
+        plt.show()
+
+    @staticmethod
+    def _img_transform_reverse_to_np(x: torch.Tensor) -> np.array:
+        rev = BuildDataset.transform_reverse(x)
+        return np.array(rev)
+
+    def saliency_maps(self):
+        for vid, indices in self.videos_and_frames.items():
+            A = [self.features[vid][0][idx] for idx in indices]
+            I = [self._img_transform_reverse_to_np(self.features[vid][1][idx])
+                 for idx in indices]
+            y = [1 if 'kissing' in vid else 0] * len(A)
+            self._show_saliency_maps(A, I, y)
+            print('=' * 10)

Some files were not shown because too many files changed in this diff