kissing_detector.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from typing import Optional
  2. import torch
  3. from torch import nn
  4. import vggish
  5. from conv import convnet_init
  6. import conv3d
  7. class KissingDetector(nn.Module):
  8. def __init__(self,
  9. conv_model_name: Optional[str],
  10. num_classes: int,
  11. feature_extract: bool,
  12. use_pretrained: bool = True,
  13. use_vggish: bool = True):
  14. super(KissingDetector, self).__init__()
  15. conv_output_size = 0
  16. vggish_output_size = 0
  17. conv_input_size = 0
  18. conv = None
  19. vggish_model = None
  20. if conv_model_name:
  21. conv, conv_input_size, conv_output_size = convnet_init(conv_model_name,
  22. num_classes,
  23. feature_extract,
  24. use_pretrained)
  25. if use_vggish:
  26. vggish_model, vggish_output_size = vggish.vggish(feature_extract)
  27. if not conv and not vggish_model:
  28. raise ValueError("Use VGGish, Conv, or both")
  29. self.conv_input_size = conv_input_size
  30. self.conv = conv
  31. self.vggish = vggish_model
  32. self.combined = nn.Linear(vggish_output_size + conv_output_size, num_classes)
  33. def forward(self, audio: torch.Tensor, image: torch.Tensor):
  34. a = self.vggish(audio) if self.vggish else None
  35. c = self.conv(image) if self.conv else None
  36. if a is not None and c is not None:
  37. combined = torch.cat((c.view(c.size(0), -1), a.view(a.size(0), -1)), dim=1)
  38. else:
  39. combined = a if a is not None else c
  40. return self.combined(combined)
  41. class KissingDetector3DConv(nn.Module):
  42. def __init__(self,
  43. num_classes: int,
  44. feature_extract: bool,
  45. use_vggish: bool = True):
  46. super(KissingDetector3DConv, self).__init__()
  47. conv_output_size = 512
  48. vggish_output_size = 0
  49. conv_input_size = 0
  50. vggish_model = None
  51. conv = conv3d.resnet34(
  52. num_classes=num_classes,
  53. shortcut_type='B',
  54. sample_size=224,
  55. sample_duration=10
  56. )
  57. conv.fc = nn.Identity()
  58. if use_vggish:
  59. vggish_model, vggish_output_size = vggish.vggish(feature_extract)
  60. if not conv and not vggish_model:
  61. raise ValueError("Use VGGish, Conv, or both")
  62. self.conv_input_size = conv_input_size
  63. self.conv = conv
  64. self.vggish = vggish_model
  65. self.combined = nn.Linear(vggish_output_size + conv_output_size, num_classes)
  66. def forward(self, audio: torch.Tensor, image: torch.Tensor):
  67. a = self.vggish(audio) if self.vggish else None
  68. c = self.conv(image) if self.conv else None
  69. if a is not None and c is not None:
  70. combined = torch.cat((c.view(c.size(0), -1), a.view(a.size(0), -1)), dim=1)
  71. else:
  72. combined = a if a is not None else c
  73. return self.combined(combined)