pipeline.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import math
  2. import os
  3. import pickle
  4. import shutil
  5. from typing import List, Tuple
  6. import cv2
  7. import numpy as np
  8. import torch
  9. from PIL import Image
  10. from moviepy.editor import VideoFileClip
  11. from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
  12. from torchvision import transforms
  13. import params
  14. import vggish_input
  15. VGGISH_FRAME_RATE = 0.96
  16. def slice_clips(segments, root, fps=2):
  17. for path, classes in segments.items():
  18. for cls, ts in classes.items():
  19. for i, (t1, t2) in enumerate(ts):
  20. set_ = np.random.choice(['train', 'val'], p=[2 / 3, 1 / 3])
  21. # get all the still frames
  22. file_name, ext = path.split('.')
  23. target = f"{root}{file_name}_{cls}_{i + 1}.{ext}"
  24. print(f'target: {target}')
  25. ffmpeg_extract_subclip(f'{root}{path}', t1, t2, targetname=target)
  26. vidcap = cv2.VideoCapture(target)
  27. vidcap.set(cv2.CAP_PROP_FPS, fps)
  28. print(cv2.CAP_PROP_FPS)
  29. success, image = vidcap.read()
  30. count = 0
  31. while success:
  32. frame_path = f'{root}casino/{set_}/{cls}/{file_name}_{i}_{count + 1}.jpg'
  33. # print(frame_path)
  34. cv2.imwrite(frame_path, image) # save frame as JPEG file
  35. success, image = vidcap.read()
  36. # print('Read a new frame: ', success)
  37. count += 1
  38. class BuildDataset:
  39. def __init__(self,
  40. base_path: str,
  41. videos_and_labels: List[Tuple[str, str]],
  42. output_path: str,
  43. n_augment: int = 1,
  44. test_size: float = 1 / 3):
  45. assert 0 < test_size < 1
  46. self.videos_and_labels = videos_and_labels
  47. self.test_size = test_size
  48. self.output_path = output_path
  49. self.base_path = base_path
  50. self.n_augment = n_augment
  51. self.sets = ['train', 'val']
  52. def _get_set(self):
  53. return np.random.choice(self.sets, p=[1 - self.test_size, self.test_size])
  54. def build_dataset(self):
  55. # wipe
  56. for set_ in self.sets:
  57. path = f'{self.output_path}/{set_}'
  58. try:
  59. shutil.rmtree(path)
  60. except FileNotFoundError:
  61. pass
  62. os.makedirs(path)
  63. for file_name, label in self.videos_and_labels:
  64. name, _ = file_name.split('.')
  65. path = f'{self.base_path}/{file_name}'
  66. audio, images = self.one_video_extract_audio_and_stills(path)
  67. set_ = self._get_set()
  68. target = f"{self.output_path}/{set_}/{label}_{name}.pkl"
  69. pickle.dump((audio, images, label), open(target, 'wb'))
  70. @staticmethod
  71. def transform_reverse(img: torch.Tensor) -> Image:
  72. return transforms.Compose([
  73. transforms.Normalize(mean=[0, 0, 0], std=(1.0 / params.std).tolist()),
  74. transforms.Normalize(mean=(-params.mean).tolist(), std=[1, 1, 1]),
  75. transforms.ToPILImage()])(img)
  76. @staticmethod
  77. def one_video_extract_audio_and_stills(path_video: str,
  78. img_size: int = 224) -> Tuple[List[torch.Tensor],
  79. List[torch.Tensor]]:
  80. # return a list of image(s), audio tensors
  81. cap = cv2.VideoCapture(path_video)
  82. frame_rate = cap.get(5)
  83. images = []
  84. transformer = transforms.Compose([
  85. transforms.RandomResizedCrop(img_size),
  86. transforms.RandomHorizontalFlip(),
  87. transforms.ToTensor(),
  88. transforms.Normalize(params.mean, params.std)
  89. ])
  90. # process the image
  91. while cap.isOpened():
  92. frame_id = cap.get(1)
  93. success, frame = cap.read()
  94. if not success:
  95. print('Something went wrong!')
  96. break
  97. if frame_id % math.floor(frame_rate * VGGISH_FRAME_RATE) == 0:
  98. frame_pil = Image.fromarray(frame, mode='RGB')
  99. images.append(transformer(frame_pil))
  100. # images += [transformer(frame_pil) for _ in range(self.n_augment)]
  101. cap.release()
  102. # process the audio
  103. # TODO: hack to get around OpenMP error
  104. os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
  105. tmp_audio_file = 'tmp.wav'
  106. VideoFileClip(path_video).audio.write_audiofile(tmp_audio_file)
  107. # TODO: fix if n_augment > 1 by duplicating each sample n_augment times
  108. audio = vggish_input.wavfile_to_examples(tmp_audio_file)
  109. # audio = audio[:, None, :, :] # add dummy dimension for "channel"
  110. # audio = torch.from_numpy(audio).float() # Convert input example to float
  111. min_sizes = min(audio.shape[0], len(images))
  112. audio = [torch.from_numpy(audio[idx][None, :, :]).float() for idx in range(min_sizes)]
  113. images = images[:min_sizes]
  114. # images = [torch.from_numpy(img).permute((2, 0, 1)) for img in images[:min_sizes]]
  115. return audio, images