vggish.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from typing import Tuple
  2. import conv
  3. import torch.nn as nn
  4. from torch import hub
  5. VGGISH_WEIGHTS = (
  6. "https://users.cs.cf.ac.uk/taylorh23/pytorch/models/vggish-cbfe8f1c.pth"
  7. )
  8. PCA_PARAMS = (
  9. "https://users.cs.cf.ac.uk/taylorh23/pytorch/models/vggish_pca_params-4d878af3.npz"
  10. )
  11. class VGGishParams:
  12. """
  13. These should not be changed. They have been added into this file for convenience.
  14. """
  15. NUM_FRAMES = (96,) # Frames in input mel-spectrogram patch.
  16. NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
  17. EMBEDDING_SIZE = 128 # Size of embedding layer.
  18. # Hyperparameters used in feature and example generation.
  19. SAMPLE_RATE = 16000
  20. STFT_WINDOW_LENGTH_SECONDS = 0.025
  21. STFT_HOP_LENGTH_SECONDS = 0.010
  22. NUM_MEL_BINS = NUM_BANDS
  23. MEL_MIN_HZ = 125
  24. MEL_MAX_HZ = 7500
  25. LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
  26. EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
  27. EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
  28. # Parameters used for embedding postprocessing.
  29. PCA_EIGEN_VECTORS_NAME = "pca_eigen_vectors"
  30. PCA_MEANS_NAME = "pca_means"
  31. QUANTIZE_MIN_VAL = -2.0
  32. QUANTIZE_MAX_VAL = +2.0
  33. """
  34. VGGish
  35. Input: 96x64 1-channel spectrogram
  36. Output: 128 Embedding
  37. """
  38. class VGGish(nn.Module):
  39. def __init__(self, feature_extract: bool):
  40. super(VGGish, self).__init__()
  41. self.features = nn.Sequential(
  42. nn.Conv2d(1, VGGishParams.NUM_BANDS, 3, 1, 1),
  43. nn.ReLU(inplace=True),
  44. nn.MaxPool2d(2, 2),
  45. nn.Conv2d(VGGishParams.NUM_BANDS, VGGishParams.EMBEDDING_SIZE, 3, 1, 1),
  46. nn.ReLU(inplace=True),
  47. nn.MaxPool2d(2, 2),
  48. nn.Conv2d(128, 256, 3, 1, 1),
  49. nn.ReLU(inplace=True),
  50. nn.Conv2d(256, 256, 3, 1, 1),
  51. nn.ReLU(inplace=True),
  52. nn.MaxPool2d(2, 2),
  53. nn.Conv2d(256, 512, 3, 1, 1),
  54. nn.ReLU(inplace=True),
  55. nn.Conv2d(512, 512, 3, 1, 1),
  56. nn.ReLU(inplace=True),
  57. nn.MaxPool2d(2, 2),
  58. )
  59. self.embeddings = nn.Sequential(
  60. nn.Linear(512 * 24, 4096),
  61. nn.ReLU(inplace=True),
  62. nn.Linear(4096, 4096),
  63. nn.ReLU(inplace=True),
  64. nn.Linear(4096, VGGishParams.EMBEDDING_SIZE),
  65. nn.ReLU(inplace=True),
  66. )
  67. conv.set_parameter_requires_grad(self.features, feature_extract)
  68. conv.set_parameter_requires_grad(self.embeddings, feature_extract)
  69. def forward(self, x):
  70. x = self.features(x)
  71. x = x.view(x.size(0), -1)
  72. x = self.embeddings(x)
  73. return x
  74. def vggish(feature_extract: bool) -> Tuple[VGGish, int]:
  75. """
  76. VGGish is a PyTorch implementation of Tensorflow's VGGish architecture used to create embeddings
  77. for Audioset. It produces a 128-d embedding of a 96ms slice of audio. Always comes pretrained.
  78. """
  79. model = VGGish(feature_extract)
  80. model.load_state_dict(hub.load_state_dict_from_url(VGGISH_WEIGHTS), strict=True)
  81. return model, VGGishParams.EMBEDDING_SIZE