Browse Source

vggish and resnet combined, figuring out input

Amir Ziai 6 years ago
parent
commit
0059ff6654
9 changed files with 550 additions and 0 deletions
  1. 87 0
      conv.py
  2. 24 0
      kissing_detector.py
  3. 205 0
      mel_features.py
  4. 0 0
      pipeline.py
  5. 4 0
      requirements.txt
  6. 5 0
      segmentor.py
  7. 92 0
      vggish.py
  8. 81 0
      vggish_input.py
  9. 52 0
      vggish_params.py

+ 87 - 0
conv.py

@@ -0,0 +1,87 @@
+from torch import nn
+from torchvision import models
+import torch
+
+
+def set_parameter_requires_grad(model, feature_extracting):
+    if feature_extracting:
+        for param in model.parameters():
+            param.requires_grad = False
+
+
+def convnet_init(model_name, num_classes, feature_extract, use_pretrained=True):
+    # Initialize these variables which will be set in this if statement. Each of these
+    #   variables is model specific.
+    model_ft = None
+    input_size = 0
+    output_size = 0
+
+    if model_name == "resnet":
+        """ Resnet18
+        """
+        model_ft = models.resnet18(pretrained=use_pretrained)
+        set_parameter_requires_grad(model_ft, feature_extract)
+        num_ftrs = model_ft.fc.in_features
+        # model_ft.fc = nn.Linear(num_ftrs, num_classes)
+        model_ft.fc = nn.Identity()
+        input_size = 224
+        output_size = model_ft(torch.rand((1, 3, input_size, input_size))).shape[1]
+
+    elif model_name == "alexnet":
+        """ Alexnet
+        """
+        model_ft = models.alexnet(pretrained=use_pretrained)
+        set_parameter_requires_grad(model_ft, feature_extract)
+        num_ftrs = model_ft.classifier[6].in_features
+        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
+        input_size = 224
+
+    elif model_name == "vgg":
+        """ VGG11_bn
+        """
+        model_ft = models.vgg11_bn(pretrained=use_pretrained)
+        set_parameter_requires_grad(model_ft, feature_extract)
+        num_ftrs = model_ft.classifier[6].in_features
+        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
+        input_size = 224
+
+    elif model_name == "squeezenet":
+        """ Squeezenet
+        """
+        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
+        set_parameter_requires_grad(model_ft, feature_extract)
+        # TODO: this is my attempt to remove the last FC layer, doesn't seem to work for SqueezeNet
+        # model_ft.classifier = nn.Identity()
+        # model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
+        model_ft.classifier[1] = nn.Identity()
+        # model_ft.num_classes = num_classes
+        input_size = 224
+
+    elif model_name == "densenet":
+        """ Densenet
+        """
+        model_ft = models.densenet121(pretrained=use_pretrained)
+        set_parameter_requires_grad(model_ft, feature_extract)
+        num_ftrs = model_ft.classifier.in_features
+        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
+        input_size = 224
+
+    elif model_name == "inception":
+        """ Inception v3
+        Be careful, expects (299,299) sized images and has auxiliary output
+        """
+        model_ft = models.inception_v3(pretrained=use_pretrained)
+        set_parameter_requires_grad(model_ft, feature_extract)
+        # Handle the auxiliary net
+        num_ftrs = model_ft.AuxLogits.fc.in_features
+        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
+        # Handle the primary net
+        num_ftrs = model_ft.fc.in_features
+        model_ft.fc = nn.Linear(num_ftrs, num_classes)
+        input_size = 299
+
+    else:
+        print("Invalid model name, exiting...")
+        exit()
+
+    return model_ft, input_size, output_size

+ 24 - 0
kissing_detector.py

@@ -0,0 +1,24 @@
+import torch
+from torch import nn
+
+import vggish
+from conv import convnet_init
+
+
+class KissingDetector(nn.Module):
+    def __init__(self, model_name: str, num_classes: int, feature_extract: bool, use_pretrained: bool = True):
+        super(KissingDetector, self).__init__()
+        conv, conv_input_size, conv_output_size = convnet_init(model_name, num_classes, feature_extract,
+                                                               use_pretrained=use_pretrained)
+        vggish_model, vggish_output_size = vggish.vggish()
+        self.conv_input_size = conv_input_size
+        self.conv = conv
+        self.vggish = vggish_model
+        self.combined = nn.Linear(vggish_output_size + conv_output_size, num_classes)
+
+    def forward(self, audio: torch.Tensor, image: torch.Tensor):
+        a = self.vggish(audio)
+        c = self.conv(image)
+        combined = torch.cat((c.view(c.size(0), -1), a.view(a.size(0), -1)), dim=1)
+        out = self.combined(combined)
+        return out

+ 205 - 0
mel_features.py

@@ -0,0 +1,205 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Defines routines to compute mel spectrogram features from audio waveform."""
+
+import numpy as np
+
+
+def frame(data, window_length, hop_length):
+    """Convert array into a sequence of successive possibly overlapping frames.
+    An n-dimensional array of shape (num_samples, ...) is converted into an
+    (n+1)-D array of shape (num_frames, window_length, ...), where each frame
+    starts hop_length points after the preceding one.
+    This is accomplished using stride_tricks, so the original data is not
+    copied.  However, there is no zero-padding, so any incomplete frames at the
+    end are not included.
+    Args:
+      data: np.array of dimension N >= 1.
+      window_length: Number of samples in each frame.
+      hop_length: Advance (in samples) between each window.
+    Returns:
+      (N+1)-D np.array with as many rows as there are complete frames that can be
+      extracted.
+    """
+    num_samples = data.shape[0]
+    num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
+    shape = (num_frames, window_length) + data.shape[1:]
+    strides = (data.strides[0] * hop_length,) + data.strides
+    return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
+
+
+def periodic_hann(window_length):
+    """Calculate a "periodic" Hann window.
+    The classic Hann window is defined as a raised cosine that starts and
+    ends on zero, and where every value appears twice, except the middle
+    point for an odd-length window.  Matlab calls this a "symmetric" window
+    and np.hanning() returns it.  However, for Fourier analysis, this
+    actually represents just over one cycle of a period N-1 cosine, and
+    thus is not compactly expressed on a length-N Fourier basis.  Instead,
+    it's better to use a raised cosine that ends just before the final
+    zero value - i.e. a complete cycle of a period-N cosine.  Matlab
+    calls this a "periodic" window. This routine calculates it.
+    Args:
+      window_length: The number of points in the returned window.
+    Returns:
+      A 1D np.array containing the periodic hann window.
+    """
+    return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
+                               np.arange(window_length)))
+
+
+def stft_magnitude(signal, fft_length,
+                   hop_length=None,
+                   window_length=None):
+    """Calculate the short-time Fourier transform magnitude.
+    Args:
+      signal: 1D np.array of the input time-domain signal.
+      fft_length: Size of the FFT to apply.
+      hop_length: Advance (in samples) between each frame passed to FFT.
+      window_length: Length of each block of samples to pass to FFT.
+    Returns:
+      2D np.array where each row contains the magnitudes of the fft_length/2+1
+      unique values of the FFT for the corresponding frame of input samples.
+    """
+    frames = frame(signal, window_length, hop_length)
+    # Apply frame window to each frame. We use a periodic Hann (cosine of period
+    # window_length) instead of the symmetric Hann of np.hanning (period
+    # window_length-1).
+    window = periodic_hann(window_length)
+    windowed_frames = frames * window
+    return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
+
+
+# Mel spectrum constants and functions.
+_MEL_BREAK_FREQUENCY_HERTZ = 700.0
+_MEL_HIGH_FREQUENCY_Q = 1127.0
+
+
+def hertz_to_mel(frequencies_hertz):
+    """Convert frequencies to mel scale using HTK formula.
+    Args:
+      frequencies_hertz: Scalar or np.array of frequencies in hertz.
+    Returns:
+      Object of same size as frequencies_hertz containing corresponding values
+      on the mel scale.
+    """
+    return _MEL_HIGH_FREQUENCY_Q * np.log(
+        1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
+
+
+def spectrogram_to_mel_matrix(num_mel_bins=20,
+                              num_spectrogram_bins=129,
+                              audio_sample_rate=8000,
+                              lower_edge_hertz=125.0,
+                              upper_edge_hertz=3800.0):
+    """Return a matrix that can post-multiply spectrogram rows to make mel.
+    Returns a np.array matrix A that can be used to post-multiply a matrix S of
+    spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
+    "mel spectrogram" M of frames x num_mel_bins.  M = S A.
+    The classic HTK algorithm exploits the complementarity of adjacent mel bands
+    to multiply each FFT bin by only one mel weight, then add it, with positive
+    and negative signs, to the two adjacent mel bands to which that bin
+    contributes.  Here, by expressing this operation as a matrix multiply, we go
+    from num_fft multiplies per frame (plus around 2*num_fft adds) to around
+    num_fft^2 multiplies and adds.  However, because these are all presumably
+    accomplished in a single call to np.dot(), it's not clear which approach is
+    faster in Python.  The matrix multiplication has the attraction of being more
+    general and flexible, and much easier to read.
+    Args:
+      num_mel_bins: How many bands in the resulting mel spectrum.  This is
+        the number of columns in the output matrix.
+      num_spectrogram_bins: How many bins there are in the source spectrogram
+        data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
+        only contains the nonredundant FFT bins.
+      audio_sample_rate: Samples per second of the audio at the input to the
+        spectrogram. We need this to figure out the actual frequencies for
+        each spectrogram bin, which dictates how they are mapped into mel.
+      lower_edge_hertz: Lower bound on the frequencies to be included in the mel
+        spectrum.  This corresponds to the lower edge of the lowest triangular
+        band.
+      upper_edge_hertz: The desired top edge of the highest frequency band.
+    Returns:
+      An np.array with shape (num_spectrogram_bins, num_mel_bins).
+    Raises:
+      ValueError: if frequency edges are incorrectly ordered or out of range.
+    """
+    nyquist_hertz = audio_sample_rate / 2.
+    if lower_edge_hertz < 0.0:
+        raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
+    if lower_edge_hertz >= upper_edge_hertz:
+        raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
+                         (lower_edge_hertz, upper_edge_hertz))
+    if upper_edge_hertz > nyquist_hertz:
+        raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" %
+                         (upper_edge_hertz, nyquist_hertz))
+    spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
+    spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
+    # The i'th mel band (starting from i=1) has center frequency
+    # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
+    # band_edges_mel[i+1].  Thus, we need num_mel_bins + 2 values in
+    # the band_edges_mel arrays.
+    band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
+                                 hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
+    # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
+    # of spectrogram values.
+    mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
+    for i in range(num_mel_bins):
+        lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
+        # Calculate lower and upper slopes for every spectrogram bin.
+        # Line segments are linear in the *mel* domain, not hertz.
+        lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
+                       (center_mel - lower_edge_mel))
+        upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
+                       (upper_edge_mel - center_mel))
+        # .. then intersect them with each other and zero.
+        mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
+                                                              upper_slope))
+    # HTK excludes the spectrogram DC bin; make sure it always gets a zero
+    # coefficient.
+    mel_weights_matrix[0, :] = 0.0
+    return mel_weights_matrix
+
+
+def log_mel_spectrogram(data,
+                        audio_sample_rate=8000,
+                        log_offset=0.0,
+                        window_length_secs=0.025,
+                        hop_length_secs=0.010,
+                        **kwargs):
+    """Convert waveform to a log magnitude mel-frequency spectrogram.
+    Args:
+      data: 1D np.array of waveform data.
+      audio_sample_rate: The sampling rate of data.
+      log_offset: Add this to values when taking log to avoid -Infs.
+      window_length_secs: Duration of each window to analyze.
+      hop_length_secs: Advance between successive analysis windows.
+      **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
+    Returns:
+      2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
+      magnitudes for successive frames.
+    """
+    window_length_samples = int(round(audio_sample_rate * window_length_secs))
+    hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
+    fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
+    spectrogram = stft_magnitude(
+        data,
+        fft_length=fft_length,
+        hop_length=hop_length_samples,
+        window_length=window_length_samples)
+    mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
+        num_spectrogram_bins=spectrogram.shape[1],
+        audio_sample_rate=audio_sample_rate, **kwargs))
+    return np.log(mel_spectrogram + log_offset)

+ 0 - 0
pipeline.py


+ 4 - 0
requirements.txt

@@ -0,0 +1,4 @@
+torch
+torchvision
+resampy
+soundfile

+ 5 - 0
segmentor.py

@@ -0,0 +1,5 @@
+from typing import List, Tuple
+
+
+def segmentor(scenes: List[bool], min_frames: int, threshold: float) -> List[Tuple[int, int]]:
+    return [(1, 5), (8, 30)]

+ 92 - 0
vggish.py

@@ -0,0 +1,92 @@
+from typing import Tuple
+
+import torch.nn as nn
+from torch import hub
+
+VGGISH_WEIGHTS = (
+    "https://users.cs.cf.ac.uk/taylorh23/pytorch/models/vggish-cbfe8f1c.pth"
+)
+PCA_PARAMS = (
+    "https://users.cs.cf.ac.uk/taylorh23/pytorch/models/vggish_pca_params-4d878af3.npz"
+)
+
+
+class VGGishParams:
+    """
+    These should not be changed. They have been added into this file for convenience.
+    """
+
+    NUM_FRAMES = (96,)  # Frames in input mel-spectrogram patch.
+    NUM_BANDS = 64  # Frequency bands in input mel-spectrogram patch.
+    EMBEDDING_SIZE = 128  # Size of embedding layer.
+
+    # Hyperparameters used in feature and example generation.
+    SAMPLE_RATE = 16000
+    STFT_WINDOW_LENGTH_SECONDS = 0.025
+    STFT_HOP_LENGTH_SECONDS = 0.010
+    NUM_MEL_BINS = NUM_BANDS
+    MEL_MIN_HZ = 125
+    MEL_MAX_HZ = 7500
+    LOG_OFFSET = 0.01  # Offset used for stabilized log of input mel-spectrogram.
+    EXAMPLE_WINDOW_SECONDS = 0.96  # Each example contains 96 10ms frames
+    EXAMPLE_HOP_SECONDS = 0.96  # with zero overlap.
+
+    # Parameters used for embedding postprocessing.
+    PCA_EIGEN_VECTORS_NAME = "pca_eigen_vectors"
+    PCA_MEANS_NAME = "pca_means"
+    QUANTIZE_MIN_VAL = -2.0
+    QUANTIZE_MAX_VAL = +2.0
+
+
+"""
+VGGish
+Input: 96x64 1-channel spectrogram
+Output:  128 Embedding 
+"""
+
+
+class VGGish(nn.Module):
+    def __init__(self):
+        super(VGGish, self).__init__()
+        self.features = nn.Sequential(
+            nn.Conv2d(1, VGGishParams.NUM_BANDS, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(2, 2),
+            nn.Conv2d(VGGishParams.NUM_BANDS, VGGishParams.EMBEDDING_SIZE, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(2, 2),
+            nn.Conv2d(128, 256, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(256, 256, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(2, 2),
+            nn.Conv2d(256, 512, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(512, 512, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(2, 2),
+        )
+        self.embeddings = nn.Sequential(
+            nn.Linear(512 * 24, 4096),
+            nn.ReLU(inplace=True),
+            nn.Linear(4096, 4096),
+            nn.ReLU(inplace=True),
+            nn.Linear(4096, VGGishParams.EMBEDDING_SIZE),
+            nn.ReLU(inplace=True),
+        )
+
+    def forward(self, x):
+        x = self.features(x)
+        x = x.view(x.size(0), -1)
+        x = self.embeddings(x)
+        return x
+
+
+def vggish() -> Tuple[VGGish, int]:
+    """
+    VGGish is a PyTorch implementation of Tensorflow's VGGish architecture used to create embeddings
+    for Audioset. It produces a 128-d embedding of a 96ms slice of audio. Always comes pretrained.
+    """
+    model = VGGish()
+    model.load_state_dict(hub.load_state_dict_from_url(VGGISH_WEIGHTS), strict=True)
+    return model, VGGishParams.EMBEDDING_SIZE

+ 81 - 0
vggish_input.py

@@ -0,0 +1,81 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Compute input examples for VGGish from audio waveform."""
+
+import mel_features
+import numpy as np
+import resampy
+import soundfile as sf
+import vggish_params
+
+
+def waveform_to_examples(data, sample_rate):
+    """Converts audio waveform into an array of examples for VGGish.
+    Args:
+      data: np.array of either one dimension (mono) or two dimensions
+        (multi-channel, with the outer dimension representing channels).
+        Each sample is generally expected to lie in the range [-1.0, +1.0],
+        although this is not required.
+      sample_rate: Sample rate of data.
+    Returns:
+      3-D np.array of shape [num_examples, num_frames, num_bands] which represents
+      a sequence of examples, each of which contains a patch of log mel
+      spectrogram, covering num_frames frames of audio and num_bands mel frequency
+      bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
+    """
+    # Convert to mono.
+    if len(data.shape) > 1:
+        data = np.mean(data, axis=1)
+    # Resample to the rate assumed by VGGish.
+    if sample_rate != vggish_params.SAMPLE_RATE:
+        data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
+
+    # Compute log mel spectrogram features.
+    log_mel = mel_features.log_mel_spectrogram(
+        data,
+        audio_sample_rate=vggish_params.SAMPLE_RATE,
+        log_offset=vggish_params.LOG_OFFSET,
+        window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
+        hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
+        num_mel_bins=vggish_params.NUM_MEL_BINS,
+        lower_edge_hertz=vggish_params.MEL_MIN_HZ,
+        upper_edge_hertz=vggish_params.MEL_MAX_HZ)
+
+    # Frame features into examples.
+    features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
+    example_window_length = int(round(
+        vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
+    example_hop_length = int(round(
+        vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
+    log_mel_examples = mel_features.frame(
+        log_mel,
+        window_length=example_window_length,
+        hop_length=example_hop_length)
+    return log_mel_examples
+
+
+def wavfile_to_examples(wav_file):
+    """Convenience wrapper around waveform_to_examples() for a common WAV format.
+    Args:
+      wav_file: String path to a file, or a file-like object. The file
+      is assumed to contain WAV audio data with signed 16-bit PCM samples.
+    Returns:
+      See waveform_to_examples.
+    """
+    wav_data, sr = sf.read(wav_file, dtype='int16')
+    assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
+    samples = wav_data / 32768.0  # Convert to [-1.0, +1.0]
+    return waveform_to_examples(samples, sr)

+ 52 - 0
vggish_params.py

@@ -0,0 +1,52 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Global parameters for the VGGish model.
+See vggish_slim.py for more information.
+"""
+
+# Architectural constants.
+NUM_FRAMES = 96  # Frames in input mel-spectrogram patch.
+NUM_BANDS = 64  # Frequency bands in input mel-spectrogram patch.
+EMBEDDING_SIZE = 128  # Size of embedding layer.
+
+# Hyperparameters used in feature and example generation.
+SAMPLE_RATE = 16000
+STFT_WINDOW_LENGTH_SECONDS = 0.025
+STFT_HOP_LENGTH_SECONDS = 0.010
+NUM_MEL_BINS = NUM_BANDS
+MEL_MIN_HZ = 125
+MEL_MAX_HZ = 7500
+LOG_OFFSET = 0.01  # Offset used for stabilized log of input mel-spectrogram.
+EXAMPLE_WINDOW_SECONDS = 0.96  # Each example contains 96 10ms frames
+EXAMPLE_HOP_SECONDS = 0.96     # with zero overlap.
+
+# Parameters used for embedding postprocessing.
+PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors'
+PCA_MEANS_NAME = 'pca_means'
+QUANTIZE_MIN_VAL = -2.0
+QUANTIZE_MAX_VAL = +2.0
+
+# Hyperparameters used in training.
+INIT_STDDEV = 0.01  # Standard deviation used to initialize weights.
+LEARNING_RATE = 1e-4  # Learning rate for the Adam optimizer.
+ADAM_EPSILON = 1e-8  # Epsilon for the Adam optimizer.
+
+# Names of ops, tensors, and features.
+INPUT_OP_NAME = 'vggish/input_features'
+INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0'
+OUTPUT_OP_NAME = 'vggish/embedding'
+OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0'
+AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding'