Browse Source

audio + img iterator

Amir Ziai 5 years ago
parent
commit
c70e64d6cd
9 changed files with 533 additions and 262 deletions
  1. 25 2
      data.py
  2. 137 1
      dev.ipynb
  3. 0 211
      dev2.ipynb
  4. 1 1
      kissing_detector.py
  5. 87 1
      pipeline.py
  6. 175 0
      spatial_transforms.py
  7. 50 0
      temporal_transforms.py
  8. 52 43
      train.py
  9. 6 3
      vggish.py

+ 25 - 2
data.py

@@ -1,12 +1,14 @@
 import copy
 import functools
+import json
 import os
+import pickle
+from glob import glob
+from typing import Tuple
 
 import torch
 import torch.utils.data as data
 from PIL import Image
-# import accimage
-import json
 
 
 def pil_loader(path):
@@ -137,3 +139,24 @@ class Video(data.Dataset):
 
     def __len__(self):
         return len(self.data)
+
+
+class AudioVideo(data.Dataset):
+    def __init__(self, path: str):
+        self.path = path
+        self.data = []
+
+        for file_path in glob(f'{path}/*.pkl'):
+            audios, images, label = pickle.load(open(file_path, 'rb'))
+            self.data += [(audios[i], images[i], label) for i in range(len(audios))]
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
+        # return (
+        #     torch.rand((1, 96, 64)),
+        #     torch.rand((3, 224, 224)),
+        #     np.random.choice([0, 1])
+        # )
+        return self.data[idx]

+ 137 - 1
dev.ipynb

@@ -944,6 +944,33 @@
     "kd"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 114,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from data import AudioVideo"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 115,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "av = AudioVideo('x')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 118,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "a, v, c = av[0]"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 7,
@@ -961,7 +988,116 @@
     }
    ],
    "source": [
-    "kd(torch.rand((1, 1, 96, 64)), torch.rand((1, 3, kd.conv_input_size, kd.conv_input_size)))"
+    "kd(torch.rand((1, 1, 96, 64)),\n",
+    "   torch.rand((1, 3, kd.conv_input_size, kd.conv_input_size)))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 120,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "datasets = {x: AudioVideo(x) for x in ['train', 'val']}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 121,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataloaders_dict = {x: torch.utils.data.DataLoader(datasets[x],\n",
+    "                                                   batch_size=batch_size,\n",
+    "                                                   shuffle=True, num_workers=4)\n",
+    "                    for x in ['train', 'val']}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 122,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch 0/14\n",
+      "----------\n"
+     ]
+    },
+    {
+     "ename": "ValueError",
+     "evalue": "too many values to unpack (expected 2)",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-122-89a52f4174d6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m model_ft, hist = train_model(kd,\n\u001b[1;32m      2\u001b[0m                              \u001b[0mdataloaders_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_ft\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_epochs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m                              is_inception=(model_name == \"inception\"))\n\u001b[0m",
+      "\u001b[0;32m<ipython-input-68-15dc3f5706f2>\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model, dataloaders, criterion, optimizer, num_epochs, is_inception)\u001b[0m\n\u001b[1;32m     48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     49\u001b[0m             \u001b[0;31m# Iterate over data.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 50\u001b[0;31m             \u001b[0;32mfor\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdataloaders\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     51\u001b[0m                 \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     52\u001b[0m                 \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 2)"
+     ]
+    }
+   ],
+   "source": [
+    "model_ft, hist = train_model(kd,\n",
+    "                             dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs,\n",
+    "                             is_inception=(model_name == \"inception\"))\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 119,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[-0.7130,  0.5764]], grad_fn=<AddmmBackward>)"
+      ]
+     },
+     "execution_count": 119,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "kd(a, v)"
    ]
   },
   {

File diff suppressed because it is too large
+ 0 - 211
dev2.ipynb


+ 1 - 1
kissing_detector.py

@@ -9,7 +9,7 @@ class KissingDetector(nn.Module):
         super(KissingDetector, self).__init__()
         conv, conv_input_size, conv_output_size = convnet_init(model_name, num_classes, feature_extract,
                                                                use_pretrained=use_pretrained)
-        vggish_model, vggish_output_size = vggish.vggish()
+        vggish_model, vggish_output_size = vggish.vggish(feature_extract)
         self.conv_input_size = conv_input_size
         self.conv = conv
         self.vggish = vggish_model

+ 87 - 1
pipeline.py

@@ -1,6 +1,18 @@
+import math
+import os
+import pickle
+import shutil
+from typing import List, Tuple
+
 import cv2
-from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
 import numpy as np
+import torch
+from moviepy.editor import VideoFileClip
+from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
+
+import vggish_input
+
+VGGISH_FRAME_RATE = 0.96
 
 
 def slice_clips(segments, root, fps=2):
@@ -26,3 +38,77 @@ def slice_clips(segments, root, fps=2):
                     success, image = vidcap.read()
                     # print('Read a new frame: ', success)
                     count += 1
+
+
+class BuildDataset:
+    def __init__(self,
+                 base_path: str,
+                 videos_and_labels: List[Tuple[str, str]],
+                 output_path: str,
+                 test_size: float = 1 / 3):
+        assert 0 < test_size < 1
+        self.videos_and_labels = videos_and_labels
+        self.test_size = test_size
+        self.output_path = output_path
+        self.base_path = base_path
+
+        self.sets = ['train', 'val']
+
+    def _get_set(self):
+        return np.random.choice(self.sets, p=[1 - self.test_size, self.test_size])
+
+    def build_dataset(self):
+        # wipe
+        for set_ in self.sets:
+            path = f'{self.output_path}/{set_}'
+            try:
+                shutil.rmtree(path)
+            except FileNotFoundError:
+                pass
+            os.makedirs(path)
+
+        for file_name, label in self.videos_and_labels:
+            name, _ = file_name.split('.')
+            path = f'{self.base_path}/{file_name}'
+            audio, images = self.one_video_extract_audio_and_stills(path)
+            set_ = self._get_set()
+            target = f"{self.output_path}/{set_}/{label}_{name}.pkl"
+            pickle.dump((audio, images, label), open(target, 'wb'))
+
+    @staticmethod
+    def one_video_extract_audio_and_stills(path_video: str) -> Tuple[List[torch.Tensor],
+                                                                     List[torch.Tensor]]:
+        # return a list of image(s), audio tensors
+        cap = cv2.VideoCapture(path_video)
+        frame_rate = cap.get(5)
+        images = []
+
+        # process the image
+        while cap.isOpened():
+            frame_id = cap.get(1)
+            success, frame = cap.read()
+
+            if not success:
+                print('Something went wrong!')
+                break
+
+            if frame_id % math.floor(frame_rate * VGGISH_FRAME_RATE) == 0:
+                images.append(frame)
+
+        cap.release()
+
+        # process the audio
+        # TODO: hack to get around OpenMP error
+        os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
+
+        tmp_audio_file = 'tmp.wav'
+        VideoFileClip(path_video).audio.write_audiofile(tmp_audio_file)
+        audio = vggish_input.wavfile_to_examples(tmp_audio_file)
+        # audio = audio[:, None, :, :]  # add dummy dimension for "channel"
+        # audio = torch.from_numpy(audio).float()  # Convert input example to float
+
+        min_sizes = min(audio.shape[0], len(images))
+        audio = [torch.from_numpy(audio[idx][None, :, :]).float() for idx in range(min_sizes)]
+        images = [torch.from_numpy(img).permute((2, 1, 0)) for img in images[:min_sizes]]
+
+        return audio, images

+ 175 - 0
spatial_transforms.py

@@ -0,0 +1,175 @@
+import random
+import math
+import numbers
+import collections
+import numpy as np
+import torch
+from PIL import Image, ImageOps
+try:
+    import accimage
+except ImportError:
+    accimage = None
+
+
+class Compose(object):
+    """Composes several transforms together.
+    Args:
+        transforms (list of ``Transform`` objects): list of transforms to compose.
+    Example:
+        >>> transforms.Compose([
+        >>>     transforms.CenterCrop(10),
+        >>>     transforms.ToTensor(),
+        >>> ])
+    """
+
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, img):
+        for t in self.transforms:
+            img = t(img)
+        return img
+
+
+class ToTensor(object):
+    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
+    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
+    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
+    """
+
+    def __call__(self, pic):
+        """
+        Args:
+            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
+        Returns:
+            Tensor: Converted image.
+        """
+        if isinstance(pic, np.ndarray):
+            # handle numpy array
+            img = torch.from_numpy(pic.transpose((2, 0, 1)))
+            # backward compatibility
+            return img.float()
+
+        if accimage is not None and isinstance(pic, accimage.Image):
+            nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
+            pic.copyto(nppic)
+            return torch.from_numpy(nppic)
+
+        # handle PIL Image
+        if pic.mode == 'I':
+            img = torch.from_numpy(np.array(pic, np.int32, copy=False))
+        elif pic.mode == 'I;16':
+            img = torch.from_numpy(np.array(pic, np.int16, copy=False))
+        else:
+            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
+        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
+        if pic.mode == 'YCbCr':
+            nchannel = 3
+        elif pic.mode == 'I;16':
+            nchannel = 1
+        else:
+            nchannel = len(pic.mode)
+        img = img.view(pic.size[1], pic.size[0], nchannel)
+        # put it from HWC to CHW format
+        # yikes, this transpose takes 80% of the loading time/CPU
+        img = img.transpose(0, 1).transpose(0, 2).contiguous()
+        if isinstance(img, torch.ByteTensor):
+            return img.float()
+        else:
+            return img
+
+
+class Normalize(object):
+    """Normalize an tensor image with mean and standard deviation.
+    Given mean: (R, G, B) and std: (R, G, B),
+    will normalize each channel of the torch.*Tensor, i.e.
+    channel = (channel - mean) / std
+    Args:
+        mean (sequence): Sequence of means for R, G, B channels respecitvely.
+        std (sequence): Sequence of standard deviations for R, G, B channels
+            respecitvely.
+    """
+
+    def __init__(self, mean, std):
+        self.mean = mean
+        self.std = std
+
+    def __call__(self, tensor):
+        """
+        Args:
+            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
+        Returns:
+            Tensor: Normalized image.
+        """
+        # TODO: make efficient
+        for t, m, s in zip(tensor, self.mean, self.std):
+            t.sub_(m).div_(s)
+        return tensor
+
+
+class Scale(object):
+    """Rescale the input PIL.Image to the given size.
+    Args:
+        size (sequence or int): Desired output size. If size is a sequence like
+            (w, h), output size will be matched to this. If size is an int,
+            smaller edge of the image will be matched to this number.
+            i.e, if height > width, then image will be rescaled to
+            (size * height / width, size)
+        interpolation (int, optional): Desired interpolation. Default is
+            ``PIL.Image.BILINEAR``
+    """
+
+    def __init__(self, size, interpolation=Image.BILINEAR):
+        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
+        self.size = size
+        self.interpolation = interpolation
+
+    def __call__(self, img):
+        """
+        Args:
+            img (PIL.Image): Image to be scaled.
+        Returns:
+            PIL.Image: Rescaled image.
+        """
+        if isinstance(self.size, int):
+            w, h = img.size
+            if (w <= h and w == self.size) or (h <= w and h == self.size):
+                return img
+            if w < h:
+                ow = self.size
+                oh = int(self.size * h / w)
+                return img.resize((ow, oh), self.interpolation)
+            else:
+                oh = self.size
+                ow = int(self.size * w / h)
+                return img.resize((ow, oh), self.interpolation)
+        else:
+            return img.resize(self.size, self.interpolation)
+
+
+class CenterCrop(object):
+    """Crops the given PIL.Image at the center.
+    Args:
+        size (sequence or int): Desired output size of the crop. If size is an
+            int instead of sequence like (h, w), a square crop (size, size) is
+            made.
+    """
+
+    def __init__(self, size):
+        if isinstance(size, numbers.Number):
+            self.size = (int(size), int(size))
+        else:
+            self.size = size
+
+    def __call__(self, img):
+        """
+        Args:
+            img (PIL.Image): Image to be cropped.
+        Returns:
+            PIL.Image: Cropped image.
+        """
+        w, h = img.size
+        th, tw = self.size
+        x1 = int(round((w - tw) / 2.))
+        y1 = int(round((h - th) / 2.))
+        return img.crop((x1, y1, x1 + tw, y1 + th))

+ 50 - 0
temporal_transforms.py

@@ -0,0 +1,50 @@
+import random
+import math
+
+
+class LoopPadding(object):
+    def __init__(self, size):
+        self.size = size
+
+    def __call__(self, frame_indices):
+        out = frame_indices
+
+        for index in out:
+            if len(out) >= self.size:
+                break
+            out.append(index)
+
+        return out
+
+
+class TemporalCenterCrop(object):
+    """Temporally crop the given frame indices at a center.
+    If the number of frames is less than the size,
+    loop the indices as many times as necessary to satisfy the size.
+    Args:
+        size (int): Desired output size of the crop.
+    """
+
+    def __init__(self, size):
+        self.size = size
+
+    def __call__(self, frame_indices):
+        """
+        Args:
+            frame_indices (list): frame indices to be cropped.
+        Returns:
+            list: Cropped frame indices.
+        """
+
+        center_index = len(frame_indices) // 2
+        begin_index = max(0, center_index - (self.size // 2))
+        end_index = min(begin_index + self.size, len(frame_indices))
+
+        out = frame_indices[begin_index:end_index]
+
+        for index in out:
+            if len(out) >= self.size:
+                break
+            out.append(index)
+
+        return out

+ 52 - 43
train.py

@@ -1,16 +1,56 @@
 import copy
 import time
+from typing import List, Tuple
 
 import torch
 import torch.optim as optim
 from torch import nn
 
-# TODO: get these properly
-device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-feature_extract = True
-model_ft = None  # TODO
-dataloaders_dict = None  # TODO
-model_name = None  # TODO
+from data import AudioVideo
+from kissing_detector import KissingDetector
+
+
+def _get_params_to_update(model: nn.Module,
+                          feature_extract: bool) -> List[nn.parameter.Parameter]:
+    params_to_update = model.parameters()
+    if feature_extract:
+        print('Params to update')
+        params_to_update = []
+        for name, param in model.named_parameters():
+            if param.requires_grad is True:
+                params_to_update.append(param)
+                print("*", name)
+    else:
+        print('Updating ALL params')
+    return params_to_update
+
+
+def train_kd(model_name: str,
+             num_epochs: int,
+             feature_extract: bool,
+             batch_size: int,
+             num_workers: int=4,
+             shuffle: bool=True,
+             lr: float=0.001,
+             momentum: float=0.9) -> Tuple[nn.Module, List[torch.Tensor]]:
+    num_classes = 2
+    kd = KissingDetector(model_name, num_classes, feature_extract)
+    params_to_update = _get_params_to_update(kd, feature_extract)
+
+    datasets = {x: AudioVideo(x) for x in ['train', 'val']}
+    dataloaders_dict = {x: torch.utils.data.DataLoader(datasets[x],
+                                                       batch_size=batch_size,
+                                                       shuffle=shuffle, num_workers=num_workers)
+                        for x in ['train', 'val']}
+    optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum=momentum)
+
+    # Setup the loss fxn
+    criterion = nn.CrossEntropyLoss()
+
+    model_ft, hist = train_model(kd,
+                                 dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs,
+                                 is_inception=(model_name == "inception"))
+    return model_ft, hist
 
 
 def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
@@ -39,8 +79,9 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
             running_corrects = 0
 
             # Iterate over data.
-            for inputs, labels in dataloaders[phase]:
-                inputs = inputs.to(device)
+            for a, v, labels in dataloaders[phase]:
+                a = a.to(device)
+                v = v.to(device)
                 labels = labels.to(device)
 
                 # zero the parameter gradients
@@ -55,12 +96,12 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
                     #   but in testing we only consider the final output.
                     if is_inception and phase == 'train':
                         # https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
-                        outputs, aux_outputs = model(inputs)
+                        outputs, aux_outputs = model(a, v)
                         loss1 = criterion(outputs, labels)
                         loss2 = criterion(aux_outputs, labels)
                         loss = loss1 + 0.4 * loss2
                     else:
-                        outputs = model(inputs)
+                        outputs = model(a, v)
                         loss = criterion(outputs, labels)
 
                     _, preds = torch.max(outputs, 1)
@@ -71,7 +112,7 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
                         optimizer.step()
 
                 # statistics
-                running_loss += loss.item() * inputs.size(0)
+                running_loss += loss.item() * a.size(0)
                 running_corrects += torch.sum(preds == labels.data)
 
             epoch_loss = running_loss / len(dataloaders[phase].dataset)
@@ -95,35 +136,3 @@ def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_ince
     # load best model weights
     model.load_state_dict(best_model_wts)
     return model, val_acc_history
-
-
-# Send the model to GPU
-model_ft = model_ft.to(device)
-
-# Gather the parameters to be optimized/updated in this run. If we are
-#  finetuning we will be updating all parameters. However, if we are
-#  doing feature extract method, we will only update the parameters
-#  that we have just initialized, i.e. the parameters with requires_grad
-#  is True.
-params_to_update = model_ft.parameters()
-print("Params to learn:")
-if feature_extract:
-    params_to_update = []
-    for name, param in model_ft.named_parameters():
-        if param.requires_grad is True:
-            params_to_update.append(param)
-            print("\t", name)
-else:
-    for name, param in model_ft.named_parameters():
-        if param.requires_grad is True:
-            print("\t", name)
-
-# Observe that all parameters are being optimized
-optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)
-
-# Setup the loss fxn
-criterion = nn.CrossEntropyLoss()
-
-# Train and evaluate
-model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs,
-                             is_inception=(model_name == "inception"))

+ 6 - 3
vggish.py

@@ -1,5 +1,6 @@
 from typing import Tuple
 
+import conv
 import torch.nn as nn
 from torch import hub
 
@@ -46,7 +47,7 @@ Output:  128 Embedding
 
 
 class VGGish(nn.Module):
-    def __init__(self):
+    def __init__(self, feature_extract: bool):
         super(VGGish, self).__init__()
         self.features = nn.Sequential(
             nn.Conv2d(1, VGGishParams.NUM_BANDS, 3, 1, 1),
@@ -74,6 +75,8 @@ class VGGish(nn.Module):
             nn.Linear(4096, VGGishParams.EMBEDDING_SIZE),
             nn.ReLU(inplace=True),
         )
+        conv.set_parameter_requires_grad(self.features, feature_extract)
+        conv.set_parameter_requires_grad(self.embeddings, feature_extract)
 
     def forward(self, x):
         x = self.features(x)
@@ -82,11 +85,11 @@ class VGGish(nn.Module):
         return x
 
 
-def vggish() -> Tuple[VGGish, int]:
+def vggish(feature_extract: bool) -> Tuple[VGGish, int]:
     """
     VGGish is a PyTorch implementation of Tensorflow's VGGish architecture used to create embeddings
     for Audioset. It produces a 128-d embedding of a 96ms slice of audio. Always comes pretrained.
     """
-    model = VGGish()
+    model = VGGish(feature_extract)
     model.load_state_dict(hub.load_state_dict_from_url(VGGISH_WEIGHTS), strict=True)
     return model, VGGishParams.EMBEDDING_SIZE

Some files were not shown because too many files changed in this diff