123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- # Copyright 2017 The TensorFlow Authors All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Compute input examples for VGGish from audio waveform."""
- import mel_features
- import numpy as np
- import resampy
- import soundfile as sf
- import vggish_params
- def waveform_to_examples(data, sample_rate):
- """Converts audio waveform into an array of examples for VGGish.
- Args:
- data: np.array of either one dimension (mono) or two dimensions
- (multi-channel, with the outer dimension representing channels).
- Each sample is generally expected to lie in the range [-1.0, +1.0],
- although this is not required.
- sample_rate: Sample rate of data.
- Returns:
- 3-D np.array of shape [num_examples, num_frames, num_bands] which represents
- a sequence of examples, each of which contains a patch of log mel
- spectrogram, covering num_frames frames of audio and num_bands mel frequency
- bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
- """
- # Convert to mono.
- if len(data.shape) > 1:
- data = np.mean(data, axis=1)
- # Resample to the rate assumed by VGGish.
- if sample_rate != vggish_params.SAMPLE_RATE:
- data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
- # Compute log mel spectrogram features.
- log_mel = mel_features.log_mel_spectrogram(
- data,
- audio_sample_rate=vggish_params.SAMPLE_RATE,
- log_offset=vggish_params.LOG_OFFSET,
- window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
- hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
- num_mel_bins=vggish_params.NUM_MEL_BINS,
- lower_edge_hertz=vggish_params.MEL_MIN_HZ,
- upper_edge_hertz=vggish_params.MEL_MAX_HZ)
- # Frame features into examples.
- features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
- example_window_length = int(round(
- vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
- example_hop_length = int(round(
- vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
- log_mel_examples = mel_features.frame(
- log_mel,
- window_length=example_window_length,
- hop_length=example_hop_length)
- return log_mel_examples
- def wavfile_to_examples(wav_file):
- """Convenience wrapper around waveform_to_examples() for a common WAV format.
- Args:
- wav_file: String path to a file, or a file-like object. The file
- is assumed to contain WAV audio data with signed 16-bit PCM samples.
- Returns:
- See waveform_to_examples.
- """
- wav_data, sr = sf.read(wav_file, dtype='int16')
- assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
- samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
- return waveform_to_examples(samples, sr)
|