Browse Source

add experiment runner

Amir Ziai 4 years ago
parent
commit
efcbf4099a
6 changed files with 149 additions and 4 deletions
  1. 73 0
      experiments.py
  2. 2 2
      kissing_detector.py
  3. 27 0
      params.py
  4. 2 0
      requirements.txt
  5. 5 2
      train.py
  6. 40 0
      utils.py

+ 73 - 0
experiments.py

@@ -0,0 +1,73 @@
+import os
+from datetime import datetime
+from itertools import product
+from typing import Dict, Set, List, Any, Tuple
+
+import pandas as pd
+from joblib import Parallel, delayed
+
+import params
+from train import train_kd, ExperimentResults
+from utils import log, merge_dicts, pickle_object, unpickle, hash_dict
+
+ParamSet = Dict[str, Any]
+ParamGrid = List[ParamSet]
+RunnerUUID = str
+
+
+class ExperimentRunner:
+    def __init__(self, experiment_parameters: Dict[str, Set[Any]], n_jobs: int):
+        self.experiment_parameters = experiment_parameters
+        self.n_jobs = n_jobs
+        self.timestamp: str = datetime.now().strftime('%Y%m%d%H%M%S')
+
+        os.makedirs('results/', exist_ok=True)
+
+    @staticmethod
+    def _get_param_grid(parameters: Dict[str, Set[Any]]) -> ParamGrid:
+        return [dict(zip(parameters.keys(), t)) for t in product(*parameters.values())]
+
+    @staticmethod
+    def _file_path_experiment_results(runner_uuid: RunnerUUID) -> str:
+        return f'results/{runner_uuid}_experiment_results.pkl'
+
+    def _experiment_result_exists(self, runner_uuid: RunnerUUID) -> bool:
+        return os.path.isfile(self._file_path_experiment_results(runner_uuid))
+
+    def _param_run(self, param_set: ParamSet) -> Tuple[ExperimentResults, RunnerUUID]:
+        log(f'Running param set: {param_set}')
+
+        uuid = hash_dict(param_set)
+
+        if self._experiment_result_exists(uuid):
+            log('Loading experiment results from cache')
+            experiment_results = unpickle(self._file_path_experiment_results(uuid))
+        else:
+            experiment_results = train_kd(**param_set)
+            pickle_object(experiment_results, self._file_path_experiment_results(uuid))
+
+        return experiment_results, uuid
+
+    @staticmethod
+    def _get_dict_from_results(results: ExperimentResults) -> Dict:
+        _, accs, f1s = results
+        return {'val_acc': max(accs), 'val_f1': max(f1s)}
+
+    def run(self):
+        param_grid = self._get_param_grid(self.experiment_parameters)
+        if self.n_jobs > 1:
+            run_output = Parallel(n_jobs=self.n_jobs)(delayed(self._param_run)(param) for param in param_grid)
+        else:
+            run_output = [self._param_run(param) for param in param_grid]
+        results_enriched = [
+            merge_dicts(self._get_dict_from_results(result), param_set,
+                        {'runner_uuid': runner_uuid},
+                        {'experiment_uuid': self.timestamp})
+            for (result, runner_uuid), param_set in zip(run_output, param_grid)
+        ]
+        pd.DataFrame(results_enriched).to_csv(f'results/results_{self.timestamp}.csv', index=False)
+
+
+if __name__ == '__main__':
+    experiment1 = ExperimentRunner(params.experiment1_test, n_jobs=params.n_jobs)
+    experiment1.run()

+ 2 - 2
kissing_detector.py

@@ -39,9 +39,9 @@ class KissingDetector(nn.Module):
         a = self.vggish(audio) if self.vggish else None
         c = self.conv(image) if self.conv else None
 
-        if a and c:
+        if a is not None and c is not None:
             combined = torch.cat((c.view(c.size(0), -1), a.view(a.size(0), -1)), dim=1)
         else:
-            combined = a if a else c
+            combined = a if a is not None else c
 
         return self.combined(combined)

+ 27 - 0
params.py

@@ -0,0 +1,27 @@
+seed = 0
+n_jobs = 1
+
+data_path_base = '/Users/aziai/Downloads/vtest_new2'
+
+# test end-to-end
+experiment1_test = {
+    'data_path_base': {data_path_base},
+    'conv_model_name': {'resnet'},
+    'num_epochs': {10},
+    'feature_extract': {True},
+    'batch_size': {64},
+    'lr': {0.001},
+    'use_vggish': {False},
+    'momentum': {0.9}
+}
+
+experiment1 = {
+    'data_path_base': {data_path_base},
+    'conv_model_name': {'resnet', 'vgg'},
+    'num_epochs': {10},
+    'feature_extract': {True, False},
+    'batch_size': {64},
+    'lr': {0.001},
+    'use_vggish': {False, True},
+    'momentum': {0.9}
+}

+ 2 - 0
requirements.txt

@@ -6,3 +6,5 @@ Pillow
 numpy
 moviepy
 opencv-python
+joblib
+pandas

+ 5 - 2
train.py

@@ -9,6 +9,8 @@ from torch import nn
 from data import AudioVideo
 from kissing_detector import KissingDetector
 
+ExperimentResults = Tuple[nn.Module, List[float], List[float]]
+
 
 def _get_params_to_update(model: nn.Module,
                           feature_extract: bool) -> List[nn.parameter.Parameter]:
@@ -30,12 +32,13 @@ def train_kd(data_path_base: str,
              num_epochs: int,
              feature_extract: bool,
              batch_size: int,
+             use_vggish: bool = True,
              num_workers: int = 4,
              shuffle: bool = True,
              lr: float = 0.001,
-             momentum: float = 0.9) -> Tuple[nn.Module, List[float], List[float]]:
+             momentum: float = 0.9) -> ExperimentResults:
     num_classes = 2
-    kd = KissingDetector(conv_model_name, num_classes, feature_extract)
+    kd = KissingDetector(conv_model_name, num_classes, feature_extract, use_vggish=use_vggish)
     params_to_update = _get_params_to_update(kd, feature_extract)
 
     datasets = {set_: AudioVideo(f'{data_path_base}/{set_}') for set_ in ['train', 'val']}

+ 40 - 0
utils.py

@@ -0,0 +1,40 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import hashlib
+import pickle
+import sys
+from functools import reduce
+from typing import TypeVar, List, Tuple, Dict, Any
+from uuid import uuid4
+
+T = TypeVar('T')
+
+
+def unzip(xs: List[Tuple[List[T], List[T]]]) -> Tuple[List[List[T]], List[List[T]]]:
+    return list(zip(*xs))
+
+
+def log(msg: str) -> None:
+    print(msg, file=sys.stderr)
+
+
+def merge_dicts(*args: Dict) -> Dict[Any, Any]:
+    return reduce(lambda x, y: {**x, **y}, args)
+
+
+def uuid_to_str(uuid: uuid4) -> str:
+    return str(uuid).replace('-', '')
+
+
+def hash_dict(d: Dict) -> str:
+    dict_str_rep = '_'.join([f'{key}_{d[key]}' for key in sorted(d.keys())])
+    return hashlib.sha224(bytearray(dict_str_rep, 'utf8')).hexdigest()
+
+
+def pickle_object(obj: object, path: str) -> None:
+    pickle.dump(obj, open(path, 'wb'))
+
+
+def unpickle(path: str) -> Any:
+    return pickle.load(open(path, 'rb'))