|
@@ -1,16 +1,56 @@
|
|
import copy
|
|
import copy
|
|
import time
|
|
import time
|
|
|
|
+from typing import List, Tuple
|
|
|
|
|
|
import torch
|
|
import torch
|
|
import torch.optim as optim
|
|
import torch.optim as optim
|
|
from torch import nn
|
|
from torch import nn
|
|
|
|
|
|
-# TODO: get these properly
|
|
|
|
-device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
-feature_extract = True
|
|
|
|
-model_ft = None # TODO
|
|
|
|
-dataloaders_dict = None # TODO
|
|
|
|
-model_name = None # TODO
|
|
|
|
|
|
+from data import AudioVideo
|
|
|
|
+from kissing_detector import KissingDetector
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def _get_params_to_update(model: nn.Module,
|
|
|
|
+ feature_extract: bool) -> List[nn.parameter.Parameter]:
|
|
|
|
+ params_to_update = model.parameters()
|
|
|
|
+ if feature_extract:
|
|
|
|
+ print('Params to update')
|
|
|
|
+ params_to_update = []
|
|
|
|
+ for name, param in model.named_parameters():
|
|
|
|
+ if param.requires_grad is True:
|
|
|
|
+ params_to_update.append(param)
|
|
|
|
+ print("*", name)
|
|
|
|
+ else:
|
|
|
|
+ print('Updating ALL params')
|
|
|
|
+ return params_to_update
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def train_kd(model_name: str,
|
|
|
|
+ num_epochs: int,
|
|
|
|
+ feature_extract: bool,
|
|
|
|
+ batch_size: int,
|
|
|
|
+ num_workers: int=4,
|
|
|
|
+ shuffle: bool=True,
|
|
|
|
+ lr: float=0.001,
|
|
|
|
+ momentum: float=0.9) -> Tuple[nn.Module, List[torch.Tensor]]:
|
|
|
|
+ num_classes = 2
|
|
|
|
+ kd = KissingDetector(model_name, num_classes, feature_extract)
|
|
|
|
+ params_to_update = _get_params_to_update(kd, feature_extract)
|
|
|
|
+
|
|
|
|
+ datasets = {x: AudioVideo(x) for x in ['train', 'val']}
|
|
|
|
+ dataloaders_dict = {x: torch.utils.data.DataLoader(datasets[x],
|
|
|
|
+ batch_size=batch_size,
|
|
|
|
+ shuffle=shuffle, num_workers=num_workers)
|
|
|
|
+ for x in ['train', 'val']}
|
|
|
|
+ optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum=momentum)
|
|
|
|
+
|
|
|
|
+ # Setup the loss fxn
|
|
|
|
+ criterion = nn.CrossEntropyLoss()
|
|
|
|
+
|
|
|
|
+ model_ft, hist = train_model(kd,
|
|
|
|
+ dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs,
|
|
|
|
+ is_inception=(model_name == "inception"))
|
|
|
|
+ return model_ft, hist
|
|
|
|
|
|
|
|
|
|
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
|
|
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
|
|
@@ -39,8 +79,9 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
|
|
running_corrects = 0
|
|
running_corrects = 0
|
|
|
|
|
|
# Iterate over data.
|
|
# Iterate over data.
|
|
- for inputs, labels in dataloaders[phase]:
|
|
|
|
- inputs = inputs.to(device)
|
|
|
|
|
|
+ for a, v, labels in dataloaders[phase]:
|
|
|
|
+ a = a.to(device)
|
|
|
|
+ v = v.to(device)
|
|
labels = labels.to(device)
|
|
labels = labels.to(device)
|
|
|
|
|
|
# zero the parameter gradients
|
|
# zero the parameter gradients
|
|
@@ -55,12 +96,12 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
|
|
# but in testing we only consider the final output.
|
|
# but in testing we only consider the final output.
|
|
if is_inception and phase == 'train':
|
|
if is_inception and phase == 'train':
|
|
# https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
|
|
# https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
|
|
- outputs, aux_outputs = model(inputs)
|
|
|
|
|
|
+ outputs, aux_outputs = model(a, v)
|
|
loss1 = criterion(outputs, labels)
|
|
loss1 = criterion(outputs, labels)
|
|
loss2 = criterion(aux_outputs, labels)
|
|
loss2 = criterion(aux_outputs, labels)
|
|
loss = loss1 + 0.4 * loss2
|
|
loss = loss1 + 0.4 * loss2
|
|
else:
|
|
else:
|
|
- outputs = model(inputs)
|
|
|
|
|
|
+ outputs = model(a, v)
|
|
loss = criterion(outputs, labels)
|
|
loss = criterion(outputs, labels)
|
|
|
|
|
|
_, preds = torch.max(outputs, 1)
|
|
_, preds = torch.max(outputs, 1)
|
|
@@ -71,7 +112,7 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
|
|
optimizer.step()
|
|
optimizer.step()
|
|
|
|
|
|
# statistics
|
|
# statistics
|
|
- running_loss += loss.item() * inputs.size(0)
|
|
|
|
|
|
+ running_loss += loss.item() * a.size(0)
|
|
running_corrects += torch.sum(preds == labels.data)
|
|
running_corrects += torch.sum(preds == labels.data)
|
|
|
|
|
|
epoch_loss = running_loss / len(dataloaders[phase].dataset)
|
|
epoch_loss = running_loss / len(dataloaders[phase].dataset)
|
|
@@ -95,35 +136,3 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
|
|
# load best model weights
|
|
# load best model weights
|
|
model.load_state_dict(best_model_wts)
|
|
model.load_state_dict(best_model_wts)
|
|
return model, val_acc_history
|
|
return model, val_acc_history
|
|
-
|
|
|
|
-
|
|
|
|
-# Send the model to GPU
|
|
|
|
-model_ft = model_ft.to(device)
|
|
|
|
-
|
|
|
|
-# Gather the parameters to be optimized/updated in this run. If we are
|
|
|
|
-# finetuning we will be updating all parameters. However, if we are
|
|
|
|
-# doing feature extract method, we will only update the parameters
|
|
|
|
-# that we have just initialized, i.e. the parameters with requires_grad
|
|
|
|
-# is True.
|
|
|
|
-params_to_update = model_ft.parameters()
|
|
|
|
-print("Params to learn:")
|
|
|
|
-if feature_extract:
|
|
|
|
- params_to_update = []
|
|
|
|
- for name, param in model_ft.named_parameters():
|
|
|
|
- if param.requires_grad is True:
|
|
|
|
- params_to_update.append(param)
|
|
|
|
- print("\t", name)
|
|
|
|
-else:
|
|
|
|
- for name, param in model_ft.named_parameters():
|
|
|
|
- if param.requires_grad is True:
|
|
|
|
- print("\t", name)
|
|
|
|
-
|
|
|
|
-# Observe that all parameters are being optimized
|
|
|
|
-optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)
|
|
|
|
-
|
|
|
|
-# Setup the loss fxn
|
|
|
|
-criterion = nn.CrossEntropyLoss()
|
|
|
|
-
|
|
|
|
-# Train and evaluate
|
|
|
|
-model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs,
|
|
|
|
- is_inception=(model_name == "inception"))
|
|
|