kissing_detector.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import torch
  2. from torch import nn
  3. import vggish
  4. from conv import convnet_init
  5. from typing import Optional
  6. class KissingDetector(nn.Module):
  7. def __init__(self,
  8. conv_model_name: Optional[str],
  9. num_classes: int,
  10. feature_extract: bool,
  11. use_pretrained: bool = True,
  12. use_vggish: bool = True):
  13. super(KissingDetector, self).__init__()
  14. conv_output_size = 0
  15. vggish_output_size = 0
  16. conv_input_size = 0
  17. conv = None
  18. vggish_model = None
  19. if conv_model_name:
  20. conv, conv_input_size, conv_output_size = convnet_init(conv_model_name,
  21. num_classes,
  22. feature_extract,
  23. use_pretrained)
  24. if use_vggish:
  25. vggish_model, vggish_output_size = vggish.vggish(feature_extract)
  26. if not conv and not vggish_model:
  27. raise ValueError("Use VGGish, Conv, or both")
  28. self.conv_input_size = conv_input_size
  29. self.conv = conv
  30. self.vggish = vggish_model
  31. self.combined = nn.Linear(vggish_output_size + conv_output_size, num_classes)
  32. def forward(self, audio: torch.Tensor, image: torch.Tensor):
  33. a = self.vggish(audio) if self.vggish else None
  34. c = self.conv(image) if self.conv else None
  35. if a is not None and c is not None:
  36. combined = torch.cat((c.view(c.size(0), -1), a.view(a.size(0), -1)), dim=1)
  37. else:
  38. combined = a if a is not None else c
  39. return self.combined(combined)