vggish.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from typing import Tuple
  2. import torch.nn as nn
  3. from torch import hub
  4. VGGISH_WEIGHTS = (
  5. "https://users.cs.cf.ac.uk/taylorh23/pytorch/models/vggish-cbfe8f1c.pth"
  6. )
  7. PCA_PARAMS = (
  8. "https://users.cs.cf.ac.uk/taylorh23/pytorch/models/vggish_pca_params-4d878af3.npz"
  9. )
  10. class VGGishParams:
  11. """
  12. These should not be changed. They have been added into this file for convenience.
  13. """
  14. NUM_FRAMES = (96,) # Frames in input mel-spectrogram patch.
  15. NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
  16. EMBEDDING_SIZE = 128 # Size of embedding layer.
  17. # Hyperparameters used in feature and example generation.
  18. SAMPLE_RATE = 16000
  19. STFT_WINDOW_LENGTH_SECONDS = 0.025
  20. STFT_HOP_LENGTH_SECONDS = 0.010
  21. NUM_MEL_BINS = NUM_BANDS
  22. MEL_MIN_HZ = 125
  23. MEL_MAX_HZ = 7500
  24. LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
  25. EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
  26. EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
  27. # Parameters used for embedding postprocessing.
  28. PCA_EIGEN_VECTORS_NAME = "pca_eigen_vectors"
  29. PCA_MEANS_NAME = "pca_means"
  30. QUANTIZE_MIN_VAL = -2.0
  31. QUANTIZE_MAX_VAL = +2.0
  32. """
  33. VGGish
  34. Input: 96x64 1-channel spectrogram
  35. Output: 128 Embedding
  36. """
  37. class VGGish(nn.Module):
  38. def __init__(self):
  39. super(VGGish, self).__init__()
  40. self.features = nn.Sequential(
  41. nn.Conv2d(1, VGGishParams.NUM_BANDS, 3, 1, 1),
  42. nn.ReLU(inplace=True),
  43. nn.MaxPool2d(2, 2),
  44. nn.Conv2d(VGGishParams.NUM_BANDS, VGGishParams.EMBEDDING_SIZE, 3, 1, 1),
  45. nn.ReLU(inplace=True),
  46. nn.MaxPool2d(2, 2),
  47. nn.Conv2d(128, 256, 3, 1, 1),
  48. nn.ReLU(inplace=True),
  49. nn.Conv2d(256, 256, 3, 1, 1),
  50. nn.ReLU(inplace=True),
  51. nn.MaxPool2d(2, 2),
  52. nn.Conv2d(256, 512, 3, 1, 1),
  53. nn.ReLU(inplace=True),
  54. nn.Conv2d(512, 512, 3, 1, 1),
  55. nn.ReLU(inplace=True),
  56. nn.MaxPool2d(2, 2),
  57. )
  58. self.embeddings = nn.Sequential(
  59. nn.Linear(512 * 24, 4096),
  60. nn.ReLU(inplace=True),
  61. nn.Linear(4096, 4096),
  62. nn.ReLU(inplace=True),
  63. nn.Linear(4096, VGGishParams.EMBEDDING_SIZE),
  64. nn.ReLU(inplace=True),
  65. )
  66. def forward(self, x):
  67. x = self.features(x)
  68. x = x.view(x.size(0), -1)
  69. x = self.embeddings(x)
  70. return x
  71. def vggish() -> Tuple[VGGish, int]:
  72. """
  73. VGGish is a PyTorch implementation of Tensorflow's VGGish architecture used to create embeddings
  74. for Audioset. It produces a 128-d embedding of a 96ms slice of audio. Always comes pretrained.
  75. """
  76. model = VGGish()
  77. model.load_state_dict(hub.load_state_dict_from_url(VGGISH_WEIGHTS), strict=True)
  78. return model, VGGishParams.EMBEDDING_SIZE