vggish.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # adapted from https://github.com/harritaylor/torchvggish
  2. from typing import Tuple
  3. import torch.nn as nn
  4. from torch import hub
  5. import conv
  6. VGGISH_WEIGHTS = (
  7. # "https://users.cs.cf.ac.uk/taylorh23/pytorch/models/vggish-cbfe8f1c.pth"
  8. 'https://users.cs.cf.ac.uk/taylorh23/pytorch/models/vggish-918c2d05.pth'
  9. )
  10. PCA_PARAMS = (
  11. "https://users.cs.cf.ac.uk/taylorh23/pytorch/models/vggish_pca_params-4d878af3.npz"
  12. )
  13. class VGGishParams:
  14. """
  15. These should not be changed. They have been added into this file for convenience.
  16. """
  17. NUM_FRAMES = (96,) # Frames in input mel-spectrogram patch.
  18. NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
  19. EMBEDDING_SIZE = 128 # Size of embedding layer.
  20. # Hyperparameters used in feature and example generation.
  21. SAMPLE_RATE = 16000
  22. STFT_WINDOW_LENGTH_SECONDS = 0.025
  23. STFT_HOP_LENGTH_SECONDS = 0.010
  24. NUM_MEL_BINS = NUM_BANDS
  25. MEL_MIN_HZ = 125
  26. MEL_MAX_HZ = 7500
  27. LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
  28. EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
  29. EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
  30. # Parameters used for embedding postprocessing.
  31. PCA_EIGEN_VECTORS_NAME = "pca_eigen_vectors"
  32. PCA_MEANS_NAME = "pca_means"
  33. QUANTIZE_MIN_VAL = -2.0
  34. QUANTIZE_MAX_VAL = +2.0
  35. """
  36. VGGish
  37. Input: 96x64 1-channel spectrogram
  38. Output: 128 Embedding
  39. """
  40. class VGGish(nn.Module):
  41. def __init__(self, feature_extract: bool):
  42. super(VGGish, self).__init__()
  43. self.features = nn.Sequential(
  44. nn.Conv2d(1, VGGishParams.NUM_BANDS, 3, 1, 1),
  45. nn.ReLU(inplace=True),
  46. nn.MaxPool2d(2, 2),
  47. nn.Conv2d(VGGishParams.NUM_BANDS, VGGishParams.EMBEDDING_SIZE, 3, 1, 1),
  48. nn.ReLU(inplace=True),
  49. nn.MaxPool2d(2, 2),
  50. nn.Conv2d(128, 256, 3, 1, 1),
  51. nn.ReLU(inplace=True),
  52. nn.Conv2d(256, 256, 3, 1, 1),
  53. nn.ReLU(inplace=True),
  54. nn.MaxPool2d(2, 2),
  55. nn.Conv2d(256, 512, 3, 1, 1),
  56. nn.ReLU(inplace=True),
  57. nn.Conv2d(512, 512, 3, 1, 1),
  58. nn.ReLU(inplace=True),
  59. nn.MaxPool2d(2, 2),
  60. )
  61. self.embeddings = nn.Sequential(
  62. nn.Linear(512 * 24, 4096),
  63. nn.ReLU(inplace=True),
  64. nn.Linear(4096, 4096),
  65. nn.ReLU(inplace=True),
  66. nn.Linear(4096, VGGishParams.EMBEDDING_SIZE),
  67. nn.ReLU(inplace=True),
  68. )
  69. conv.set_parameter_requires_grad(self.features, feature_extract)
  70. conv.set_parameter_requires_grad(self.embeddings, feature_extract)
  71. def forward(self, x):
  72. x = self.features(x)
  73. x = x.view(x.size(0), -1)
  74. x = self.embeddings(x)
  75. return x
  76. def vggish(feature_extract: bool) -> Tuple[VGGish, int]:
  77. """
  78. VGGish is a PyTorch implementation of Tensorflow's VGGish architecture used to create embeddings
  79. for Audioset. It produces a 128-d embedding of a 96ms slice of audio. Always comes pretrained.
  80. """
  81. model = VGGish(feature_extract)
  82. model.load_state_dict(hub.load_state_dict_from_url(VGGISH_WEIGHTS), strict=True)
  83. return model, VGGishParams.EMBEDDING_SIZE