kissing_detector.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from typing import Optional
  2. import torch
  3. from torch import nn
  4. import vggish
  5. from conv import convnet_init, set_parameter_requires_grad
  6. import conv3d
  7. class KissingDetector(nn.Module):
  8. def __init__(self,
  9. conv_model_name: Optional[str],
  10. num_classes: int,
  11. feature_extract: bool,
  12. use_pretrained: bool = True,
  13. use_vggish: bool = True):
  14. super(KissingDetector, self).__init__()
  15. conv_output_size = 0
  16. vggish_output_size = 0
  17. conv_input_size = 0
  18. conv = None
  19. vggish_model = None
  20. if conv_model_name:
  21. conv, conv_input_size, conv_output_size = convnet_init(conv_model_name,
  22. num_classes,
  23. feature_extract,
  24. use_pretrained)
  25. if use_vggish:
  26. vggish_model, vggish_output_size = vggish.vggish(feature_extract)
  27. if not conv and not vggish_model:
  28. raise ValueError("Use VGGish, Conv, or both")
  29. self.conv_input_size = conv_input_size
  30. self.conv = conv
  31. self.vggish = vggish_model
  32. self.combined = nn.Linear(vggish_output_size + conv_output_size, num_classes)
  33. def forward(self, audio: torch.Tensor, image: torch.Tensor):
  34. a = self.vggish(audio) if self.vggish else None
  35. c = self.conv(image) if self.conv else None
  36. if a is not None and c is not None:
  37. combined = torch.cat((c.view(c.size(0), -1), a.view(a.size(0), -1)), dim=1)
  38. else:
  39. combined = a if a is not None else c
  40. return self.combined(combined)
  41. class KissingDetector3DConv(nn.Module):
  42. def __init__(self,
  43. num_classes: int,
  44. feature_extract: bool,
  45. use_vggish: bool = True):
  46. super(KissingDetector3DConv, self).__init__()
  47. conv_output_size = 512
  48. vggish_output_size = 0
  49. conv_input_size = 0
  50. vggish_model = None
  51. conv = conv3d.resnet34(
  52. num_classes=num_classes,
  53. shortcut_type='B',
  54. sample_size=224,
  55. sample_duration=16
  56. )
  57. set_parameter_requires_grad(conv, feature_extract)
  58. conv.fc = nn.Identity()
  59. if use_vggish:
  60. vggish_model, vggish_output_size = vggish.vggish(feature_extract)
  61. if not conv and not vggish_model:
  62. raise ValueError("Use VGGish, Conv, or both")
  63. self.conv_input_size = conv_input_size
  64. self.conv = conv
  65. self.vggish = vggish_model
  66. self.combined = nn.Linear(vggish_output_size + conv_output_size, num_classes)
  67. def forward(self, audio: torch.Tensor, image: torch.Tensor):
  68. a = self.vggish(audio) if self.vggish else None
  69. c = self.conv(image) if self.conv else None
  70. if a is not None and c is not None:
  71. combined = torch.cat((c.view(c.size(0), -1), a.view(a.size(0), -1)), dim=1)
  72. else:
  73. combined = a if a is not None else c
  74. return self.combined(combined)