|
@@ -25,19 +25,20 @@ def _get_params_to_update(model: nn.Module,
|
|
|
return params_to_update
|
|
|
|
|
|
|
|
|
-def train_kd(model_name: str,
|
|
|
+def train_kd(data_path_base: str,
|
|
|
+ model_name: str,
|
|
|
num_epochs: int,
|
|
|
feature_extract: bool,
|
|
|
batch_size: int,
|
|
|
- num_workers: int=4,
|
|
|
- shuffle: bool=True,
|
|
|
- lr: float=0.001,
|
|
|
- momentum: float=0.9) -> Tuple[nn.Module, List[torch.Tensor]]:
|
|
|
+ num_workers: int = 4,
|
|
|
+ shuffle: bool = True,
|
|
|
+ lr: float = 0.001,
|
|
|
+ momentum: float = 0.9) -> Tuple[nn.Module, List[torch.Tensor]]:
|
|
|
num_classes = 2
|
|
|
kd = KissingDetector(model_name, num_classes, feature_extract)
|
|
|
params_to_update = _get_params_to_update(kd, feature_extract)
|
|
|
|
|
|
- datasets = {x: AudioVideo(x) for x in ['train', 'val']}
|
|
|
+ datasets = {set_: AudioVideo(f'{data_path_base}/{set_}') for set_ in ['train', 'val']}
|
|
|
dataloaders_dict = {x: torch.utils.data.DataLoader(datasets[x],
|
|
|
batch_size=batch_size,
|
|
|
shuffle=shuffle, num_workers=num_workers)
|
|
@@ -60,6 +61,7 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
|
|
|
|
|
|
best_model_wts = copy.deepcopy(model.state_dict())
|
|
|
best_acc = 0.0
|
|
|
+ best_f1 = 0.0
|
|
|
|
|
|
# Detect if we have a GPU available
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
@@ -77,6 +79,9 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
|
|
|
|
|
|
running_loss = 0.0
|
|
|
running_corrects = 0
|
|
|
+ running_tp = 0
|
|
|
+ running_fp = 0
|
|
|
+ running_fn = 0
|
|
|
|
|
|
# Iterate over data.
|
|
|
for a, v, labels in dataloaders[phase]:
|
|
@@ -114,15 +119,27 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
|
|
|
# statistics
|
|
|
running_loss += loss.item() * a.size(0)
|
|
|
running_corrects += torch.sum(preds == labels.data)
|
|
|
+ running_tp += torch.sum((preds == labels.data)[labels.data == 1])
|
|
|
+ running_fp += torch.sum((preds != labels.data)[labels.data == 1])
|
|
|
+ running_fn += torch.sum((preds != labels.data)[labels.data == 0])
|
|
|
|
|
|
epoch_loss = running_loss / len(dataloaders[phase].dataset)
|
|
|
- epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
|
|
|
+ n = len(dataloaders[phase].dataset)
|
|
|
+ epoch_acc = running_corrects.double() / n
|
|
|
+ tp = running_tp.double()
|
|
|
+ fp = running_fp.double()
|
|
|
+ fn = running_fn.double()
|
|
|
+ p = tp / (tp + fp)
|
|
|
+ r = tp / (tp + fn)
|
|
|
+ epoch_f1 = 2 * p * r / (p + r)
|
|
|
|
|
|
- print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
|
|
|
+ print('{} Loss: {:.4f} F1: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_f1, epoch_acc))
|
|
|
|
|
|
# deep copy the model
|
|
|
if phase == 'val' and epoch_acc > best_acc:
|
|
|
best_acc = epoch_acc
|
|
|
+ if phase == 'val' and epoch_f1 > best_f1:
|
|
|
+ best_f1 = epoch_f1
|
|
|
best_model_wts = copy.deepcopy(model.state_dict())
|
|
|
if phase == 'val':
|
|
|
val_acc_history.append(epoch_acc)
|
|
@@ -131,6 +148,7 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
|
|
|
|
|
|
time_elapsed = time.time() - since
|
|
|
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
|
|
|
+ print('Best val F1 : {:4f}'.format(best_f1))
|
|
|
print('Best val Acc: {:4f}'.format(best_acc))
|
|
|
|
|
|
# load best model weights
|