pipeline.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import math
  2. import os
  3. import pickle
  4. import shutil
  5. from typing import List, Tuple
  6. import cv2
  7. import numpy as np
  8. import torch
  9. from moviepy.editor import VideoFileClip
  10. from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
  11. from torchvision import transforms
  12. from PIL import Image
  13. import vggish_input
  14. VGGISH_FRAME_RATE = 0.96
  15. def slice_clips(segments, root, fps=2):
  16. for path, classes in segments.items():
  17. for cls, ts in classes.items():
  18. for i, (t1, t2) in enumerate(ts):
  19. set_ = np.random.choice(['train', 'val'], p=[2 / 3, 1 / 3])
  20. # get all the still frames
  21. file_name, ext = path.split('.')
  22. target = f"{root}{file_name}_{cls}_{i + 1}.{ext}"
  23. print(f'target: {target}')
  24. ffmpeg_extract_subclip(f'{root}{path}', t1, t2, targetname=target)
  25. vidcap = cv2.VideoCapture(target)
  26. vidcap.set(cv2.CAP_PROP_FPS, fps)
  27. print(cv2.CAP_PROP_FPS)
  28. success, image = vidcap.read()
  29. count = 0
  30. while success:
  31. frame_path = f'{root}casino/{set_}/{cls}/{file_name}_{i}_{count + 1}.jpg'
  32. # print(frame_path)
  33. cv2.imwrite(frame_path, image) # save frame as JPEG file
  34. success, image = vidcap.read()
  35. # print('Read a new frame: ', success)
  36. count += 1
  37. class BuildDataset:
  38. def __init__(self,
  39. base_path: str,
  40. videos_and_labels: List[Tuple[str, str]],
  41. output_path: str,
  42. n_augment: int=1,
  43. test_size: float = 1 / 3):
  44. assert 0 < test_size < 1
  45. self.videos_and_labels = videos_and_labels
  46. self.test_size = test_size
  47. self.output_path = output_path
  48. self.base_path = base_path
  49. self.n_augment = n_augment
  50. self.sets = ['train', 'val']
  51. self.img_size = 224
  52. self.transformer = transforms.Compose([
  53. transforms.RandomResizedCrop(self.img_size),
  54. transforms.RandomHorizontalFlip(),
  55. transforms.ToTensor(),
  56. # TODO: wtf?
  57. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  58. ])
  59. def _get_set(self):
  60. return np.random.choice(self.sets, p=[1 - self.test_size, self.test_size])
  61. def build_dataset(self):
  62. # wipe
  63. for set_ in self.sets:
  64. path = f'{self.output_path}/{set_}'
  65. try:
  66. shutil.rmtree(path)
  67. except FileNotFoundError:
  68. pass
  69. os.makedirs(path)
  70. for file_name, label in self.videos_and_labels:
  71. name, _ = file_name.split('.')
  72. path = f'{self.base_path}/{file_name}'
  73. audio, images = self.one_video_extract_audio_and_stills(path)
  74. set_ = self._get_set()
  75. target = f"{self.output_path}/{set_}/{label}_{name}.pkl"
  76. pickle.dump((audio, images, label), open(target, 'wb'))
  77. def one_video_extract_audio_and_stills(self, path_video: str) -> Tuple[List[torch.Tensor],
  78. List[torch.Tensor]]:
  79. # return a list of image(s), audio tensors
  80. cap = cv2.VideoCapture(path_video)
  81. frame_rate = cap.get(5)
  82. images = []
  83. # process the image
  84. while cap.isOpened():
  85. frame_id = cap.get(1)
  86. success, frame = cap.read()
  87. if not success:
  88. print('Something went wrong!')
  89. break
  90. if frame_id % math.floor(frame_rate * VGGISH_FRAME_RATE) == 0:
  91. frame_pil = Image.fromarray(frame, mode='RGB')
  92. images += [self.transformer(frame_pil) for _ in range(self.n_augment)]
  93. cap.release()
  94. # process the audio
  95. # TODO: hack to get around OpenMP error
  96. os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
  97. tmp_audio_file = 'tmp.wav'
  98. VideoFileClip(path_video).audio.write_audiofile(tmp_audio_file)
  99. # TODO: fix if n_augment > 1 by duplicating each sample n_augment times
  100. audio = vggish_input.wavfile_to_examples(tmp_audio_file)
  101. # audio = audio[:, None, :, :] # add dummy dimension for "channel"
  102. # audio = torch.from_numpy(audio).float() # Convert input example to float
  103. min_sizes = min(audio.shape[0], len(images))
  104. audio = [torch.from_numpy(audio[idx][None, :, :]).float() for idx in range(min_sizes)]
  105. # images = [torch.from_numpy(img).permute((2, 0, 1)) for img in images[:min_sizes]]
  106. return audio, images