train.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # adapted from PyTorch tutorials
  2. import copy
  3. import time
  4. from typing import List, Tuple, Optional
  5. import torch
  6. import torch.optim as optim
  7. from torch import nn
  8. from data import AudioVideo, AudioVideo3D
  9. from kissing_detector import KissingDetector, KissingDetector3DConv
  10. ExperimentResults = Tuple[Optional[nn.Module], List[float], List[float]]
  11. def _get_params_to_update(model: nn.Module,
  12. feature_extract: bool) -> List[nn.parameter.Parameter]:
  13. params_to_update = model.parameters()
  14. if feature_extract:
  15. print('Params to update')
  16. params_to_update = []
  17. for name, param in model.named_parameters():
  18. if param.requires_grad is True:
  19. params_to_update.append(param)
  20. print("*", name)
  21. else:
  22. print('Updating ALL params')
  23. return params_to_update
  24. def train_kd(data_path_base: str,
  25. conv_model_name: Optional[str],
  26. num_epochs: int,
  27. feature_extract: bool,
  28. batch_size: int,
  29. use_vggish: bool = True,
  30. num_workers: int = 4,
  31. shuffle: bool = True,
  32. lr: float = 0.001,
  33. momentum: float = 0.9,
  34. use_3d: bool = False) -> ExperimentResults:
  35. num_classes = 2
  36. try:
  37. if use_3d:
  38. kd = KissingDetector3DConv(num_classes, feature_extract, use_vggish)
  39. else:
  40. kd = KissingDetector(conv_model_name, num_classes, feature_extract, use_vggish=use_vggish)
  41. except ValueError:
  42. # if the combination is not valid
  43. return None, [-1.0], [-1.0]
  44. params_to_update = _get_params_to_update(kd, feature_extract)
  45. av = AudioVideo3D if use_3d else AudioVideo
  46. datasets = {set_: av(f'{data_path_base}/{set_}') for set_ in ['train', 'val']}
  47. dataloaders_dict = {x: torch.utils.data.DataLoader(datasets[x],
  48. batch_size=batch_size,
  49. shuffle=shuffle, num_workers=num_workers)
  50. for x in ['train', 'val']}
  51. # optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum=momentum)
  52. optimizer_ft = optim.Adam(params_to_update, lr=lr)
  53. # Setup the loss fxn
  54. criterion = nn.CrossEntropyLoss()
  55. return train_model(kd,
  56. dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs,
  57. is_inception=(conv_model_name == "inception"))
  58. def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
  59. since = time.time()
  60. val_acc_history = []
  61. val_f1_history = []
  62. best_model_wts = copy.deepcopy(model.state_dict())
  63. best_acc = 0.0
  64. best_f1 = 0.0
  65. # Detect if we have a GPU available
  66. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  67. for epoch in range(num_epochs):
  68. print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  69. print('-' * 10)
  70. # Each epoch has a training and validation phase
  71. for phase in ['train', 'val']:
  72. if phase == 'train':
  73. model.train() # Set model to training mode
  74. else:
  75. model.eval() # Set model to evaluate mode
  76. running_loss = 0.0
  77. running_corrects = 0
  78. running_tp = 0
  79. running_fp = 0
  80. running_fn = 0
  81. # Iterate over data.
  82. for a, v, labels in dataloaders[phase]:
  83. a = a.to(device)
  84. v = v.to(device)
  85. labels = labels.to(device)
  86. # zero the parameter gradients
  87. optimizer.zero_grad()
  88. # forward
  89. # track history if only in train
  90. with torch.set_grad_enabled(phase == 'train'):
  91. # Get model outputs and calculate loss
  92. # Special case for inception because in training it has an auxiliary output. In train
  93. # mode we calculate the loss by summing the final output and the auxiliary output
  94. # but in testing we only consider the final output.
  95. if is_inception and phase == 'train':
  96. # https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
  97. outputs, aux_outputs = model(a, v)
  98. loss1 = criterion(outputs, labels)
  99. loss2 = criterion(aux_outputs, labels)
  100. loss = loss1 + 0.4 * loss2
  101. else:
  102. outputs = model(a, v)
  103. loss = criterion(outputs, labels)
  104. _, preds = torch.max(outputs, 1)
  105. # backward + optimize only if in training phase
  106. if phase == 'train':
  107. loss.backward()
  108. optimizer.step()
  109. # statistics
  110. running_loss += loss.item() * a.size(0)
  111. running_corrects += torch.sum(preds == labels.data)
  112. running_tp += torch.sum((preds == labels.data)[labels.data == 1])
  113. running_fp += torch.sum((preds != labels.data)[labels.data == 1])
  114. running_fn += torch.sum((preds != labels.data)[labels.data == 0])
  115. epoch_loss = running_loss / len(dataloaders[phase].dataset)
  116. n = len(dataloaders[phase].dataset)
  117. epoch_acc = running_corrects.double() / n
  118. tp = running_tp.double()
  119. fp = running_fp.double()
  120. fn = running_fn.double()
  121. p = tp / (tp + fp)
  122. r = tp / (tp + fn)
  123. epoch_f1 = 2 * p * r / (p + r)
  124. print('{} Loss: {:.4f} F1: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_f1, epoch_acc))
  125. # deep copy the model
  126. if phase == 'val' and epoch_acc > best_acc:
  127. best_acc = epoch_acc
  128. if phase == 'val' and epoch_f1 > best_f1:
  129. best_f1 = epoch_f1
  130. best_model_wts = copy.deepcopy(model.state_dict())
  131. if phase == 'val':
  132. val_acc_history.append(float(epoch_acc))
  133. val_f1_history.append(float(epoch_f1))
  134. print()
  135. time_elapsed = time.time() - since
  136. print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
  137. print('Best val F1 : {:4f}'.format(best_f1))
  138. print('Best val Acc : {:4f}'.format(best_acc))
  139. # load best model weights
  140. model.load_state_dict(best_model_wts)
  141. return model, val_acc_history, val_f1_history