segmentor.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from typing import List, Tuple
  2. import random
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import torch
  6. from PIL.Image import Image
  7. from matplotlib.pyplot import figure, imshow, axis
  8. from torch import nn
  9. from pipeline import BuildDataset
  10. # images constituting a segments and the length in seconds
  11. Segment = Tuple[List[Image], int]
  12. class Segmentor:
  13. def __init__(self,
  14. model: nn.Module,
  15. min_frames: int,
  16. threshold: float):
  17. self.model = model
  18. self.min_frames = min_frames
  19. self.threshold = threshold
  20. @staticmethod
  21. def _segmentor(preds: List[int],
  22. min_frames: int,
  23. threshold: float) -> List[List[int]]:
  24. candidates = []
  25. n = len(preds)
  26. for idx_start in range(n):
  27. if preds[idx_start] == 1:
  28. if n - idx_start >= min_frames:
  29. best_here = (-1, (-1, -1))
  30. for idx_end in range(idx_start + min_frames - 1, len(preds)):
  31. if preds[idx_end] == 1:
  32. if np.mean(preds[idx_start:idx_end + 1]) >= threshold:
  33. frames = idx_end - idx_start + 1
  34. endpoints = (idx_start, idx_end)
  35. if frames > best_here[0]:
  36. best_here = (frames, endpoints)
  37. if best_here[0] > 0:
  38. candidates.append(best_here[1])
  39. overlap = True
  40. while overlap:
  41. overlap = False
  42. for i in range(len(candidates)):
  43. ref_idx_start, ref_idx_end = candidates[i]
  44. for j in range(i + 1, len(candidates)):
  45. comp_idx_start, comp_idx_end = candidates[j]
  46. if ref_idx_start <= comp_idx_end <= ref_idx_end or ref_idx_start <= comp_idx_start <= ref_idx_end:
  47. # overlapping, take the longer one
  48. if comp_idx_end - comp_idx_end > ref_idx_end - ref_idx_start:
  49. del candidates[i]
  50. else:
  51. del candidates[j]
  52. overlap = True
  53. if overlap:
  54. break
  55. if overlap:
  56. break
  57. return [list(range(idx_start, idx_end + 1)) for idx_start, idx_end in candidates]
  58. @staticmethod
  59. def _torch_img_to_pil(img: torch.Tensor) -> Image:
  60. return BuildDataset.transform_reverse(img)
  61. @staticmethod
  62. def _get_segment_len(indices: List[int]):
  63. return max(indices) - min(indices) + 1
  64. def segmentor(self, preds: List[int], images: List[torch.Tensor]) -> List[Segment]:
  65. segment_list = self._segmentor(preds, self.min_frames, self.threshold)
  66. return [
  67. ([self._torch_img_to_pil(images[idx])
  68. for idx in segment_idx], self._get_segment_len(segment_idx))
  69. for segment_idx in segment_list]
  70. def _predict(self, audio: torch.Tensor, image: torch.Tensor) -> int:
  71. return int(torch.max(self.model(audio.unsqueeze(0), image.unsqueeze(0)), 1)[1][0])
  72. def get_segments(self, path_video: str) -> List[Segment]:
  73. audio, images = BuildDataset.one_video_extract_audio_and_stills(path_video)
  74. preds = [self._predict(audio[idx], images[idx]) for idx in range(len(images))]
  75. return self.segmentor(preds, images)
  76. @staticmethod
  77. def show_images_horizontally(images: List[Image]) -> None:
  78. # https://stackoverflow.com/questions/36006136/how-to-display-images-in-a-row-with-ipython-display
  79. fig = figure(figsize=(20, 20))
  80. number_of_files = len(images)
  81. for i in range(number_of_files):
  82. a = fig.add_subplot(1, number_of_files, i + 1)
  83. image = images[i]
  84. imshow(image)
  85. axis('off')
  86. plt.show()
  87. def visualize_segments(self, path_video: str, n_to_show: int=10) -> None:
  88. segments = self.get_segments(path_video)
  89. n_segments = len(segments)
  90. print(f'Found {len(segments)} segments')
  91. if n_segments > 0:
  92. for i, (segment_images, segment_len) in enumerate(segments):
  93. print(f'Segment {i + 1}, {segment_len} seconds')
  94. print(f'First {n_to_show}')
  95. self.show_images_horizontally(segment_images[:n_to_show])
  96. print(f'{n_to_show} random shots')
  97. self.show_images_horizontally(random.sample(segment_images, n_to_show))
  98. print('Last 10')
  99. self.show_images_horizontally(segment_images[-n_to_show:])
  100. print('=' * 10)