Browse Source

latest params

Amir Ziai 5 years ago
parent
commit
ec1ba63cbc
3 changed files with 121 additions and 23 deletions
  1. 111 18
      dev4.ipynb
  2. 3 3
      params.py
  3. 7 2
      train.py

+ 111 - 18
dev4.ipynb

@@ -12,7 +12,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 48,
+   "execution_count": 51,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -25,25 +25,23 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 49,
+   "execution_count": 54,
    "metadata": {},
    "outputs": [],
    "source": [
-    "ex = ExperimentRunner(params.experiment1, n_jobs=1)"
+    "ex = ExperimentRunner(params.experiment_vggish_only, n_jobs=1)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 55,
    "metadata": {},
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Running param set: {'data_path_base': '/Users/aziai/Downloads/vtest_new2', 'conv_model_name': 'vgg', 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.001, 'use_vggish': False, 'momentum': 0.9}\n",
-      "Downloading: \"https://download.pytorch.org/models/vgg11_bn-6002323d.pth\" to /Users/aziai/.cache/torch/checkpoints/vgg11_bn-6002323d.pth\n",
-      "100%|██████████| 531503671/531503671 [00:40<00:00, 13079409.73it/s]\n"
+      "Running param set: {'data_path_base': '/Users/aziai/Downloads/vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.001, 'use_vggish': True, 'momentum': 0.9}\n"
      ]
     },
     {
@@ -53,31 +51,126 @@
       "Updating ALL params\n",
       "Epoch 0/9\n",
       "----------\n",
-      "train Loss: 1.0483 F1: 0.2951 Acc: 0.4557\n",
-      "val Loss: 0.5709 F1: 0.8189 Acc: 0.7294\n",
+      "train Loss: 0.8801 F1: 0.6833 Acc: 0.5190\n",
+      "val Loss: 0.5289 F1: 0.8742 Acc: 0.7765\n",
       "\n",
       "Epoch 1/9\n",
       "----------\n",
-      "train Loss: 0.5769 F1: 0.7872 Acc: 0.7468\n",
-      "val Loss: 0.3201 F1: 0.9353 Acc: 0.8941\n",
+      "train Loss: 0.8135 F1: 0.6833 Acc: 0.5190\n",
+      "val Loss: 0.5584 F1: 0.8742 Acc: 0.7765\n",
       "\n",
       "Epoch 2/9\n",
       "----------\n",
-      "train Loss: 0.3647 F1: 0.8247 Acc: 0.7848\n",
-      "val Loss: 0.4309 F1: 0.8333 Acc: 0.7647\n",
+      "train Loss: 0.7179 F1: 0.6833 Acc: 0.5190\n",
+      "val Loss: 0.6482 F1: 0.8667 Acc: 0.7647\n",
       "\n",
       "Epoch 3/9\n",
       "----------\n",
-      "train Loss: 0.2243 F1: 0.8571 Acc: 0.8608\n",
-      "val Loss: 0.7989 F1: 0.6796 Acc: 0.6118\n",
+      "train Loss: 0.6749 F1: 0.6833 Acc: 0.5190\n",
+      "val Loss: 0.6608 F1: 0.8725 Acc: 0.7765\n",
       "\n",
       "Epoch 4/9\n",
       "----------\n",
-      "train Loss: 0.1799 F1: 0.9231 Acc: 0.9241\n",
-      "val Loss: 0.8629 F1: 0.7407 Acc: 0.6706\n",
+      "train Loss: 0.6832 F1: 0.6842 Acc: 0.5443\n",
+      "val Loss: 0.6556 F1: 0.8467 Acc: 0.7529\n",
       "\n",
       "Epoch 5/9\n",
-      "----------\n"
+      "----------\n",
+      "train Loss: 0.6969 F1: 0.6667 Acc: 0.5823\n",
+      "val Loss: 0.6542 F1: 0.8413 Acc: 0.7647\n",
+      "\n",
+      "Epoch 6/9\n",
+      "----------\n",
+      "train Loss: 0.6960 F1: 0.5060 Acc: 0.4810\n",
+      "val Loss: 0.6501 F1: 0.8788 Acc: 0.8118\n",
+      "\n",
+      "Epoch 7/9\n",
+      "----------\n",
+      "train Loss: 0.6896 F1: 0.5882 Acc: 0.5570\n",
+      "val Loss: 0.6488 F1: 0.8837 Acc: 0.8235\n",
+      "\n",
+      "Epoch 8/9\n",
+      "----------\n",
+      "train Loss: 0.6764 F1: 0.6667 Acc: 0.6203\n",
+      "val Loss: 0.6576 F1: 0.8976 Acc: 0.8471\n",
+      "\n",
+      "Epoch 9/9\n",
+      "----------\n",
+      "train Loss: 0.6710 F1: 0.7253 Acc: 0.6835\n",
+      "val Loss: 0.6590 F1: 0.8710 Acc: 0.8118\n",
+      "\n",
+      "Training complete in 0m 56s\n",
+      "Best val F1  : 0.897638\n",
+      "Best val Acc : 0.847059\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Running param set: {'data_path_base': '/Users/aziai/Downloads/vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.001, 'use_vggish': True, 'momentum': 0.9}\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Params to update\n",
+      "* combined.weight\n",
+      "* combined.bias\n",
+      "Epoch 0/9\n",
+      "----------\n",
+      "train Loss: 0.7826 F1: 0.6833 Acc: 0.5190\n",
+      "val Loss: 0.5453 F1: 0.8742 Acc: 0.7765\n",
+      "\n",
+      "Epoch 1/9\n",
+      "----------\n",
+      "train Loss: 0.7525 F1: 0.6833 Acc: 0.5190\n",
+      "val Loss: 0.5790 F1: 0.8742 Acc: 0.7765\n",
+      "\n",
+      "Epoch 2/9\n",
+      "----------\n",
+      "train Loss: 0.7158 F1: 0.6833 Acc: 0.5190\n",
+      "val Loss: 0.6449 F1: 0.8742 Acc: 0.7765\n",
+      "\n",
+      "Epoch 3/9\n",
+      "----------\n",
+      "train Loss: 0.6936 F1: 0.6838 Acc: 0.5316\n",
+      "val Loss: 0.7350 F1: 0.0870 Acc: 0.2588\n",
+      "\n",
+      "Epoch 4/9\n",
+      "----------\n",
+      "train Loss: 0.7142 F1: nan Acc: 0.4430\n",
+      "val Loss: 0.8224 F1: nan Acc: 0.2235\n",
+      "\n",
+      "Epoch 5/9\n",
+      "----------\n",
+      "train Loss: 0.7382 F1: nan Acc: 0.4810\n",
+      "val Loss: 0.8526 F1: nan Acc: 0.2235\n",
+      "\n",
+      "Epoch 6/9\n",
+      "----------\n",
+      "train Loss: 0.7477 F1: nan Acc: 0.4810\n",
+      "val Loss: 0.8238 F1: nan Acc: 0.2235\n",
+      "\n",
+      "Epoch 7/9\n",
+      "----------\n",
+      "train Loss: 0.7361 F1: nan Acc: 0.4810\n",
+      "val Loss: 0.7583 F1: 0.0299 Acc: 0.2353\n",
+      "\n",
+      "Epoch 8/9\n",
+      "----------\n",
+      "train Loss: 0.7142 F1: nan Acc: 0.4304\n",
+      "val Loss: 0.6900 F1: 0.6491 Acc: 0.5294\n",
+      "\n",
+      "Epoch 9/9\n",
+      "----------\n",
+      "train Loss: 0.6995 F1: 0.5161 Acc: 0.4304\n",
+      "val Loss: 0.6337 F1: 0.8742 Acc: 0.7765\n",
+      "\n",
+      "Training complete in 0m 26s\n",
+      "Best val F1  : 0.874172\n",
+      "Best val Acc : 0.776471\n"
      ]
     }
    ],

+ 3 - 3
params.py

@@ -4,7 +4,7 @@ n_jobs = 1
 data_path_base = '/Users/aziai/Downloads/vtest_new2'
 
 # test end-to-end
-experiment1_test = {
+experiment_test = {
     'data_path_base': {data_path_base},
     'conv_model_name': {'resnet'},
     'num_epochs': {10},
@@ -15,9 +15,9 @@ experiment1_test = {
     'momentum': {0.9}
 }
 
-experiment1 = {
+experiments = {
     'data_path_base': {data_path_base},
-    'conv_model_name': {'resnet', 'vgg'},
+    'conv_model_name': {'resnet', None},  # vgg
     'num_epochs': {10},
     'feature_extract': {True, False},
     'batch_size': {64},

+ 7 - 2
train.py

@@ -9,7 +9,7 @@ from torch import nn
 from data import AudioVideo
 from kissing_detector import KissingDetector
 
-ExperimentResults = Tuple[nn.Module, List[float], List[float]]
+ExperimentResults = Tuple[Optional[nn.Module], List[float], List[float]]
 
 
 def _get_params_to_update(model: nn.Module,
@@ -38,7 +38,12 @@ def train_kd(data_path_base: str,
              lr: float = 0.001,
              momentum: float = 0.9) -> ExperimentResults:
     num_classes = 2
-    kd = KissingDetector(conv_model_name, num_classes, feature_extract, use_vggish=use_vggish)
+    try:
+        kd = KissingDetector(conv_model_name, num_classes, feature_extract, use_vggish=use_vggish)
+    except ValueError:
+        # if the combination is not valid
+        return None, [-1.0], [-1.0]
+
     params_to_update = _get_params_to_update(kd, feature_extract)
 
     datasets = {set_: AudioVideo(f'{data_path_base}/{set_}') for set_ in ['train', 'val']}