Browse Source

added segmentor, qual stuff next

Amir Ziai 4 years ago
parent
commit
b2fcb55f40
8 changed files with 719 additions and 372 deletions
  1. 7 4
      README.md
  2. 537 350
      dev4.ipynb
  3. 1 0
      experiments.py
  4. 7 2
      params.py
  5. 24 14
      pipeline.py
  6. 1 0
      requirements.txt
  7. 121 2
      segmentor.py
  8. 21 0
      test_segmentor.py

+ 7 - 4
README.md

@@ -49,7 +49,10 @@ builder.build_dataset()
 - Failure examples
 
 ## TODO
-- Define experiments
-- Saliency map, class viz
-- 
-- ...
+- Segmentor
+- Qual
+    - Saliency map
+    - class viz
+    - Error examples
+    - Audio?
+- 3DC

+ 537 - 350
dev4.ipynb

@@ -12,7 +12,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 63,
+   "execution_count": 85,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -25,7 +25,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 71,
+   "execution_count": 92,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -34,269 +34,115 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 93,
    "metadata": {},
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'densenet', 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.01, 'use_vggish': False, 'momentum': 0.9}\n"
-     ]
-    },
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Updating ALL params\n",
-      "Epoch 0/9\n",
-      "----------\n",
-      "train Loss: 0.8410 F1: 0.3704 Acc: 0.5696\n",
-      "val Loss: 0.6466 F1: 0.8667 Acc: 0.7647\n",
-      "\n",
-      "Epoch 1/9\n",
-      "----------\n",
-      "train Loss: 1.0247 F1: 0.6949 Acc: 0.5443\n",
-      "val Loss: 1.7111 F1: nan Acc: 0.2235\n",
-      "\n",
-      "Epoch 2/9\n",
-      "----------\n",
-      "train Loss: 0.4048 F1: 0.6557 Acc: 0.7342\n",
-      "val Loss: 1.5030 F1: 0.1370 Acc: 0.2588\n",
-      "\n",
-      "Epoch 3/9\n",
-      "----------\n",
-      "train Loss: 0.0870 F1: 0.9877 Acc: 0.9873\n",
-      "val Loss: 0.6543 F1: 0.7627 Acc: 0.6706\n",
-      "\n",
-      "Epoch 4/9\n",
-      "----------\n",
-      "train Loss: 0.0952 F1: 0.9535 Acc: 0.9494\n",
-      "val Loss: 0.6773 F1: 0.8358 Acc: 0.7412\n",
-      "\n",
-      "Epoch 5/9\n",
-      "----------\n",
-      "train Loss: 0.0807 F1: 0.9647 Acc: 0.9620\n",
-      "val Loss: 0.8060 F1: 0.8182 Acc: 0.7176\n",
-      "\n",
-      "Epoch 6/9\n",
-      "----------\n",
-      "train Loss: 0.0083 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 1.2097 F1: 0.6667 Acc: 0.5529\n",
-      "\n",
-      "Epoch 7/9\n",
-      "----------\n",
-      "train Loss: 0.0021 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 1.7171 F1: 0.5102 Acc: 0.4353\n",
-      "\n",
-      "Epoch 8/9\n",
-      "----------\n",
-      "train Loss: 0.0003 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 2.0735 F1: 0.4421 Acc: 0.3765\n",
-      "\n",
-      "Epoch 9/9\n",
-      "----------\n",
-      "train Loss: 0.0002 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 2.5907 F1: 0.3218 Acc: 0.3059\n",
-      "\n",
-      "Training complete in 6m 49s\n",
-      "Best val F1  : 0.866667\n",
-      "Best val Acc : 0.764706\n"
-     ]
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'densenet', 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.01, 'use_vggish': False, 'momentum': 0.95}\n"
-     ]
-    },
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Updating ALL params\n",
-      "Epoch 0/9\n",
-      "----------\n",
-      "train Loss: 0.6807 F1: 0.6667 Acc: 0.5696\n",
-      "val Loss: 0.6620 F1: 0.6729 Acc: 0.5882\n",
-      "\n",
-      "Epoch 1/9\n",
-      "----------\n",
-      "train Loss: 0.3776 F1: 0.9425 Acc: 0.9367\n",
-      "val Loss: 0.8433 F1: 0.3000 Acc: 0.3412\n",
-      "\n",
-      "Epoch 2/9\n",
-      "----------\n",
-      "train Loss: 0.1782 F1: 0.9750 Acc: 0.9747\n",
-      "val Loss: 0.7165 F1: 0.5376 Acc: 0.4941\n",
-      "\n",
-      "Epoch 3/9\n",
-      "----------\n",
-      "train Loss: 0.1016 F1: 0.9762 Acc: 0.9747\n",
-      "val Loss: 0.6558 F1: 0.6981 Acc: 0.6235\n",
-      "\n",
-      "Epoch 4/9\n",
-      "----------\n",
-      "train Loss: 0.0343 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 0.8724 F1: 0.5111 Acc: 0.4824\n",
-      "\n",
-      "Epoch 5/9\n",
-      "----------\n",
-      "train Loss: 0.0107 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 1.0340 F1: 0.4889 Acc: 0.4588\n",
-      "\n",
-      "Epoch 6/9\n",
-      "----------\n",
-      "train Loss: 0.0066 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 1.3202 F1: 0.4186 Acc: 0.4118\n",
-      "\n",
-      "Epoch 7/9\n",
-      "----------\n",
-      "train Loss: 0.0046 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 1.5957 F1: 0.3953 Acc: 0.3882\n",
-      "\n",
-      "Epoch 8/9\n",
-      "----------\n",
-      "train Loss: 0.0015 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 1.9662 F1: 0.4186 Acc: 0.4118\n",
-      "\n",
-      "Epoch 9/9\n",
-      "----------\n",
-      "train Loss: 0.0007 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 2.1900 F1: 0.4045 Acc: 0.3765\n",
-      "\n",
-      "Training complete in 6m 60s\n",
-      "Best val F1  : 0.698113\n",
-      "Best val Acc : 0.623529\n"
-     ]
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'densenet', 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.01, 'use_vggish': True, 'momentum': 0.9}\n"
-     ]
-    },
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Updating ALL params\n",
-      "Epoch 0/9\n",
-      "----------\n",
-      "train Loss: 0.6751 F1: 0.5352 Acc: 0.5823\n",
-      "val Loss: 0.5813 F1: 0.8742 Acc: 0.7765\n",
-      "\n",
-      "Epoch 1/9\n",
-      "----------\n",
-      "train Loss: 0.6840 F1: 0.7130 Acc: 0.5823\n",
-      "val Loss: 2.9018 F1: nan Acc: 0.2235\n",
-      "\n",
-      "Epoch 2/9\n",
-      "----------\n",
-      "train Loss: 1.1845 F1: nan Acc: 0.4810\n",
-      "val Loss: 0.6255 F1: 0.7581 Acc: 0.6471\n",
-      "\n",
-      "Epoch 3/9\n",
-      "----------\n",
-      "train Loss: 0.1377 F1: 0.9535 Acc: 0.9494\n",
-      "val Loss: 1.1300 F1: 0.8800 Acc: 0.7882\n",
-      "\n",
-      "Epoch 4/9\n",
-      "----------\n",
-      "train Loss: 1.2311 F1: 0.7321 Acc: 0.6203\n",
-      "val Loss: 0.9068 F1: 0.8321 Acc: 0.7294\n",
-      "\n",
-      "Epoch 5/9\n",
-      "----------\n",
-      "train Loss: 0.0355 F1: 0.9756 Acc: 0.9747\n",
-      "val Loss: 2.1568 F1: 0.3059 Acc: 0.3059\n",
-      "\n",
-      "Epoch 6/9\n",
-      "----------\n",
-      "train Loss: 0.0030 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 6.2178 F1: nan Acc: 0.2235\n",
-      "\n",
-      "Epoch 7/9\n",
-      "----------\n",
-      "train Loss: 0.0112 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 8.0201 F1: nan Acc: 0.2235\n",
-      "\n",
-      "Epoch 8/9\n",
-      "----------\n",
-      "train Loss: 0.0022 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 9.1074 F1: nan Acc: 0.2235\n",
-      "\n",
-      "Epoch 9/9\n",
-      "----------\n",
-      "train Loss: 0.0037 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 9.3362 F1: nan Acc: 0.2235\n",
-      "\n",
-      "Training complete in 7m 50s\n",
-      "Best val F1  : 0.880000\n",
-      "Best val Acc : 0.788235\n"
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.01, 'use_vggish': False, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "3d4d0ae6efebf77f6b7f5a4163558c892daae15e174c5b791e9e5db1\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.01, 'use_vggish': False, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "c7b84e8c420a3be3a333ba41c2618399c2baebb470fc6a23a0a433d2\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.01, 'use_vggish': True, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "0fd6f6ebe947c54bdcb0340e244f0a27254f0a358be9a038f8fc9825\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.01, 'use_vggish': True, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "bc7e48700485911fd1dcfa46fb408e93d744c2ec3afd26356af872f9\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.001, 'use_vggish': False, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "54aaa1ed52b83839a188927adb63f64a6e8cc8f9d5ce37a7be960ed4\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.001, 'use_vggish': False, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "9af39512bef330a9bbab90f91063597b4d3d798a0a7dec9054354d6e\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.001, 'use_vggish': True, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "57fac69720ed9689a5805e15ddbb89ee26d9b64fbd2cf2b65c432a52\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.001, 'use_vggish': True, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "8686806f634f6da4be6bf84cca80c305e8a7f751dd8ca88ee2112398\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.01, 'use_vggish': False, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "dddc0808ccee29d6d96d7922bdce1e66e319ad29cd72065ceb76086f\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.01, 'use_vggish': False, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "1613a061bc0ee4cb3e89ff00d7b183241ee4fccf84420cb0cfa06ea4\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.01, 'use_vggish': True, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "10ece59983d31536981268d1e8bbdf4460e65b24a7567b1027d92a7c\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.01, 'use_vggish': True, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "db6e5ff043c83f9643e5fa4aecb3c85a03e2be1c8de56993246b6f23\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.001, 'use_vggish': False, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "33ac0b6c3357fb2cfc22194be83a872435a7b3495506b83f8ee76fcf\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.001, 'use_vggish': False, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "41fd07df53b8400f9386cf6d6fd5e400290fc19cc96bd1c1c66991d2\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.001, 'use_vggish': True, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "c2c04441a29f2bcf11fade126cd304849d3d5e9a1c98fab8a4b9a2fc\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': None, 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.001, 'use_vggish': True, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "4d227fadf14ff66ea2c7f253a96c2111a78234f95089d003a3d3f88f\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.01, 'use_vggish': False, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "a9d9eb5afabbea4dff607846fe1a0782c760f44cfbd0f982c2f8bbb4\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.01, 'use_vggish': False, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "b5d41c44c98e2fa432656659c4fbcb77ef2a66d30dbe5148940ec3bb\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.01, 'use_vggish': True, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "a60fa62389b6418273ca349479228d39e55b7b357ba2e7ec95423d41\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.01, 'use_vggish': True, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "dcb42ec2f883718f58a1612e4afff99249aa628afa844252c67ed670\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.001, 'use_vggish': False, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "8c4da8f6f9157db4dbd6f70b67c1db469ea034730c12ebfccf0c5329\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.001, 'use_vggish': False, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "091aaed6d121608dc449c2f33a7c74bbe835ccc5abadc937263b9f90\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.001, 'use_vggish': True, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "c6cd69a63c6b7670802afa33af35b13d6c687340f71bf7e299e5711c\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.001, 'use_vggish': True, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "af38e4c808ff3998977167f1e5138bdd0301f92a2dc206a6319d9823\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.01, 'use_vggish': False, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "9715fe9781571c61eb6bf38cbc5173df40e9e41ae3eaec6789cf771b\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.01, 'use_vggish': False, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "21f4a1394a48e89171341403ff9eccc6e080a9acb66005b9b14a035d\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.01, 'use_vggish': True, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "454e165857fea306cac78a54c635a15545fab0e7a5f6067e6509aeb3\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.01, 'use_vggish': True, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "4566c335e71f215a1110a36a7cfec1882c86f86dbb7e3e5787dfdc26\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.001, 'use_vggish': False, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "6fcff331f0df5ded20113ae3e7f2d1568e3f3fba9f2a922715326fcf\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.001, 'use_vggish': False, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "81bf2d2fb5c8083afc54b43f8c561e6c5c36304a2d002444a02bb0d9\n",
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.001, 'use_vggish': True, 'momentum': 0.9}\n",
+      "Loading experiment results from cache\n",
+      "5f938c364c0de0bc3ab556f1b48971520006d0ed581d53c1d89787d7\n"
      ]
     },
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'densenet', 'num_epochs': 10, 'feature_extract': False, 'batch_size': 64, 'lr': 0.01, 'use_vggish': True, 'momentum': 0.95}\n"
-     ]
-    },
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Updating ALL params\n",
-      "Epoch 0/9\n",
-      "----------\n",
-      "train Loss: 0.6752 F1: 0.5526 Acc: 0.5696\n",
-      "val Loss: 0.5195 F1: 0.8742 Acc: 0.7765\n",
-      "\n",
-      "Epoch 1/9\n",
-      "----------\n",
-      "train Loss: 0.5782 F1: 0.7387 Acc: 0.6329\n",
-      "val Loss: 1.4869 F1: 0.0299 Acc: 0.2353\n",
-      "\n",
-      "Epoch 2/9\n",
-      "----------\n",
-      "train Loss: 0.3623 F1: 0.7941 Acc: 0.8228\n",
-      "val Loss: 0.5903 F1: 0.7692 Acc: 0.6824\n",
-      "\n",
-      "Epoch 3/9\n",
-      "----------\n",
-      "train Loss: 0.0622 F1: 0.9880 Acc: 0.9873\n",
-      "val Loss: 0.4745 F1: 0.8714 Acc: 0.7882\n",
-      "\n",
-      "Epoch 4/9\n",
-      "----------\n",
-      "train Loss: 0.0707 F1: 0.9762 Acc: 0.9747\n",
-      "val Loss: 0.5149 F1: 0.8759 Acc: 0.8000\n",
-      "\n",
-      "Epoch 5/9\n",
-      "----------\n",
-      "train Loss: 0.0126 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 0.6557 F1: 0.8244 Acc: 0.7294\n",
-      "\n",
-      "Epoch 6/9\n",
-      "----------\n",
-      "train Loss: 0.0048 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 1.0149 F1: 0.7069 Acc: 0.6000\n",
-      "\n",
-      "Epoch 7/9\n",
-      "----------\n",
-      "train Loss: 0.0042 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 1.3682 F1: 0.7009 Acc: 0.5882\n",
-      "\n",
-      "Epoch 8/9\n",
-      "----------\n",
-      "train Loss: 0.0025 F1: 1.0000 Acc: 1.0000\n",
-      "val Loss: 1.7442 F1: 0.6607 Acc: 0.5529\n",
-      "\n",
-      "Epoch 9/9\n",
-      "----------\n"
+      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.001, 'use_vggish': True, 'momentum': 0.95}\n",
+      "Loading experiment results from cache\n",
+      "b0ba9005b21c06336b8a58c2ed784dcf5fec6028dd7bca0eefbc3c7a\n"
      ]
     }
    ],
@@ -306,7 +152,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 68,
+   "execution_count": 88,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -315,7 +161,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 69,
+   "execution_count": 89,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -326,7 +172,28 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 70,
+   "execution_count": 90,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "-rw-r--r--  1 aziai  staff   4.6K Jun  2 23:21 results/results_20190602232142.csv\r\n",
+      "-rw-r--r--@ 1 aziai  staff   2.1K Jun  2 20:25 results/results_20190602195657.csv\r\n",
+      "-rw-r--r--@ 1 aziai  staff   1.4K Jun  2 19:48 results/results_20190602194822.csv\r\n",
+      "-rw-r--r--@ 1 aziai  staff   478B Jun  2 19:45 results/results_20190602194401.csv\r\n",
+      "-rw-r--r--@ 1 aziai  staff   846B Jun  2 19:28 results/results_20191402191416.csv\r\n"
+     ]
+    }
+   ],
+   "source": [
+    "!ls -lht results/*.csv"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 91,
    "metadata": {},
    "outputs": [
     {
@@ -366,181 +233,481 @@
        "  </thead>\n",
        "  <tbody>\n",
        "    <tr>\n",
-       "      <th>9</th>\n",
+       "      <th>17</th>\n",
        "      <td>64</td>\n",
        "      <td>resnet</td>\n",
-       "      <td>/Users/aziai/Downloads/vtest_new2</td>\n",
-       "      <td>20190602195657</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>False</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.95</td>\n",
+       "      <td>10</td>\n",
+       "      <td>b5d41c44c98e2fa432656659c4fbcb77ef2a66d30dbe51...</td>\n",
+       "      <td>False</td>\n",
+       "      <td>0.917647</td>\n",
+       "      <td>0.946565</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>22</th>\n",
+       "      <td>64</td>\n",
+       "      <td>resnet</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
        "      <td>False</td>\n",
        "      <td>0.001</td>\n",
-       "      <td>0.9</td>\n",
+       "      <td>0.90</td>\n",
        "      <td>10</td>\n",
-       "      <td>b3290017112d2116e4bdbb9c8dbf15a8e75adacb942afb...</td>\n",
+       "      <td>c6cd69a63c6b7670802afa33af35b13d6c687340f71bf7...</td>\n",
        "      <td>True</td>\n",
        "      <td>0.894118</td>\n",
-       "      <td>0.935252</td>\n",
+       "      <td>0.936170</td>\n",
        "    </tr>\n",
        "    <tr>\n",
-       "      <th>5</th>\n",
+       "      <th>6</th>\n",
        "      <td>64</td>\n",
        "      <td>NaN</td>\n",
-       "      <td>/Users/aziai/Downloads/vtest_new2</td>\n",
-       "      <td>20190602195657</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
        "      <td>False</td>\n",
        "      <td>0.001</td>\n",
-       "      <td>0.9</td>\n",
+       "      <td>0.90</td>\n",
        "      <td>10</td>\n",
-       "      <td>4687cd536acf6c9b058a8d20719fa6910f50a6abee3ae2...</td>\n",
+       "      <td>57fac69720ed9689a5805e15ddbb89ee26d9b64fbd2cf2...</td>\n",
        "      <td>True</td>\n",
-       "      <td>0.847059</td>\n",
-       "      <td>0.897638</td>\n",
+       "      <td>0.882353</td>\n",
+       "      <td>0.929577</td>\n",
        "    </tr>\n",
        "    <tr>\n",
-       "      <th>1</th>\n",
+       "      <th>3</th>\n",
        "      <td>64</td>\n",
-       "      <td>densenet</td>\n",
-       "      <td>/Users/aziai/Downloads/vtest_new2</td>\n",
-       "      <td>20190602195657</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>False</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.95</td>\n",
+       "      <td>10</td>\n",
+       "      <td>bc7e48700485911fd1dcfa46fb408e93d744c2ec3afd26...</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.858824</td>\n",
+       "      <td>0.916667</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>24</th>\n",
+       "      <td>64</td>\n",
+       "      <td>resnet</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.90</td>\n",
+       "      <td>10</td>\n",
+       "      <td>9715fe9781571c61eb6bf38cbc5173df40e9e41ae3eaec...</td>\n",
+       "      <td>False</td>\n",
+       "      <td>0.870588</td>\n",
+       "      <td>0.916031</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>26</th>\n",
+       "      <td>64</td>\n",
+       "      <td>resnet</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.90</td>\n",
+       "      <td>10</td>\n",
+       "      <td>454e165857fea306cac78a54c635a15545fab0e7a5f606...</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.858824</td>\n",
+       "      <td>0.907692</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>23</th>\n",
+       "      <td>64</td>\n",
+       "      <td>resnet</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
        "      <td>False</td>\n",
        "      <td>0.001</td>\n",
-       "      <td>0.9</td>\n",
+       "      <td>0.95</td>\n",
        "      <td>10</td>\n",
-       "      <td>697154af4871003b02cf60d53531fb21a80e853646b458...</td>\n",
+       "      <td>af38e4c808ff3998977167f1e5138bdd0301f92a2dc206...</td>\n",
        "      <td>True</td>\n",
-       "      <td>0.823529</td>\n",
-       "      <td>0.896552</td>\n",
+       "      <td>0.835294</td>\n",
+       "      <td>0.893939</td>\n",
        "    </tr>\n",
        "    <tr>\n",
-       "      <th>8</th>\n",
+       "      <th>27</th>\n",
        "      <td>64</td>\n",
        "      <td>resnet</td>\n",
-       "      <td>/Users/aziai/Downloads/vtest_new2</td>\n",
-       "      <td>20190602195657</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.95</td>\n",
+       "      <td>10</td>\n",
+       "      <td>4566c335e71f215a1110a36a7cfec1882c86f86dbb7e3e...</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.823529</td>\n",
+       "      <td>0.893617</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>64</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
        "      <td>False</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.90</td>\n",
+       "      <td>10</td>\n",
+       "      <td>0fd6f6ebe947c54bdcb0340e244f0a27254f0a358be9a0...</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.811765</td>\n",
+       "      <td>0.891892</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>31</th>\n",
+       "      <td>64</td>\n",
+       "      <td>resnet</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>True</td>\n",
        "      <td>0.001</td>\n",
-       "      <td>0.9</td>\n",
+       "      <td>0.95</td>\n",
        "      <td>10</td>\n",
-       "      <td>0e57debc92afdd0dc7a209584b4d97860c9dba98f3aed4...</td>\n",
+       "      <td>b0ba9005b21c06336b8a58c2ed784dcf5fec6028dd7bca...</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.811765</td>\n",
+       "      <td>0.887324</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>18</th>\n",
+       "      <td>64</td>\n",
+       "      <td>resnet</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
        "      <td>False</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.90</td>\n",
+       "      <td>10</td>\n",
+       "      <td>a60fa62389b6418273ca349479228d39e55b7b357ba2e7...</td>\n",
+       "      <td>True</td>\n",
        "      <td>0.823529</td>\n",
-       "      <td>0.878049</td>\n",
+       "      <td>0.887218</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>19</th>\n",
+       "      <td>64</td>\n",
+       "      <td>resnet</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>False</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.95</td>\n",
+       "      <td>10</td>\n",
+       "      <td>dcb42ec2f883718f58a1612e4afff99249aa628afa8442...</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.800000</td>\n",
+       "      <td>0.880000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>29</th>\n",
+       "      <td>64</td>\n",
+       "      <td>resnet</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.001</td>\n",
+       "      <td>0.95</td>\n",
+       "      <td>10</td>\n",
+       "      <td>81bf2d2fb5c8083afc54b43f8c561e6c5c36304a2d0024...</td>\n",
+       "      <td>False</td>\n",
+       "      <td>0.800000</td>\n",
+       "      <td>0.877698</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>11</th>\n",
+       "      <td>64</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.95</td>\n",
+       "      <td>10</td>\n",
+       "      <td>db6e5ff043c83f9643e5fa4aecb3c85a03e2be1c8de569...</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.776471</td>\n",
+       "      <td>0.874172</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>7</th>\n",
        "      <td>64</td>\n",
        "      <td>NaN</td>\n",
-       "      <td>/Users/aziai/Downloads/vtest_new2</td>\n",
-       "      <td>20190602195657</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>False</td>\n",
+       "      <td>0.001</td>\n",
+       "      <td>0.95</td>\n",
+       "      <td>10</td>\n",
+       "      <td>8686806f634f6da4be6bf84cca80c305e8a7f751dd8ca8...</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.776471</td>\n",
+       "      <td>0.874172</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>14</th>\n",
+       "      <td>64</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
        "      <td>True</td>\n",
        "      <td>0.001</td>\n",
-       "      <td>0.9</td>\n",
+       "      <td>0.90</td>\n",
        "      <td>10</td>\n",
-       "      <td>de7011637531360e5c76520f054d90327cc478dbabeab9...</td>\n",
+       "      <td>c2c04441a29f2bcf11fade126cd304849d3d5e9a1c98fa...</td>\n",
        "      <td>True</td>\n",
        "      <td>0.776471</td>\n",
        "      <td>0.874172</td>\n",
        "    </tr>\n",
        "    <tr>\n",
-       "      <th>2</th>\n",
+       "      <th>15</th>\n",
        "      <td>64</td>\n",
-       "      <td>densenet</td>\n",
-       "      <td>/Users/aziai/Downloads/vtest_new2</td>\n",
-       "      <td>20190602195657</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
        "      <td>True</td>\n",
        "      <td>0.001</td>\n",
-       "      <td>0.9</td>\n",
+       "      <td>0.95</td>\n",
        "      <td>10</td>\n",
-       "      <td>89b8a27230c7f2ee2dc0ece6fd1f9deccae873fce8288f...</td>\n",
-       "      <td>False</td>\n",
+       "      <td>4d227fadf14ff66ea2c7f253a96c2111a78234f95089d0...</td>\n",
+       "      <td>True</td>\n",
        "      <td>0.776471</td>\n",
-       "      <td>0.845528</td>\n",
+       "      <td>0.874172</td>\n",
        "    </tr>\n",
        "    <tr>\n",
-       "      <th>11</th>\n",
+       "      <th>10</th>\n",
+       "      <td>64</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.90</td>\n",
+       "      <td>10</td>\n",
+       "      <td>10ece59983d31536981268d1e8bbdf4460e65b24a7567b...</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.776471</td>\n",
+       "      <td>0.874172</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>25</th>\n",
        "      <td>64</td>\n",
        "      <td>resnet</td>\n",
-       "      <td>/Users/aziai/Downloads/vtest_new2</td>\n",
-       "      <td>20190602195657</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
        "      <td>True</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.95</td>\n",
+       "      <td>10</td>\n",
+       "      <td>21f4a1394a48e89171341403ff9eccc6e080a9acb66005...</td>\n",
+       "      <td>False</td>\n",
+       "      <td>0.811765</td>\n",
+       "      <td>0.873016</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>21</th>\n",
+       "      <td>64</td>\n",
+       "      <td>resnet</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>False</td>\n",
+       "      <td>0.001</td>\n",
+       "      <td>0.95</td>\n",
+       "      <td>10</td>\n",
+       "      <td>091aaed6d121608dc449c2f33a7c74bbe835ccc5abadc9...</td>\n",
+       "      <td>False</td>\n",
+       "      <td>0.800000</td>\n",
+       "      <td>0.872180</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>30</th>\n",
+       "      <td>64</td>\n",
+       "      <td>resnet</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.001</td>\n",
+       "      <td>0.90</td>\n",
+       "      <td>10</td>\n",
+       "      <td>5f938c364c0de0bc3ab556f1b48971520006d0ed581d53...</td>\n",
+       "      <td>True</td>\n",
+       "      <td>0.776471</td>\n",
+       "      <td>0.868966</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>16</th>\n",
+       "      <td>64</td>\n",
+       "      <td>resnet</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>False</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.90</td>\n",
+       "      <td>10</td>\n",
+       "      <td>a9d9eb5afabbea4dff607846fe1a0782c760f44cfbd0f9...</td>\n",
+       "      <td>False</td>\n",
+       "      <td>0.788235</td>\n",
+       "      <td>0.867647</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>20</th>\n",
+       "      <td>64</td>\n",
+       "      <td>resnet</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>False</td>\n",
        "      <td>0.001</td>\n",
-       "      <td>0.9</td>\n",
+       "      <td>0.90</td>\n",
        "      <td>10</td>\n",
-       "      <td>8e644bc291a463725bf0bcb11825a196383a4860eeecd7...</td>\n",
+       "      <td>8c4da8f6f9157db4dbd6f70b67c1db469ea034730c12eb...</td>\n",
+       "      <td>False</td>\n",
+       "      <td>0.788235</td>\n",
+       "      <td>0.854839</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>28</th>\n",
+       "      <td>64</td>\n",
+       "      <td>resnet</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
        "      <td>True</td>\n",
+       "      <td>0.001</td>\n",
+       "      <td>0.90</td>\n",
+       "      <td>10</td>\n",
+       "      <td>6fcff331f0df5ded20113ae3e7f2d1568e3f3fba9f2a92...</td>\n",
+       "      <td>False</td>\n",
        "      <td>0.729412</td>\n",
-       "      <td>0.824427</td>\n",
+       "      <td>0.818898</td>\n",
        "    </tr>\n",
        "    <tr>\n",
-       "      <th>0</th>\n",
+       "      <th>1</th>\n",
        "      <td>64</td>\n",
-       "      <td>densenet</td>\n",
-       "      <td>/Users/aziai/Downloads/vtest_new2</td>\n",
-       "      <td>20190602195657</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
        "      <td>False</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.95</td>\n",
+       "      <td>10</td>\n",
+       "      <td>c7b84e8c420a3be3a333ba41c2618399c2baebb470fc6a...</td>\n",
+       "      <td>False</td>\n",
+       "      <td>-1.000000</td>\n",
+       "      <td>-1.000000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>13</th>\n",
+       "      <td>64</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>True</td>\n",
        "      <td>0.001</td>\n",
-       "      <td>0.9</td>\n",
+       "      <td>0.95</td>\n",
        "      <td>10</td>\n",
-       "      <td>5c8b5c5ceb49ba7d53ccc921e116ae183fc5b44037c410...</td>\n",
+       "      <td>41fd07df53b8400f9386cf6d6fd5e400290fc19cc96bd1...</td>\n",
        "      <td>False</td>\n",
-       "      <td>0.694118</td>\n",
-       "      <td>0.796875</td>\n",
+       "      <td>-1.000000</td>\n",
+       "      <td>-1.000000</td>\n",
        "    </tr>\n",
        "    <tr>\n",
-       "      <th>3</th>\n",
+       "      <th>12</th>\n",
        "      <td>64</td>\n",
-       "      <td>densenet</td>\n",
-       "      <td>/Users/aziai/Downloads/vtest_new2</td>\n",
-       "      <td>20190602195657</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
        "      <td>True</td>\n",
        "      <td>0.001</td>\n",
-       "      <td>0.9</td>\n",
+       "      <td>0.90</td>\n",
        "      <td>10</td>\n",
-       "      <td>31c3e541f0e5d5b5c1823de1f23c1bc4c8c4be3c22eef1...</td>\n",
+       "      <td>33ac0b6c3357fb2cfc22194be83a872435a7b3495506b8...</td>\n",
+       "      <td>False</td>\n",
+       "      <td>-1.000000</td>\n",
+       "      <td>-1.000000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>9</th>\n",
+       "      <td>64</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
        "      <td>True</td>\n",
-       "      <td>0.529412</td>\n",
-       "      <td>0.629630</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.95</td>\n",
+       "      <td>10</td>\n",
+       "      <td>1613a061bc0ee4cb3e89ff00d7b183241ee4fccf84420c...</td>\n",
+       "      <td>False</td>\n",
+       "      <td>-1.000000</td>\n",
+       "      <td>-1.000000</td>\n",
        "    </tr>\n",
        "    <tr>\n",
-       "      <th>10</th>\n",
+       "      <th>8</th>\n",
        "      <td>64</td>\n",
-       "      <td>resnet</td>\n",
-       "      <td>/Users/aziai/Downloads/vtest_new2</td>\n",
-       "      <td>20190602195657</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
        "      <td>True</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.90</td>\n",
+       "      <td>10</td>\n",
+       "      <td>dddc0808ccee29d6d96d7922bdce1e66e319ad29cd7206...</td>\n",
+       "      <td>False</td>\n",
+       "      <td>-1.000000</td>\n",
+       "      <td>-1.000000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>5</th>\n",
+       "      <td>64</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>False</td>\n",
        "      <td>0.001</td>\n",
-       "      <td>0.9</td>\n",
+       "      <td>0.95</td>\n",
        "      <td>10</td>\n",
-       "      <td>e07e5119b07164f06098d1adba9e4c43ad0344716a0746...</td>\n",
+       "      <td>9af39512bef330a9bbab90f91063597b4d3d798a0a7dec...</td>\n",
        "      <td>False</td>\n",
-       "      <td>0.564706</td>\n",
-       "      <td>0.626263</td>\n",
+       "      <td>-1.000000</td>\n",
+       "      <td>-1.000000</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>4</th>\n",
        "      <td>64</td>\n",
        "      <td>NaN</td>\n",
-       "      <td>/Users/aziai/Downloads/vtest_new2</td>\n",
-       "      <td>20190602195657</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
        "      <td>False</td>\n",
        "      <td>0.001</td>\n",
-       "      <td>0.9</td>\n",
+       "      <td>0.90</td>\n",
        "      <td>10</td>\n",
-       "      <td>7324021853e0ca109ff0effaf0c9d06e68d72744305f34...</td>\n",
+       "      <td>54aaa1ed52b83839a188927adb63f64a6e8cc8f9d5ce37...</td>\n",
        "      <td>False</td>\n",
        "      <td>-1.000000</td>\n",
        "      <td>-1.000000</td>\n",
        "    </tr>\n",
        "    <tr>\n",
-       "      <th>6</th>\n",
+       "      <th>0</th>\n",
        "      <td>64</td>\n",
        "      <td>NaN</td>\n",
-       "      <td>/Users/aziai/Downloads/vtest_new2</td>\n",
-       "      <td>20190602195657</td>\n",
-       "      <td>True</td>\n",
-       "      <td>0.001</td>\n",
-       "      <td>0.9</td>\n",
+       "      <td>vtest_new2</td>\n",
+       "      <td>20190602232142</td>\n",
+       "      <td>False</td>\n",
+       "      <td>0.010</td>\n",
+       "      <td>0.90</td>\n",
        "      <td>10</td>\n",
-       "      <td>9ca13084e79d88cec23574c0c37fa9109fe87a7026f9bd...</td>\n",
+       "      <td>3d4d0ae6efebf77f6b7f5a4163558c892daae15e174c5b...</td>\n",
        "      <td>False</td>\n",
        "      <td>-1.000000</td>\n",
        "      <td>-1.000000</td>\n",
@@ -550,22 +717,42 @@
        "</div>"
       ],
       "text/plain": [
-       "    batch_size conv_model_name                     data_path_base  experiment_uuid  feature_extract     lr  momentum  num_epochs                                        runner_uuid  use_vggish   val_acc    val_f1\n",
-       "9           64          resnet  /Users/aziai/Downloads/vtest_new2   20190602195657            False  0.001       0.9          10  b3290017112d2116e4bdbb9c8dbf15a8e75adacb942afb...        True  0.894118  0.935252\n",
-       "5           64             NaN  /Users/aziai/Downloads/vtest_new2   20190602195657            False  0.001       0.9          10  4687cd536acf6c9b058a8d20719fa6910f50a6abee3ae2...        True  0.847059  0.897638\n",
-       "1           64        densenet  /Users/aziai/Downloads/vtest_new2   20190602195657            False  0.001       0.9          10  697154af4871003b02cf60d53531fb21a80e853646b458...        True  0.823529  0.896552\n",
-       "8           64          resnet  /Users/aziai/Downloads/vtest_new2   20190602195657            False  0.001       0.9          10  0e57debc92afdd0dc7a209584b4d97860c9dba98f3aed4...       False  0.823529  0.878049\n",
-       "7           64             NaN  /Users/aziai/Downloads/vtest_new2   20190602195657             True  0.001       0.9          10  de7011637531360e5c76520f054d90327cc478dbabeab9...        True  0.776471  0.874172\n",
-       "2           64        densenet  /Users/aziai/Downloads/vtest_new2   20190602195657             True  0.001       0.9          10  89b8a27230c7f2ee2dc0ece6fd1f9deccae873fce8288f...       False  0.776471  0.845528\n",
-       "11          64          resnet  /Users/aziai/Downloads/vtest_new2   20190602195657             True  0.001       0.9          10  8e644bc291a463725bf0bcb11825a196383a4860eeecd7...        True  0.729412  0.824427\n",
-       "0           64        densenet  /Users/aziai/Downloads/vtest_new2   20190602195657            False  0.001       0.9          10  5c8b5c5ceb49ba7d53ccc921e116ae183fc5b44037c410...       False  0.694118  0.796875\n",
-       "3           64        densenet  /Users/aziai/Downloads/vtest_new2   20190602195657             True  0.001       0.9          10  31c3e541f0e5d5b5c1823de1f23c1bc4c8c4be3c22eef1...        True  0.529412  0.629630\n",
-       "10          64          resnet  /Users/aziai/Downloads/vtest_new2   20190602195657             True  0.001       0.9          10  e07e5119b07164f06098d1adba9e4c43ad0344716a0746...       False  0.564706  0.626263\n",
-       "4           64             NaN  /Users/aziai/Downloads/vtest_new2   20190602195657            False  0.001       0.9          10  7324021853e0ca109ff0effaf0c9d06e68d72744305f34...       False -1.000000 -1.000000\n",
-       "6           64             NaN  /Users/aziai/Downloads/vtest_new2   20190602195657             True  0.001       0.9          10  9ca13084e79d88cec23574c0c37fa9109fe87a7026f9bd...       False -1.000000 -1.000000"
+       "    batch_size conv_model_name data_path_base  experiment_uuid  feature_extract     lr  momentum  num_epochs                                        runner_uuid  use_vggish   val_acc    val_f1\n",
+       "17          64          resnet     vtest_new2   20190602232142            False  0.010      0.95          10  b5d41c44c98e2fa432656659c4fbcb77ef2a66d30dbe51...       False  0.917647  0.946565\n",
+       "22          64          resnet     vtest_new2   20190602232142            False  0.001      0.90          10  c6cd69a63c6b7670802afa33af35b13d6c687340f71bf7...        True  0.894118  0.936170\n",
+       "6           64             NaN     vtest_new2   20190602232142            False  0.001      0.90          10  57fac69720ed9689a5805e15ddbb89ee26d9b64fbd2cf2...        True  0.882353  0.929577\n",
+       "3           64             NaN     vtest_new2   20190602232142            False  0.010      0.95          10  bc7e48700485911fd1dcfa46fb408e93d744c2ec3afd26...        True  0.858824  0.916667\n",
+       "24          64          resnet     vtest_new2   20190602232142             True  0.010      0.90          10  9715fe9781571c61eb6bf38cbc5173df40e9e41ae3eaec...       False  0.870588  0.916031\n",
+       "26          64          resnet     vtest_new2   20190602232142             True  0.010      0.90          10  454e165857fea306cac78a54c635a15545fab0e7a5f606...        True  0.858824  0.907692\n",
+       "23          64          resnet     vtest_new2   20190602232142            False  0.001      0.95          10  af38e4c808ff3998977167f1e5138bdd0301f92a2dc206...        True  0.835294  0.893939\n",
+       "27          64          resnet     vtest_new2   20190602232142             True  0.010      0.95          10  4566c335e71f215a1110a36a7cfec1882c86f86dbb7e3e...        True  0.823529  0.893617\n",
+       "2           64             NaN     vtest_new2   20190602232142            False  0.010      0.90          10  0fd6f6ebe947c54bdcb0340e244f0a27254f0a358be9a0...        True  0.811765  0.891892\n",
+       "31          64          resnet     vtest_new2   20190602232142             True  0.001      0.95          10  b0ba9005b21c06336b8a58c2ed784dcf5fec6028dd7bca...        True  0.811765  0.887324\n",
+       "18          64          resnet     vtest_new2   20190602232142            False  0.010      0.90          10  a60fa62389b6418273ca349479228d39e55b7b357ba2e7...        True  0.823529  0.887218\n",
+       "19          64          resnet     vtest_new2   20190602232142            False  0.010      0.95          10  dcb42ec2f883718f58a1612e4afff99249aa628afa8442...        True  0.800000  0.880000\n",
+       "29          64          resnet     vtest_new2   20190602232142             True  0.001      0.95          10  81bf2d2fb5c8083afc54b43f8c561e6c5c36304a2d0024...       False  0.800000  0.877698\n",
+       "11          64             NaN     vtest_new2   20190602232142             True  0.010      0.95          10  db6e5ff043c83f9643e5fa4aecb3c85a03e2be1c8de569...        True  0.776471  0.874172\n",
+       "7           64             NaN     vtest_new2   20190602232142            False  0.001      0.95          10  8686806f634f6da4be6bf84cca80c305e8a7f751dd8ca8...        True  0.776471  0.874172\n",
+       "14          64             NaN     vtest_new2   20190602232142             True  0.001      0.90          10  c2c04441a29f2bcf11fade126cd304849d3d5e9a1c98fa...        True  0.776471  0.874172\n",
+       "15          64             NaN     vtest_new2   20190602232142             True  0.001      0.95          10  4d227fadf14ff66ea2c7f253a96c2111a78234f95089d0...        True  0.776471  0.874172\n",
+       "10          64             NaN     vtest_new2   20190602232142             True  0.010      0.90          10  10ece59983d31536981268d1e8bbdf4460e65b24a7567b...        True  0.776471  0.874172\n",
+       "25          64          resnet     vtest_new2   20190602232142             True  0.010      0.95          10  21f4a1394a48e89171341403ff9eccc6e080a9acb66005...       False  0.811765  0.873016\n",
+       "21          64          resnet     vtest_new2   20190602232142            False  0.001      0.95          10  091aaed6d121608dc449c2f33a7c74bbe835ccc5abadc9...       False  0.800000  0.872180\n",
+       "30          64          resnet     vtest_new2   20190602232142             True  0.001      0.90          10  5f938c364c0de0bc3ab556f1b48971520006d0ed581d53...        True  0.776471  0.868966\n",
+       "16          64          resnet     vtest_new2   20190602232142            False  0.010      0.90          10  a9d9eb5afabbea4dff607846fe1a0782c760f44cfbd0f9...       False  0.788235  0.867647\n",
+       "20          64          resnet     vtest_new2   20190602232142            False  0.001      0.90          10  8c4da8f6f9157db4dbd6f70b67c1db469ea034730c12eb...       False  0.788235  0.854839\n",
+       "28          64          resnet     vtest_new2   20190602232142             True  0.001      0.90          10  6fcff331f0df5ded20113ae3e7f2d1568e3f3fba9f2a92...       False  0.729412  0.818898\n",
+       "1           64             NaN     vtest_new2   20190602232142            False  0.010      0.95          10  c7b84e8c420a3be3a333ba41c2618399c2baebb470fc6a...       False -1.000000 -1.000000\n",
+       "13          64             NaN     vtest_new2   20190602232142             True  0.001      0.95          10  41fd07df53b8400f9386cf6d6fd5e400290fc19cc96bd1...       False -1.000000 -1.000000\n",
+       "12          64             NaN     vtest_new2   20190602232142             True  0.001      0.90          10  33ac0b6c3357fb2cfc22194be83a872435a7b3495506b8...       False -1.000000 -1.000000\n",
+       "9           64             NaN     vtest_new2   20190602232142             True  0.010      0.95          10  1613a061bc0ee4cb3e89ff00d7b183241ee4fccf84420c...       False -1.000000 -1.000000\n",
+       "8           64             NaN     vtest_new2   20190602232142             True  0.010      0.90          10  dddc0808ccee29d6d96d7922bdce1e66e319ad29cd7206...       False -1.000000 -1.000000\n",
+       "5           64             NaN     vtest_new2   20190602232142            False  0.001      0.95          10  9af39512bef330a9bbab90f91063597b4d3d798a0a7dec...       False -1.000000 -1.000000\n",
+       "4           64             NaN     vtest_new2   20190602232142            False  0.001      0.90          10  54aaa1ed52b83839a188927adb63f64a6e8cc8f9d5ce37...       False -1.000000 -1.000000\n",
+       "0           64             NaN     vtest_new2   20190602232142            False  0.010      0.90          10  3d4d0ae6efebf77f6b7f5a4163558c892daae15e174c5b...       False -1.000000 -1.000000"
       ]
      },
-     "execution_count": 70,
+     "execution_count": 91,
      "metadata": {},
      "output_type": "execute_result"
     }

+ 1 - 0
experiments.py

@@ -41,6 +41,7 @@ class ExperimentRunner:
 
         if self._experiment_result_exists(uuid):
             log('Loading experiment results from cache')
+            log(uuid)
             experiment_results = unpickle(self._file_path_experiment_results(uuid))
         else:
             experiment_results = train_kd(**param_set)

+ 7 - 2
params.py

@@ -1,8 +1,13 @@
+import numpy as np
+
 seed = 0
 n_jobs = 1
 
 data_path_base = 'vtest_new2'
 
+mean = np.array([0.485, 0.456, 0.406])
+std = np.array([0.229, 0.224, 0.225])
+
 # test end-to-end
 experiment_test = {
     'data_path_base': {data_path_base},
@@ -17,11 +22,11 @@ experiment_test = {
 
 experiments = {
     'data_path_base': {data_path_base},
-    'conv_model_name': {'resnet', None, 'densenet', 'squeezenet'},  # vgg
+    'conv_model_name': {'resnet', None},  # vgg
     'num_epochs': {10},
     'feature_extract': {True, False},
     'batch_size': {64},
-    'lr': {1e-3, 1e-2, 5e-4},
+    'lr': {1e-3, 1e-2},
     'use_vggish': {False, True},
     'momentum': {0.9, 0.95}
 }

+ 24 - 14
pipeline.py

@@ -7,11 +7,12 @@ from typing import List, Tuple
 import cv2
 import numpy as np
 import torch
+from PIL import Image
 from moviepy.editor import VideoFileClip
 from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
 from torchvision import transforms
-from PIL import Image
 
+import params
 import vggish_input
 
 VGGISH_FRAME_RATE = 0.96
@@ -47,7 +48,7 @@ class BuildDataset:
                  base_path: str,
                  videos_and_labels: List[Tuple[str, str]],
                  output_path: str,
-                 n_augment: int=1,
+                 n_augment: int = 1,
                  test_size: float = 1 / 3):
         assert 0 < test_size < 1
         self.videos_and_labels = videos_and_labels
@@ -57,15 +58,6 @@ class BuildDataset:
         self.n_augment = n_augment
 
         self.sets = ['train', 'val']
-        self.img_size = 224
-
-        self.transformer = transforms.Compose([
-            transforms.RandomResizedCrop(self.img_size),
-            transforms.RandomHorizontalFlip(),
-            transforms.ToTensor(),
-            # TODO: wtf?
-            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
-        ])
 
     def _get_set(self):
         return np.random.choice(self.sets, p=[1 - self.test_size, self.test_size])
@@ -88,13 +80,29 @@ class BuildDataset:
             target = f"{self.output_path}/{set_}/{label}_{name}.pkl"
             pickle.dump((audio, images, label), open(target, 'wb'))
 
-    def one_video_extract_audio_and_stills(self, path_video: str) -> Tuple[List[torch.Tensor],
-                                                                     List[torch.Tensor]]:
+    @staticmethod
+    def transform_reverse(img: torch.Tensor) -> Image:
+        return transforms.Compose([
+            transforms.Normalize(mean=[0, 0, 0], std=(1.0 / params.std).tolist()),
+            transforms.Normalize(mean=(-params.mean).tolist(), std=[1, 1, 1]),
+            transforms.ToPILImage()])(img)
+
+    @staticmethod
+    def one_video_extract_audio_and_stills(path_video: str,
+                                           img_size: int = 224) -> Tuple[List[torch.Tensor],
+                                                                         List[torch.Tensor]]:
         # return a list of image(s), audio tensors
         cap = cv2.VideoCapture(path_video)
         frame_rate = cap.get(5)
         images = []
 
+        transformer = transforms.Compose([
+            transforms.RandomResizedCrop(img_size),
+            transforms.RandomHorizontalFlip(),
+            transforms.ToTensor(),
+            transforms.Normalize(params.mean, params.std)
+        ])
+
         # process the image
         while cap.isOpened():
             frame_id = cap.get(1)
@@ -106,7 +114,8 @@ class BuildDataset:
 
             if frame_id % math.floor(frame_rate * VGGISH_FRAME_RATE) == 0:
                 frame_pil = Image.fromarray(frame, mode='RGB')
-                images += [self.transformer(frame_pil) for _ in range(self.n_augment)]
+                images.append(transformer(frame_pil))
+                # images += [transformer(frame_pil) for _ in range(self.n_augment)]
 
         cap.release()
 
@@ -123,6 +132,7 @@ class BuildDataset:
 
         min_sizes = min(audio.shape[0], len(images))
         audio = [torch.from_numpy(audio[idx][None, :, :]).float() for idx in range(min_sizes)]
+        images = images[:min_sizes]
         # images = [torch.from_numpy(img).permute((2, 0, 1)) for img in images[:min_sizes]]
 
         return audio, images

+ 1 - 0
requirements.txt

@@ -8,3 +8,4 @@ moviepy
 opencv-python
 joblib
 pandas
+matplotlib

+ 121 - 2
segmentor.py

@@ -1,5 +1,124 @@
 from typing import List, Tuple
+import random
 
+import matplotlib.pyplot as plt
 
-def segmentor(scenes: List[bool], min_frames: int, threshold: float) -> List[Tuple[int, int]]:
-    return [(1, 5), (8, 30)]
+import numpy as np
+import torch
+from PIL.Image import Image
+from matplotlib.pyplot import figure, imshow, axis
+from torch import nn
+
+from pipeline import BuildDataset
+
+# images constituting a segments and the length in seconds
+Segment = Tuple[List[Image], int]
+
+
+class Segmentor:
+    def __init__(self,
+                 model: nn.Module,
+                 min_frames: int,
+                 threshold: float):
+        self.model = model
+        self.min_frames = min_frames
+        self.threshold = threshold
+
+    @staticmethod
+    def _segmentor(preds: List[int],
+                   min_frames: int,
+                   threshold: float) -> List[List[int]]:
+        candidates = []
+
+        n = len(preds)
+
+        for idx_start in range(n):
+            if preds[idx_start] == 1:
+                if n - idx_start >= min_frames:
+                    best_here = (-1, (-1, -1))
+                    for idx_end in range(idx_start + min_frames - 1, len(preds)):
+                        if preds[idx_end] == 1:
+                            if np.mean(preds[idx_start:idx_end + 1]) >= threshold:
+                                frames = idx_end - idx_start + 1
+                                endpoints = (idx_start, idx_end)
+                                if frames > best_here[0]:
+                                    best_here = (frames, endpoints)
+                    if best_here[0] > 0:
+                        candidates.append(best_here[1])
+
+        overlap = True
+        while overlap:
+            overlap = False
+            for i in range(len(candidates)):
+                ref_idx_start, ref_idx_end = candidates[i]
+
+                for j in range(i + 1, len(candidates)):
+                    comp_idx_start, comp_idx_end = candidates[j]
+                    if ref_idx_start <= comp_idx_end <= ref_idx_end or ref_idx_start <= comp_idx_start <= ref_idx_end:
+                        # overlapping, take the longer one
+                        if comp_idx_end - comp_idx_end > ref_idx_end - ref_idx_start:
+                            del candidates[i]
+                        else:
+                            del candidates[j]
+                        overlap = True
+
+                    if overlap:
+                        break
+
+                if overlap:
+                    break
+
+        return [list(range(idx_start, idx_end + 1)) for idx_start, idx_end in candidates]
+
+    @staticmethod
+    def _torch_img_to_pil(img: torch.Tensor) -> Image:
+        return BuildDataset.transform_reverse(img)
+
+    @staticmethod
+    def _get_segment_len(indices: List[int]):
+        return max(indices) - min(indices) + 1
+
+    def segmentor(self, preds: List[int], images: List[torch.Tensor]) -> List[Segment]:
+        segment_list = self._segmentor(preds, self.min_frames, self.threshold)
+        return [
+            ([self._torch_img_to_pil(images[idx])
+              for idx in segment_idx], self._get_segment_len(segment_idx))
+            for segment_idx in segment_list]
+
+    def _predict(self, audio: torch.Tensor, image: torch.Tensor) -> int:
+        return int(torch.max(self.model(audio.unsqueeze(0), image.unsqueeze(0)), 1)[1][0])
+
+    def get_segments(self, path_video: str) -> List[Segment]:
+        audio, images = BuildDataset.one_video_extract_audio_and_stills(path_video)
+        preds = [self._predict(audio[idx], images[idx]) for idx in range(len(images))]
+        return self.segmentor(preds, images)
+
+    @staticmethod
+    def show_images_horizontally(images: List[Image]) -> None:
+        # https://stackoverflow.com/questions/36006136/how-to-display-images-in-a-row-with-ipython-display
+        fig = figure(figsize=(20, 20))
+        number_of_files = len(images)
+        for i in range(number_of_files):
+            a = fig.add_subplot(1, number_of_files, i + 1)
+            image = images[i]
+            imshow(image)
+            axis('off')
+        plt.show()
+
+    def visualize_segments(self, path_video: str, n_to_show: int=10) -> None:
+        segments = self.get_segments(path_video)
+        n_segments = len(segments)
+        print(f'Found {len(segments)} segments')
+
+        if n_segments > 0:
+            for i, (segment_images, segment_len) in enumerate(segments):
+                print(f'Segment {i + 1}, {segment_len} seconds')
+                print(f'First {n_to_show}')
+                self.show_images_horizontally(segment_images[:n_to_show])
+
+                print(f'{n_to_show} random shots')
+                self.show_images_horizontally(random.sample(segment_images, n_to_show))
+
+                print('Last 10')
+                self.show_images_horizontally(segment_images[-n_to_show:])
+                print('=' * 10)

+ 21 - 0
test_segmentor.py

@@ -0,0 +1,21 @@
+import unittest
+
+from segmentor import Segmentor
+
+
+class TestSegmentor(unittest.TestCase):
+    def test_segmentor(self):
+        tests = [
+            ([1, 1, 0, 1, 0, 0, 1, 0, 1], [[0, 1, 2, 3]]),
+            ([1, 1, 0, 1, 1, 0, 1, 0, 1], [[0, 1, 2, 3, 4, 5, 6]]),
+            ([1, 1, 0, 1, 1, 0, 1, 1, 1], [list(range(0, 8 + 1))]),
+            ([1, 1, 1, 0, 1], [[0, 1, 2, 3, 4]]),
+            ([0, 0, 0, 0, 0], []),
+            ([1] * 7 + [0] * 3, [list(range(7))])
+        ]
+
+        min_frames = 4
+        threshold = 0.7
+
+        for ex, exp in tests:
+            self.assertEqual(exp, Segmentor._segmentor(ex, min_frames, threshold))