Browse Source

runs, added f1

Amir Ziai 5 years ago
parent
commit
f0c0c2bfa6
3 changed files with 817 additions and 12 deletions
  1. 773 0
      dev3.ipynb
  2. 18 4
      pipeline.py
  3. 26 8
      train.py

File diff suppressed because it is too large
+ 773 - 0
dev3.ipynb


+ 18 - 4
pipeline.py

@@ -9,6 +9,8 @@ import numpy as np
 import torch
 from moviepy.editor import VideoFileClip
 from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
+from torchvision import transforms
+from PIL import Image
 
 import vggish_input
 
@@ -45,14 +47,25 @@ class BuildDataset:
                  base_path: str,
                  videos_and_labels: List[Tuple[str, str]],
                  output_path: str,
+                 n_augment: int=1,
                  test_size: float = 1 / 3):
         assert 0 < test_size < 1
         self.videos_and_labels = videos_and_labels
         self.test_size = test_size
         self.output_path = output_path
         self.base_path = base_path
+        self.n_augment = n_augment
 
         self.sets = ['train', 'val']
+        self.img_size = 224
+
+        self.transformer = transforms.Compose([
+            transforms.RandomResizedCrop(self.img_size),
+            transforms.RandomHorizontalFlip(),
+            transforms.ToTensor(),
+            # TODO: wtf?
+            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+        ])
 
     def _get_set(self):
         return np.random.choice(self.sets, p=[1 - self.test_size, self.test_size])
@@ -75,8 +88,7 @@ class BuildDataset:
             target = f"{self.output_path}/{set_}/{label}_{name}.pkl"
             pickle.dump((audio, images, label), open(target, 'wb'))
 
-    @staticmethod
-    def one_video_extract_audio_and_stills(path_video: str) -> Tuple[List[torch.Tensor],
+    def one_video_extract_audio_and_stills(self, path_video: str) -> Tuple[List[torch.Tensor],
                                                                      List[torch.Tensor]]:
         # return a list of image(s), audio tensors
         cap = cv2.VideoCapture(path_video)
@@ -93,7 +105,8 @@ class BuildDataset:
                 break
 
             if frame_id % math.floor(frame_rate * VGGISH_FRAME_RATE) == 0:
-                images.append(frame)
+                frame_pil = Image.fromarray(frame, mode='RGB')
+                images += [self.transformer(frame_pil) for _ in range(self.n_augment)]
 
         cap.release()
 
@@ -103,12 +116,13 @@ class BuildDataset:
 
         tmp_audio_file = 'tmp.wav'
         VideoFileClip(path_video).audio.write_audiofile(tmp_audio_file)
+        # TODO: fix if n_augment > 1 by duplicating each sample n_augment times
         audio = vggish_input.wavfile_to_examples(tmp_audio_file)
         # audio = audio[:, None, :, :]  # add dummy dimension for "channel"
         # audio = torch.from_numpy(audio).float()  # Convert input example to float
 
         min_sizes = min(audio.shape[0], len(images))
         audio = [torch.from_numpy(audio[idx][None, :, :]).float() for idx in range(min_sizes)]
-        images = [torch.from_numpy(img).permute((2, 1, 0)) for img in images[:min_sizes]]
+        # images = [torch.from_numpy(img).permute((2, 0, 1)) for img in images[:min_sizes]]
 
         return audio, images

+ 26 - 8
train.py

@@ -25,19 +25,20 @@ def _get_params_to_update(model: nn.Module,
     return params_to_update
 
 
-def train_kd(model_name: str,
+def train_kd(data_path_base: str,
+             model_name: str,
              num_epochs: int,
              feature_extract: bool,
              batch_size: int,
-             num_workers: int=4,
-             shuffle: bool=True,
-             lr: float=0.001,
-             momentum: float=0.9) -> Tuple[nn.Module, List[torch.Tensor]]:
+             num_workers: int = 4,
+             shuffle: bool = True,
+             lr: float = 0.001,
+             momentum: float = 0.9) -> Tuple[nn.Module, List[torch.Tensor]]:
     num_classes = 2
     kd = KissingDetector(model_name, num_classes, feature_extract)
     params_to_update = _get_params_to_update(kd, feature_extract)
 
-    datasets = {x: AudioVideo(x) for x in ['train', 'val']}
+    datasets = {set_: AudioVideo(f'{data_path_base}/{set_}') for set_ in ['train', 'val']}
     dataloaders_dict = {x: torch.utils.data.DataLoader(datasets[x],
                                                        batch_size=batch_size,
                                                        shuffle=shuffle, num_workers=num_workers)
@@ -60,6 +61,7 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
 
     best_model_wts = copy.deepcopy(model.state_dict())
     best_acc = 0.0
+    best_f1 = 0.0
 
     # Detect if we have a GPU available
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -77,6 +79,9 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
 
             running_loss = 0.0
             running_corrects = 0
+            running_tp = 0
+            running_fp = 0
+            running_fn = 0
 
             # Iterate over data.
             for a, v, labels in dataloaders[phase]:
@@ -114,15 +119,27 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
                 # statistics
                 running_loss += loss.item() * a.size(0)
                 running_corrects += torch.sum(preds == labels.data)
+                running_tp += torch.sum((preds == labels.data)[labels.data == 1])
+                running_fp += torch.sum((preds != labels.data)[labels.data == 1])
+                running_fn += torch.sum((preds != labels.data)[labels.data == 0])
 
             epoch_loss = running_loss / len(dataloaders[phase].dataset)
-            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
+            n = len(dataloaders[phase].dataset)
+            epoch_acc = running_corrects.double() / n
+            tp = running_tp.double()
+            fp = running_fp.double()
+            fn = running_fn.double()
+            p = tp / (tp + fp)
+            r = tp / (tp + fn)
+            epoch_f1 = 2 * p * r / (p + r)
 
-            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
+            print('{} Loss: {:.4f} F1: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_f1, epoch_acc))
 
             # deep copy the model
             if phase == 'val' and epoch_acc > best_acc:
                 best_acc = epoch_acc
+            if phase == 'val' and epoch_f1 > best_f1:
+                best_f1 = epoch_f1
                 best_model_wts = copy.deepcopy(model.state_dict())
             if phase == 'val':
                 val_acc_history.append(epoch_acc)
@@ -131,6 +148,7 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
 
     time_elapsed = time.time() - since
     print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
+    print('Best val F1 : {:4f}'.format(best_f1))
     print('Best val Acc: {:4f}'.format(best_acc))
 
     # load best model weights

Some files were not shown because too many files changed in this diff