kissing_detector.py 1006 B

1234567891011121314151617181920212223
  1. import torch
  2. from torch import nn
  3. import vggish
  4. from conv import convnet_init
  5. class KissingDetector(nn.Module):
  6. def __init__(self, model_name: str, num_classes: int, feature_extract: bool, use_pretrained: bool = True):
  7. super(KissingDetector, self).__init__()
  8. conv, conv_input_size, conv_output_size = convnet_init(model_name, num_classes, feature_extract,
  9. use_pretrained=use_pretrained)
  10. vggish_model, vggish_output_size = vggish.vggish(feature_extract)
  11. self.conv_input_size = conv_input_size
  12. self.conv = conv
  13. self.vggish = vggish_model
  14. self.combined = nn.Linear(vggish_output_size + conv_output_size, num_classes)
  15. def forward(self, audio: torch.Tensor, image: torch.Tensor):
  16. a = self.vggish(audio)
  17. c = self.conv(image)
  18. combined = torch.cat((c.view(c.size(0), -1), a.view(a.size(0), -1)), dim=1)
  19. out = self.combined(combined)
  20. return out