vggish_input.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # Copyright 2017 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Compute input examples for VGGish from audio waveform."""
  16. import mel_features
  17. import numpy as np
  18. import resampy
  19. import soundfile as sf
  20. import vggish_params
  21. def waveform_to_examples(data, sample_rate):
  22. """Converts audio waveform into an array of examples for VGGish.
  23. Args:
  24. data: np.array of either one dimension (mono) or two dimensions
  25. (multi-channel, with the outer dimension representing channels).
  26. Each sample is generally expected to lie in the range [-1.0, +1.0],
  27. although this is not required.
  28. sample_rate: Sample rate of data.
  29. Returns:
  30. 3-D np.array of shape [num_examples, num_frames, num_bands] which represents
  31. a sequence of examples, each of which contains a patch of log mel
  32. spectrogram, covering num_frames frames of audio and num_bands mel frequency
  33. bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
  34. """
  35. # Convert to mono.
  36. if len(data.shape) > 1:
  37. data = np.mean(data, axis=1)
  38. # Resample to the rate assumed by VGGish.
  39. if sample_rate != vggish_params.SAMPLE_RATE:
  40. data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
  41. # Compute log mel spectrogram features.
  42. log_mel = mel_features.log_mel_spectrogram(
  43. data,
  44. audio_sample_rate=vggish_params.SAMPLE_RATE,
  45. log_offset=vggish_params.LOG_OFFSET,
  46. window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
  47. hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
  48. num_mel_bins=vggish_params.NUM_MEL_BINS,
  49. lower_edge_hertz=vggish_params.MEL_MIN_HZ,
  50. upper_edge_hertz=vggish_params.MEL_MAX_HZ)
  51. # Frame features into examples.
  52. features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
  53. example_window_length = int(round(
  54. vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
  55. example_hop_length = int(round(
  56. vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
  57. log_mel_examples = mel_features.frame(
  58. log_mel,
  59. window_length=example_window_length,
  60. hop_length=example_hop_length)
  61. return log_mel_examples
  62. def wavfile_to_examples(wav_file):
  63. """Convenience wrapper around waveform_to_examples() for a common WAV format.
  64. Args:
  65. wav_file: String path to a file, or a file-like object. The file
  66. is assumed to contain WAV audio data with signed 16-bit PCM samples.
  67. Returns:
  68. See waveform_to_examples.
  69. """
  70. wav_data, sr = sf.read(wav_file, dtype='int16')
  71. assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
  72. samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
  73. return waveform_to_examples(samples, sr)