vggish.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from typing import Tuple
  2. import torch.nn as nn
  3. from torch import hub
  4. import conv
  5. VGGISH_WEIGHTS = (
  6. # "https://users.cs.cf.ac.uk/taylorh23/pytorch/models/vggish-cbfe8f1c.pth"
  7. 'https://users.cs.cf.ac.uk/taylorh23/pytorch/models/vggish-918c2d05.pth'
  8. )
  9. PCA_PARAMS = (
  10. "https://users.cs.cf.ac.uk/taylorh23/pytorch/models/vggish_pca_params-4d878af3.npz"
  11. )
  12. class VGGishParams:
  13. """
  14. These should not be changed. They have been added into this file for convenience.
  15. """
  16. NUM_FRAMES = (96,) # Frames in input mel-spectrogram patch.
  17. NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
  18. EMBEDDING_SIZE = 128 # Size of embedding layer.
  19. # Hyperparameters used in feature and example generation.
  20. SAMPLE_RATE = 16000
  21. STFT_WINDOW_LENGTH_SECONDS = 0.025
  22. STFT_HOP_LENGTH_SECONDS = 0.010
  23. NUM_MEL_BINS = NUM_BANDS
  24. MEL_MIN_HZ = 125
  25. MEL_MAX_HZ = 7500
  26. LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
  27. EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
  28. EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
  29. # Parameters used for embedding postprocessing.
  30. PCA_EIGEN_VECTORS_NAME = "pca_eigen_vectors"
  31. PCA_MEANS_NAME = "pca_means"
  32. QUANTIZE_MIN_VAL = -2.0
  33. QUANTIZE_MAX_VAL = +2.0
  34. """
  35. VGGish
  36. Input: 96x64 1-channel spectrogram
  37. Output: 128 Embedding
  38. """
  39. class VGGish(nn.Module):
  40. def __init__(self, feature_extract: bool):
  41. super(VGGish, self).__init__()
  42. self.features = nn.Sequential(
  43. nn.Conv2d(1, VGGishParams.NUM_BANDS, 3, 1, 1),
  44. nn.ReLU(inplace=True),
  45. nn.MaxPool2d(2, 2),
  46. nn.Conv2d(VGGishParams.NUM_BANDS, VGGishParams.EMBEDDING_SIZE, 3, 1, 1),
  47. nn.ReLU(inplace=True),
  48. nn.MaxPool2d(2, 2),
  49. nn.Conv2d(128, 256, 3, 1, 1),
  50. nn.ReLU(inplace=True),
  51. nn.Conv2d(256, 256, 3, 1, 1),
  52. nn.ReLU(inplace=True),
  53. nn.MaxPool2d(2, 2),
  54. nn.Conv2d(256, 512, 3, 1, 1),
  55. nn.ReLU(inplace=True),
  56. nn.Conv2d(512, 512, 3, 1, 1),
  57. nn.ReLU(inplace=True),
  58. nn.MaxPool2d(2, 2),
  59. )
  60. self.embeddings = nn.Sequential(
  61. nn.Linear(512 * 24, 4096),
  62. nn.ReLU(inplace=True),
  63. nn.Linear(4096, 4096),
  64. nn.ReLU(inplace=True),
  65. nn.Linear(4096, VGGishParams.EMBEDDING_SIZE),
  66. nn.ReLU(inplace=True),
  67. )
  68. conv.set_parameter_requires_grad(self.features, feature_extract)
  69. conv.set_parameter_requires_grad(self.embeddings, feature_extract)
  70. def forward(self, x):
  71. x = self.features(x)
  72. x = x.view(x.size(0), -1)
  73. x = self.embeddings(x)
  74. return x
  75. def vggish(feature_extract: bool) -> Tuple[VGGish, int]:
  76. """
  77. VGGish is a PyTorch implementation of Tensorflow's VGGish architecture used to create embeddings
  78. for Audioset. It produces a 128-d embedding of a 96ms slice of audio. Always comes pretrained.
  79. """
  80. model = VGGish(feature_extract)
  81. model.load_state_dict(hub.load_state_dict_from_url(VGGISH_WEIGHTS), strict=True)
  82. return model, VGGishParams.EMBEDDING_SIZE