pipeline.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import math
  2. import os
  3. import pickle
  4. import shutil
  5. from typing import List, Tuple
  6. import cv2
  7. import numpy as np
  8. import torch
  9. from PIL import Image
  10. from moviepy.editor import VideoFileClip
  11. from torchvision import transforms
  12. import params
  13. import vggish_input
  14. class BuildDataset:
  15. def __init__(self,
  16. base_path: str,
  17. videos_and_labels: List[Tuple[str, str]],
  18. output_path: str,
  19. n_augment: int = 1,
  20. test_size: float = 1 / 3):
  21. assert 0 < test_size < 1
  22. self.videos_and_labels = videos_and_labels
  23. self.test_size = test_size
  24. self.output_path = output_path
  25. self.base_path = base_path
  26. self.n_augment = n_augment
  27. self.sets = ['train', 'val']
  28. def _get_set(self):
  29. return np.random.choice(self.sets, p=[1 - self.test_size, self.test_size])
  30. def build_dataset(self):
  31. # wipe
  32. for set_ in self.sets:
  33. path = f'{self.output_path}/{set_}'
  34. try:
  35. shutil.rmtree(path)
  36. except FileNotFoundError:
  37. pass
  38. os.makedirs(path)
  39. for file_name, label in self.videos_and_labels:
  40. name, _ = file_name.split('.')
  41. path = f'{self.base_path}/{file_name}'
  42. audio, images = self.one_video_extract_audio_and_stills(path)
  43. set_ = self._get_set()
  44. target = f"{self.output_path}/{set_}/{label}_{name}.pkl"
  45. pickle.dump((audio, images, label), open(target, 'wb'))
  46. @staticmethod
  47. def transform_reverse(img: torch.Tensor) -> Image:
  48. return transforms.Compose([
  49. transforms.Normalize(mean=[0, 0, 0], std=(1.0 / params.std).tolist()),
  50. transforms.Normalize(mean=(-params.mean).tolist(), std=[1, 1, 1]),
  51. transforms.ToPILImage()])(img)
  52. @staticmethod
  53. def transformer(img_size: int):
  54. return transforms.Compose([
  55. transforms.RandomResizedCrop(img_size),
  56. transforms.RandomHorizontalFlip(),
  57. transforms.ToTensor(),
  58. transforms.Normalize(params.mean, params.std)
  59. ])
  60. @classmethod
  61. def one_video_extract_audio_and_stills(cls,
  62. path_video: str,
  63. img_size: int = 224) -> Tuple[List[torch.Tensor],
  64. List[torch.Tensor]]:
  65. # return a list of image(s), audio tensors
  66. cap = cv2.VideoCapture(path_video)
  67. frame_rate = cap.get(5)
  68. images = []
  69. transformer = cls.transformer(img_size)
  70. # process the image
  71. while cap.isOpened():
  72. frame_id = cap.get(1)
  73. success, frame = cap.read()
  74. if not success:
  75. print('Something went wrong!')
  76. break
  77. if frame_id % math.floor(frame_rate * params.vggish_frame_rate) == 0:
  78. frame_pil = Image.fromarray(frame, mode='RGB')
  79. images.append(transformer(frame_pil))
  80. # images += [transformer(frame_pil) for _ in range(self.n_augment)]
  81. cap.release()
  82. # process the audio
  83. # TODO: hack to get around OpenMP error
  84. os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
  85. tmp_audio_file = 'tmp.wav'
  86. VideoFileClip(path_video).audio.write_audiofile(tmp_audio_file)
  87. # fix if n_augment > 1 by duplicating each sample n_augment times
  88. audio = vggish_input.wavfile_to_examples(tmp_audio_file)
  89. # audio = audio[:, None, :, :] # add dummy dimension for "channel"
  90. # audio = torch.from_numpy(audio).float() # Convert input example to float
  91. min_sizes = min(audio.shape[0], len(images))
  92. audio = [torch.from_numpy(audio[idx][None, :, :]).float() for idx in range(min_sizes)]
  93. images = images[:min_sizes]
  94. # images = [torch.from_numpy(img).permute((2, 0, 1)) for img in images[:min_sizes]]
  95. return audio, images