Browse Source

3d conv seems to be running

Amir Ziai 5 years ago
parent
commit
cd87f34290
4 changed files with 89 additions and 30 deletions
  1. 63 23
      data.py
  2. 2 1
      kissing_detector.py
  3. 13 1
      params.py
  4. 11 5
      train.py

+ 63 - 23
data.py

@@ -4,13 +4,75 @@ import json
 import os
 import pickle
 from glob import glob
-from typing import Tuple
+from typing import Tuple, List
 
 import torch
 import torch.utils.data as data
 from PIL import Image
 
 
+class AV(data.Dataset):
+    def __init__(self, path: str):
+        self.path = path
+        self.data = []
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
+        return self.data[idx]
+
+
+class AudioVideo(AV):
+    def __init__(self, path: str):
+        # output format:
+        # return (
+        #     torch.rand((1, 96, 64)),
+        #     torch.rand((3, 224, 224)),
+        #     np.random.choice([0, 1])
+        # )
+        super().__init__(path)
+
+        for file_path in glob(f'{path}/*.pkl'):
+            audios, images, label = pickle.load(open(file_path, 'rb'))
+            self.data += [(audios[i], images[i], label) for i in range(len(audios))]
+
+
+class AudioVideo3D(AV):
+    def __init__(self, path: str):
+        # output format:
+        # return (
+        #     torch.rand((1, 96, 64)),
+        #     torch.rand((3, 16, 224, 224)),
+        #     np.random.choice([0, 1])
+        # )
+        super().__init__(path)
+        frames = 16
+
+        for file_path in glob(f'{path}/*.pkl'):
+            audios, images, label = pickle.load(open(file_path, 'rb'))
+            images_temporal = self._process_temporal_tensor(images, frames)
+            self.data += [(audios[i], images_temporal[i], label) for i in range(len(audios))]
+
+    @staticmethod
+    def _process_temporal_tensor(images: List[torch.Tensor],
+                                 frames: int) -> List[torch.Tensor]:
+        out = []
+
+        for i in range(len(images)):
+            e = torch.zeros((frames, 3, 224, 224))
+            e[-1] = images[0]
+            for j in range(min(i, frames)):
+                e[-1 - j] = images[j]
+                # try:
+                #     e[-1 - j] = images[j]
+                # except:
+                #     raise ValueError(f"trying to get {i} from images with len = {len(images)}")
+            ee = e.permute((1, 0, 2, 3))
+            out.append(ee)
+        return out
+
+
 def pil_loader(path):
     # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
     with open(path, 'rb') as f:
@@ -139,25 +201,3 @@ class Video(data.Dataset):
 
     def __len__(self):
         return len(self.data)
-
-
-class AudioVideo(data.Dataset):
-    def __init__(self, path: str):
-        self.path = path
-        self.data = []
-
-        for file_path in glob(f'{path}/*.pkl'):
-            audios, images, label = pickle.load(open(file_path, 'rb'))
-            self.data += [(audios[i], images[i], label) for i in range(len(audios))]
-
-    def __len__(self):
-        return len(self.data)
-
-    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
-        # output format:
-        # return (
-        #     torch.rand((1, 96, 64)),
-        #     torch.rand((3, 224, 224)),
-        #     np.random.choice([0, 1])
-        # )
-        return self.data[idx]

+ 2 - 1
kissing_detector.py

@@ -4,7 +4,7 @@ import torch
 from torch import nn
 
 import vggish
-from conv import convnet_init
+from conv import convnet_init, set_parameter_requires_grad
 import conv3d
 
 
@@ -67,6 +67,7 @@ class KissingDetector3DConv(nn.Module):
             sample_size=224,
             sample_duration=10
         )
+        set_parameter_requires_grad(conv, feature_extract)
         conv.fc = nn.Identity()
 
         if use_vggish:

+ 13 - 1
params.py

@@ -16,10 +16,22 @@ experiment_test = {
     'feature_extract': {True},
     'batch_size': {64},
     'lr': {0.001},
-    'use_vggish': {False},
+    'use_vggish': {True},
     'momentum': {0.9}
 }
 
+experiment_test_3d = {
+    'data_path_base': {data_path_base},
+    'conv_model_name': {'resnet'},
+    'num_epochs': {10},
+    'feature_extract': {True},
+    'batch_size': {64},
+    'lr': {0.001},
+    'use_vggish': {True},
+    'momentum': {0.9},
+    'use_3d': {True}
+}
+
 experiments = {
     'data_path_base': {data_path_base},
     'conv_model_name': {'resnet', None},  # vgg

+ 11 - 5
train.py

@@ -6,8 +6,8 @@ import torch
 import torch.optim as optim
 from torch import nn
 
-from data import AudioVideo
-from kissing_detector import KissingDetector
+from data import AudioVideo, AudioVideo3D
+from kissing_detector import KissingDetector, KissingDetector3DConv
 
 ExperimentResults = Tuple[Optional[nn.Module], List[float], List[float]]
 
@@ -36,17 +36,23 @@ def train_kd(data_path_base: str,
              num_workers: int = 4,
              shuffle: bool = True,
              lr: float = 0.001,
-             momentum: float = 0.9) -> ExperimentResults:
+             momentum: float = 0.9,
+             use_3d: bool = False) -> ExperimentResults:
     num_classes = 2
     try:
-        kd = KissingDetector(conv_model_name, num_classes, feature_extract, use_vggish=use_vggish)
+        if use_3d:
+            kd = KissingDetector3DConv(num_classes, feature_extract, use_vggish)
+        else:
+            kd = KissingDetector(conv_model_name, num_classes, feature_extract, use_vggish=use_vggish)
     except ValueError:
         # if the combination is not valid
         return None, [-1.0], [-1.0]
 
     params_to_update = _get_params_to_update(kd, feature_extract)
 
-    datasets = {set_: AudioVideo(f'{data_path_base}/{set_}') for set_ in ['train', 'val']}
+    av = AudioVideo3D if use_3d else AudioVideo
+
+    datasets = {set_: av(f'{data_path_base}/{set_}') for set_ in ['train', 'val']}
     dataloaders_dict = {x: torch.utils.data.DataLoader(datasets[x],
                                                        batch_size=batch_size,
                                                        shuffle=shuffle, num_workers=num_workers)