pipeline.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import math
  2. import os
  3. import pickle
  4. import shutil
  5. from typing import List, Tuple
  6. import cv2
  7. import numpy as np
  8. import torch
  9. from moviepy.editor import VideoFileClip
  10. from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
  11. import vggish_input
  12. VGGISH_FRAME_RATE = 0.96
  13. def slice_clips(segments, root, fps=2):
  14. for path, classes in segments.items():
  15. for cls, ts in classes.items():
  16. for i, (t1, t2) in enumerate(ts):
  17. set_ = np.random.choice(['train', 'val'], p=[2 / 3, 1 / 3])
  18. # get all the still frames
  19. file_name, ext = path.split('.')
  20. target = f"{root}{file_name}_{cls}_{i + 1}.{ext}"
  21. print(f'target: {target}')
  22. ffmpeg_extract_subclip(f'{root}{path}', t1, t2, targetname=target)
  23. vidcap = cv2.VideoCapture(target)
  24. vidcap.set(cv2.CAP_PROP_FPS, fps)
  25. print(cv2.CAP_PROP_FPS)
  26. success, image = vidcap.read()
  27. count = 0
  28. while success:
  29. frame_path = f'{root}casino/{set_}/{cls}/{file_name}_{i}_{count + 1}.jpg'
  30. # print(frame_path)
  31. cv2.imwrite(frame_path, image) # save frame as JPEG file
  32. success, image = vidcap.read()
  33. # print('Read a new frame: ', success)
  34. count += 1
  35. class BuildDataset:
  36. def __init__(self,
  37. base_path: str,
  38. videos_and_labels: List[Tuple[str, str]],
  39. output_path: str,
  40. test_size: float = 1 / 3):
  41. assert 0 < test_size < 1
  42. self.videos_and_labels = videos_and_labels
  43. self.test_size = test_size
  44. self.output_path = output_path
  45. self.base_path = base_path
  46. self.sets = ['train', 'val']
  47. def _get_set(self):
  48. return np.random.choice(self.sets, p=[1 - self.test_size, self.test_size])
  49. def build_dataset(self):
  50. # wipe
  51. for set_ in self.sets:
  52. path = f'{self.output_path}/{set_}'
  53. try:
  54. shutil.rmtree(path)
  55. except FileNotFoundError:
  56. pass
  57. os.makedirs(path)
  58. for file_name, label in self.videos_and_labels:
  59. name, _ = file_name.split('.')
  60. path = f'{self.base_path}/{file_name}'
  61. audio, images = self.one_video_extract_audio_and_stills(path)
  62. set_ = self._get_set()
  63. target = f"{self.output_path}/{set_}/{label}_{name}.pkl"
  64. pickle.dump((audio, images, label), open(target, 'wb'))
  65. @staticmethod
  66. def one_video_extract_audio_and_stills(path_video: str) -> Tuple[List[torch.Tensor],
  67. List[torch.Tensor]]:
  68. # return a list of image(s), audio tensors
  69. cap = cv2.VideoCapture(path_video)
  70. frame_rate = cap.get(5)
  71. images = []
  72. # process the image
  73. while cap.isOpened():
  74. frame_id = cap.get(1)
  75. success, frame = cap.read()
  76. if not success:
  77. print('Something went wrong!')
  78. break
  79. if frame_id % math.floor(frame_rate * VGGISH_FRAME_RATE) == 0:
  80. images.append(frame)
  81. cap.release()
  82. # process the audio
  83. # TODO: hack to get around OpenMP error
  84. os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
  85. tmp_audio_file = 'tmp.wav'
  86. VideoFileClip(path_video).audio.write_audiofile(tmp_audio_file)
  87. audio = vggish_input.wavfile_to_examples(tmp_audio_file)
  88. # audio = audio[:, None, :, :] # add dummy dimension for "channel"
  89. # audio = torch.from_numpy(audio).float() # Convert input example to float
  90. min_sizes = min(audio.shape[0], len(images))
  91. audio = [torch.from_numpy(audio[idx][None, :, :]).float() for idx in range(min_sizes)]
  92. images = [torch.from_numpy(img).permute((2, 1, 0)) for img in images[:min_sizes]]
  93. return audio, images