Browse Source

more cleanup

Amir Ziai 4 years ago
parent
commit
ac2fe34279
6 changed files with 112 additions and 166 deletions
  1. 26 1
      README.md
  2. 0 135
      data.py
  3. 82 0
      examples/detector.ipynb
  4. 1 1
      kissing_detector.py
  5. 1 0
      params.py
  6. 2 29
      pipeline.py

+ 26 - 1
README.md

@@ -9,6 +9,9 @@ Use Python 3.6+
 python3 experiments.py
 ```
 
+## Requirements
+This is a PyTorch project. Look at `requirements.txt` for more details. 
+
 this will run the experiments in `params.py` specified by the `experiments` dictionary.
 
 ## Build dataset
@@ -40,6 +43,28 @@ builder = BuildDataset(base_path='path/to/movies',
 builder.build_dataset()
 ```
 
+## Detect kissing segments in a given video
+```python
+from segmentor import Segmentor
+import utils
+
+# download model.pkl from https://drive.google.com/file/d/1RlvvdInTXtJikGv_ZbHcKoblCypN1Z0A/view?usp=sharing
+# or train your own
+model = utils.unpickle('model.pkl')  # pickled PyTorch model 
+s = Segmentor(model, min_frames=10, threshold=0.7)
+
+# For YouTube clip Hot Summer Nights - Kiss Scene (Maika Monroe and Timothee Chalamet)
+# at https://www.youtube.com/watch?v=GG5HmLQ_Fx0
+# v=XXX is the YouTube ID, pass that here 
+s.visualize_segments_youtube('GG5HmLQ_Fx0')
+
+# alternatively you can provide a path to a local mp4 file
+s.visualize_segments('path/to/file.mp4')
+```
+
+See examples in [examples/detector.ipynb](examples/detector.ipynb).
+
 ## Heavily used the following resources:
 - [Video Classification Using 3D ResNet](https://github.com/kenshohara/video-classification-3d-cnn-pytorch)
-- [CS231N assignment 3](http://cs231n.github.io/assignments2019/assignment3/)
+- [AudioSet](https://research.google.com/audioset/download.html)
+- [CS231N Saliency maps and class viz PyTorch code](http://cs231n.github.io/assignments2019/assignment3/)

+ 0 - 135
data.py

@@ -1,14 +1,9 @@
-import copy
-import functools
-import json
-import os
 import pickle
 from glob import glob
 from typing import Tuple, List
 
 import torch
 import torch.utils.data as data
-from PIL import Image
 
 
 class AV(data.Dataset):
@@ -71,133 +66,3 @@ class AudioVideo3D(AV):
             ee = e.permute((1, 0, 2, 3))
             out.append(ee)
         return out
-
-
-def pil_loader(path):
-    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
-    with open(path, 'rb') as f:
-        with Image.open(f) as img:
-            return img.convert('RGB')
-
-
-def accimage_loader(path):
-    # try:
-    #     return accimage.Image(path)
-    # except IOError:
-    #     # Potentially a decoding problem, fall back to PIL.Image
-    #     return pil_loader(path)
-    return pil_loader(path)
-
-
-def get_default_image_loader():
-    from torchvision import get_image_backend
-    if get_image_backend() == 'accimage':
-        return accimage_loader
-    else:
-        return pil_loader
-
-
-def video_loader(video_dir_path, frame_indices, image_loader):
-    video = []
-    for i in frame_indices:
-        image_path = os.path.join(video_dir_path, 'image_{:05d}.jpg'.format(i))
-        if os.path.exists(image_path):
-            video.append(image_loader(image_path))
-        else:
-            return video
-
-    return video
-
-
-def get_default_video_loader():
-    image_loader = get_default_image_loader()
-    return functools.partial(video_loader, image_loader=image_loader)
-
-
-def load_annotation_data(data_file_path):
-    with open(data_file_path, 'r') as data_file:
-        return json.load(data_file)
-
-
-def get_class_labels(data):
-    class_labels_map = {}
-    index = 0
-    for class_label in data['labels']:
-        class_labels_map[class_label] = index
-        index += 1
-    return class_labels_map
-
-
-def get_video_names_and_annotations(data, subset):
-    video_names = []
-    annotations = []
-
-    for key, value in data['database'].items():
-        this_subset = value['subset']
-        if this_subset == subset:
-            if subset == 'testing':
-                video_names.append('test/{}'.format(key))
-            else:
-                label = value['annotations']['label']
-                video_names.append('{}/{}'.format(label, key))
-                annotations.append(value['annotations'])
-
-    return video_names, annotations
-
-
-def make_dataset(video_path, sample_duration):
-    dataset = []
-
-    n_frames = len(os.listdir(video_path))
-
-    begin_t = 1
-    end_t = n_frames
-    sample = {
-        'video': video_path,
-        'segment': [begin_t, end_t],
-        'n_frames': n_frames,
-    }
-
-    step = sample_duration
-    for i in range(1, (n_frames - sample_duration + 1), step):
-        sample_i = copy.deepcopy(sample)
-        sample_i['frame_indices'] = list(range(i, i + sample_duration))
-        sample_i['segment'] = torch.IntTensor([i, i + sample_duration - 1])
-        dataset.append(sample_i)
-
-    return dataset
-
-
-class Video(data.Dataset):
-    def __init__(self, video_path,
-                 spatial_transform=None, temporal_transform=None,
-                 sample_duration=16, get_loader=get_default_video_loader):
-        self.data = make_dataset(video_path, sample_duration)
-
-        self.spatial_transform = spatial_transform
-        self.temporal_transform = temporal_transform
-        self.loader = get_loader()
-
-    def __getitem__(self, index):
-        """
-        Args:
-            index (int): Index
-        Returns:
-            tuple: (image, target) where target is class_index of the target class.
-        """
-        path = self.data[index]['video']
-
-        frame_indices = self.data[index]['frame_indices']
-        if self.temporal_transform is not None:
-            frame_indices = self.temporal_transform(frame_indices)
-        clip = self.loader(path, frame_indices)
-        if self.spatial_transform is not None:
-            clip = [self.spatial_transform(img) for img in clip]
-        clip = torch.stack(clip, 0).permute(1, 0, 2, 3)
-
-        target = self.data[index]['segment']
-
-        return clip, target
-
-    def __len__(self):
-        return len(self.data)

File diff suppressed because it is too large
+ 82 - 0
examples/detector.ipynb


+ 1 - 1
kissing_detector.py

@@ -65,7 +65,7 @@ class KissingDetector3DConv(nn.Module):
             num_classes=num_classes,
             shortcut_type='B',
             sample_size=224,
-            sample_duration=10
+            sample_duration=16
         )
         set_parameter_requires_grad(conv, feature_extract)
         conv.fc = nn.Identity()

+ 1 - 0
params.py

@@ -7,6 +7,7 @@ data_path_base = 'vtest_new2'
 
 mean = np.array([0.485, 0.456, 0.406])
 std = np.array([0.229, 0.224, 0.225])
+vggish_frame_rate = 0.96
 
 # test end-to-end
 experiment_test = {

+ 2 - 29
pipeline.py

@@ -9,39 +9,12 @@ import numpy as np
 import torch
 from PIL import Image
 from moviepy.editor import VideoFileClip
-from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
+
 from torchvision import transforms
 
 import params
 import vggish_input
 
-VGGISH_FRAME_RATE = 0.96
-
-
-def slice_clips(segments, root, fps=2):
-    for path, classes in segments.items():
-
-        for cls, ts in classes.items():
-            for i, (t1, t2) in enumerate(ts):
-                set_ = np.random.choice(['train', 'val'], p=[2 / 3, 1 / 3])
-                # get all the still frames
-                file_name, ext = path.split('.')
-                target = f"{root}{file_name}_{cls}_{i + 1}.{ext}"
-                print(f'target: {target}')
-                ffmpeg_extract_subclip(f'{root}{path}', t1, t2, targetname=target)
-                vidcap = cv2.VideoCapture(target)
-                vidcap.set(cv2.CAP_PROP_FPS, fps)
-                print(cv2.CAP_PROP_FPS)
-                success, image = vidcap.read()
-                count = 0
-                while success:
-                    frame_path = f'{root}casino/{set_}/{cls}/{file_name}_{i}_{count + 1}.jpg'
-                    # print(frame_path)
-                    cv2.imwrite(frame_path, image)  # save frame as JPEG file
-                    success, image = vidcap.read()
-                    # print('Read a new frame: ', success)
-                    count += 1
-
 
 class BuildDataset:
     def __init__(self,
@@ -117,7 +90,7 @@ class BuildDataset:
                 print('Something went wrong!')
                 break
 
-            if frame_id % math.floor(frame_rate * VGGISH_FRAME_RATE) == 0:
+            if frame_id % math.floor(frame_rate * params.vggish_frame_rate) == 0:
                 frame_pil = Image.fromarray(frame, mode='RGB')
                 images.append(transformer(frame_pil))
                 # images += [transformer(frame_pil) for _ in range(self.n_augment)]

Some files were not shown because too many files changed in this diff