Browse Source

added conv3d

Amir Ziai 5 years ago
parent
commit
cda3d270f2
3 changed files with 339 additions and 1 deletions
  1. 195 0
      conv3d.py
  2. 46 1
      kissing_detector.py
  3. 98 0
      qualitative.py

+ 195 - 0
conv3d.py

@@ -0,0 +1,195 @@
+# code is from https://github.com/kenshohara/3D-ResNets-PyTorch/blob/master/model.py
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+
+def conv3x3x3(in_planes, out_planes, stride=1):
+    # 3x3x3 convolution with padding
+    return nn.Conv3d(
+        in_planes,
+        out_planes,
+        kernel_size=3,
+        stride=stride,
+        padding=1,
+        bias=False)
+
+
+def downsample_basic_block(x, planes, stride):
+    out = F.avg_pool3d(x, kernel_size=1, stride=stride)
+    zero_pads = torch.Tensor(
+        out.size(0), planes - out.size(1), out.size(2), out.size(3),
+        out.size(4)).zero_()
+    if isinstance(out.data, torch.cuda.FloatTensor):
+        zero_pads = zero_pads.cuda()
+
+    out = Variable(torch.cat([out.data, zero_pads], dim=1))
+
+    return out
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm3d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3x3(planes, planes)
+        self.bn2 = nn.BatchNorm3d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
+        self.bn1 = nn.BatchNorm3d(planes)
+        self.conv2 = nn.Conv3d(
+            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm3d(planes)
+        self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
+        self.bn3 = nn.BatchNorm3d(planes * 4)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class ResNet(nn.Module):
+
+    def __init__(self,
+                 block,
+                 layers,
+                 sample_size,
+                 sample_duration,
+                 shortcut_type='B',
+                 num_classes=400):
+        self.inplanes = 64
+        super(ResNet, self).__init__()
+        self.conv1 = nn.Conv3d(
+            3,
+            64,
+            kernel_size=7,
+            stride=(1, 2, 2),
+            padding=(3, 3, 3),
+            bias=False)
+        self.bn1 = nn.BatchNorm3d(64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
+        self.layer2 = self._make_layer(
+            block, 128, layers[1], shortcut_type, stride=2)
+        self.layer3 = self._make_layer(
+            block, 256, layers[2], shortcut_type, stride=2)
+        self.layer4 = self._make_layer(
+            block, 512, layers[3], shortcut_type, stride=2)
+        last_duration = int(math.ceil(sample_duration / 16))
+        last_size = int(math.ceil(sample_size / 32))
+        self.avgpool = nn.AvgPool3d(
+            (last_duration, last_size, last_size), stride=1)
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv3d):
+                m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
+            elif isinstance(m, nn.BatchNorm3d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            if shortcut_type == 'A':
+                downsample = partial(
+                    downsample_basic_block,
+                    planes=planes * block.expansion,
+                    stride=stride)
+            else:
+                downsample = nn.Sequential(
+                    nn.Conv3d(
+                        self.inplanes,
+                        planes * block.expansion,
+                        kernel_size=1,
+                        stride=stride,
+                        bias=False), nn.BatchNorm3d(planes * block.expansion))
+a
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+
+        x = x.view(x.size(0), -1)
+        x = self.fc(x)
+
+        return x
+
+
+def resnet34(**kwargs):
+    """Constructs a ResNet-34 model.
+    """
+    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
+    return model

+ 46 - 1
kissing_detector.py

@@ -1,8 +1,11 @@
+from typing import Optional
+
 import torch
 from torch import nn
+
 import vggish
 from conv import convnet_init
-from typing import Optional
+import conv3d
 
 
 class KissingDetector(nn.Module):
@@ -45,3 +48,45 @@ class KissingDetector(nn.Module):
             combined = a if a is not None else c
 
         return self.combined(combined)
+
+
+class KissingDetector3DConv(nn.Module):
+    def __init__(self,
+                 num_classes: int,
+                 feature_extract: bool,
+                 use_vggish: bool = True):
+        super(KissingDetector3DConv, self).__init__()
+        conv_output_size = 512
+        vggish_output_size = 0
+        conv_input_size = 0
+        vggish_model = None
+
+        conv = conv3d.resnet34(
+            num_classes=num_classes,
+            shortcut_type='B',
+            sample_size=224,
+            sample_duration=10
+        )
+        conv.fc = nn.Identity()
+
+        if use_vggish:
+            vggish_model, vggish_output_size = vggish.vggish(feature_extract)
+
+        if not conv and not vggish_model:
+            raise ValueError("Use VGGish, Conv, or both")
+
+        self.conv_input_size = conv_input_size
+        self.conv = conv
+        self.vggish = vggish_model
+        self.combined = nn.Linear(vggish_output_size + conv_output_size, num_classes)
+
+    def forward(self, audio: torch.Tensor, image: torch.Tensor):
+        a = self.vggish(audio) if self.vggish else None
+        c = self.conv(image) if self.conv else None
+
+        if a is not None and c is not None:
+            combined = torch.cat((c.view(c.size(0), -1), a.view(a.size(0), -1)), dim=1)
+        else:
+            combined = a if a is not None else c
+
+        return self.combined(combined)

+ 98 - 0
qualitative.py

@@ -94,3 +94,101 @@ class QualitativeAnalysis:
             y = [1 if 'kissing' in vid else 0] * len(A)
             self._show_saliency_maps(A, I, y)
             print('=' * 10)
+
+    # next few methods taken from cs231n
+    @staticmethod
+    def jitter(X, ox, oy):
+        """
+        Helper function to randomly jitter an image.
+
+        Inputs
+        - X: PyTorch Tensor of shape (N, C, H, W)
+        - ox, oy: Integers giving number of pixels to jitter along W and H axes
+
+        Returns: A new PyTorch Tensor of shape (N, C, H, W)
+        """
+        if ox != 0:
+            left = X[:, :, :, :-ox]
+            right = X[:, :, :, -ox:]
+            X = torch.cat([right, left], dim=3)
+        if oy != 0:
+            top = X[:, :, :-oy]
+            bottom = X[:, :, -oy:]
+            X = torch.cat([bottom, top], dim=2)
+        return X
+
+    def create_class_visualization(target_y, model, dtype, **kwargs):
+        """
+        Generate an image to maximize the score of target_y under a pretrained model.
+
+        Inputs:
+        - target_y: Integer in the range [0, 1000) giving the index of the class
+        - model: A pretrained CNN that will be used to generate the image
+        - dtype: Torch datatype to use for computations
+
+        Keyword arguments:
+        - l2_reg: Strength of L2 regularization on the image
+        - learning_rate: How big of a step to take
+        - num_iterations: How many iterations to use
+        - blur_every: How often to blur the image as an implicit regularizer
+        - max_jitter: How much to gjitter the image as an implicit regularizer
+        - show_every: How often to show the intermediate result
+        """
+        model.type(dtype)
+        l2_reg = kwargs.pop('l2_reg', 1e-3)
+        learning_rate = kwargs.pop('learning_rate', 25)
+        num_iterations = kwargs.pop('num_iterations', 100)
+        blur_every = kwargs.pop('blur_every', 10)
+        max_jitter = kwargs.pop('max_jitter', 16)
+        show_every = kwargs.pop('show_every', 25)
+
+        # Randomly initialize the image as a PyTorch Tensor, and make it requires gradient.
+        img = torch.randn(1, 3, 224, 224).mul_(1.0).type(dtype).requires_grad_()
+
+        for t in range(num_iterations):
+            # Randomly jitter the image a bit; this gives slightly nicer results
+            ox, oy = random.randint(0, max_jitter), random.randint(0, max_jitter)
+            img.data.copy_(jitter(img.data, ox, oy))
+
+            ########################################################################
+            # TODO: Use the model to compute the gradient of the score for the     #
+            # class target_y with respect to the pixels of the image, and make a   #
+            # gradient step on the image using the learning rate. Don't forget the #
+            # L2 regularization term!                                              #
+            # Be very careful about the signs of elements in your code.            #
+            ########################################################################
+            # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
+
+            target = model(img)[0, target_y]
+            target.backward()
+            g = img.grad.data
+            g -= 2 * l2_reg * img.data
+            img.data += learning_rate * (g / g.norm())
+            img.grad.zero_()
+
+            # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
+            ########################################################################
+            #                             END OF YOUR CODE                         #
+            ########################################################################
+
+            # Undo the random jitter
+            img.data.copy_(jitter(img.data, -ox, -oy))
+
+            # As regularizer, clamp and periodically blur the image
+            for c in range(3):
+                lo = float(-SQUEEZENET_MEAN[c] / SQUEEZENET_STD[c])
+                hi = float((1.0 - SQUEEZENET_MEAN[c]) / SQUEEZENET_STD[c])
+                img.data[:, c].clamp_(min=lo, max=hi)
+            if t % blur_every == 0:
+                blur_image(img.data, sigma=0.5)
+
+            # Periodically show the image
+            if t == 0 or (t + 1) % show_every == 0 or t == num_iterations - 1:
+                plt.imshow(deprocess(img.data.clone().cpu()))
+                class_name = class_names[target_y]
+                plt.title('%s\nIteration %d / %d' % (class_name, t + 1, num_iterations))
+                plt.gcf().set_size_inches(4, 4)
+                plt.axis('off')
+                plt.show()
+
+        return deprocess(img.data.cpu())