train.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import copy
  2. import time
  3. from typing import List, Tuple
  4. import torch
  5. import torch.optim as optim
  6. from torch import nn
  7. from data import AudioVideo
  8. from kissing_detector import KissingDetector
  9. def _get_params_to_update(model: nn.Module,
  10. feature_extract: bool) -> List[nn.parameter.Parameter]:
  11. params_to_update = model.parameters()
  12. if feature_extract:
  13. print('Params to update')
  14. params_to_update = []
  15. for name, param in model.named_parameters():
  16. if param.requires_grad is True:
  17. params_to_update.append(param)
  18. print("*", name)
  19. else:
  20. print('Updating ALL params')
  21. return params_to_update
  22. def train_kd(model_name: str,
  23. num_epochs: int,
  24. feature_extract: bool,
  25. batch_size: int,
  26. num_workers: int=4,
  27. shuffle: bool=True,
  28. lr: float=0.001,
  29. momentum: float=0.9) -> Tuple[nn.Module, List[torch.Tensor]]:
  30. num_classes = 2
  31. kd = KissingDetector(model_name, num_classes, feature_extract)
  32. params_to_update = _get_params_to_update(kd, feature_extract)
  33. datasets = {x: AudioVideo(x) for x in ['train', 'val']}
  34. dataloaders_dict = {x: torch.utils.data.DataLoader(datasets[x],
  35. batch_size=batch_size,
  36. shuffle=shuffle, num_workers=num_workers)
  37. for x in ['train', 'val']}
  38. optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum=momentum)
  39. # Setup the loss fxn
  40. criterion = nn.CrossEntropyLoss()
  41. model_ft, hist = train_model(kd,
  42. dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs,
  43. is_inception=(model_name == "inception"))
  44. return model_ft, hist
  45. def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
  46. since = time.time()
  47. val_acc_history = []
  48. best_model_wts = copy.deepcopy(model.state_dict())
  49. best_acc = 0.0
  50. # Detect if we have a GPU available
  51. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  52. for epoch in range(num_epochs):
  53. print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  54. print('-' * 10)
  55. # Each epoch has a training and validation phase
  56. for phase in ['train', 'val']:
  57. if phase == 'train':
  58. model.train() # Set model to training mode
  59. else:
  60. model.eval() # Set model to evaluate mode
  61. running_loss = 0.0
  62. running_corrects = 0
  63. # Iterate over data.
  64. for a, v, labels in dataloaders[phase]:
  65. a = a.to(device)
  66. v = v.to(device)
  67. labels = labels.to(device)
  68. # zero the parameter gradients
  69. optimizer.zero_grad()
  70. # forward
  71. # track history if only in train
  72. with torch.set_grad_enabled(phase == 'train'):
  73. # Get model outputs and calculate loss
  74. # Special case for inception because in training it has an auxiliary output. In train
  75. # mode we calculate the loss by summing the final output and the auxiliary output
  76. # but in testing we only consider the final output.
  77. if is_inception and phase == 'train':
  78. # https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
  79. outputs, aux_outputs = model(a, v)
  80. loss1 = criterion(outputs, labels)
  81. loss2 = criterion(aux_outputs, labels)
  82. loss = loss1 + 0.4 * loss2
  83. else:
  84. outputs = model(a, v)
  85. loss = criterion(outputs, labels)
  86. _, preds = torch.max(outputs, 1)
  87. # backward + optimize only if in training phase
  88. if phase == 'train':
  89. loss.backward()
  90. optimizer.step()
  91. # statistics
  92. running_loss += loss.item() * a.size(0)
  93. running_corrects += torch.sum(preds == labels.data)
  94. epoch_loss = running_loss / len(dataloaders[phase].dataset)
  95. epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
  96. print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
  97. # deep copy the model
  98. if phase == 'val' and epoch_acc > best_acc:
  99. best_acc = epoch_acc
  100. best_model_wts = copy.deepcopy(model.state_dict())
  101. if phase == 'val':
  102. val_acc_history.append(epoch_acc)
  103. print()
  104. time_elapsed = time.time() - since
  105. print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
  106. print('Best val Acc: {:4f}'.format(best_acc))
  107. # load best model weights
  108. model.load_state_dict(best_model_wts)
  109. return model, val_acc_history