{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3DC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "import math\n",
    "from functools import partial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv3x3x3(in_planes, out_planes, stride=1):\n",
    "    # 3x3x3 convolution with padding\n",
    "    return nn.Conv3d(\n",
    "        in_planes,\n",
    "        out_planes,\n",
    "        kernel_size=3,\n",
    "        stride=stride,\n",
    "        padding=1,\n",
    "        bias=False)\n",
    "\n",
    "\n",
    "def downsample_basic_block(x, planes, stride):\n",
    "    out = F.avg_pool3d(x, kernel_size=1, stride=stride)\n",
    "    zero_pads = torch.Tensor(\n",
    "        out.size(0), planes - out.size(1), out.size(2), out.size(3),\n",
    "        out.size(4)).zero_()\n",
    "    if isinstance(out.data, torch.cuda.FloatTensor):\n",
    "        zero_pads = zero_pads.cuda()\n",
    "\n",
    "    out = Variable(torch.cat([out.data, zero_pads], dim=1))\n",
    "\n",
    "    return out\n",
    "\n",
    "\n",
    "class BasicBlock(nn.Module):\n",
    "    expansion = 1\n",
    "\n",
    "    def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
    "        super(BasicBlock, self).__init__()\n",
    "        self.conv1 = conv3x3x3(inplanes, planes, stride)\n",
    "        self.bn1 = nn.BatchNorm3d(planes)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.conv2 = conv3x3x3(planes, planes)\n",
    "        self.bn2 = nn.BatchNorm3d(planes)\n",
    "        self.downsample = downsample\n",
    "        self.stride = stride\n",
    "\n",
    "    def forward(self, x):\n",
    "        residual = x\n",
    "\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv2(out)\n",
    "        out = self.bn2(out)\n",
    "\n",
    "        if self.downsample is not None:\n",
    "            residual = self.downsample(x)\n",
    "\n",
    "        out += residual\n",
    "        out = self.relu(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "class Bottleneck(nn.Module):\n",
    "    expansion = 4\n",
    "\n",
    "    def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
    "        super(Bottleneck, self).__init__()\n",
    "        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm3d(planes)\n",
    "        self.conv2 = nn.Conv3d(\n",
    "            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm3d(planes)\n",
    "        self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)\n",
    "        self.bn3 = nn.BatchNorm3d(planes * 4)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.downsample = downsample\n",
    "        self.stride = stride\n",
    "\n",
    "    def forward(self, x):\n",
    "        residual = x\n",
    "\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv2(out)\n",
    "        out = self.bn2(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv3(out)\n",
    "        out = self.bn3(out)\n",
    "\n",
    "        if self.downsample is not None:\n",
    "            residual = self.downsample(x)\n",
    "\n",
    "        out += residual\n",
    "        out = self.relu(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "class ResNet(nn.Module):\n",
    "\n",
    "    def __init__(self,\n",
    "                 block,\n",
    "                 layers,\n",
    "                 sample_size,\n",
    "                 sample_duration,\n",
    "                 shortcut_type='B',\n",
    "                 num_classes=400):\n",
    "        self.inplanes = 64\n",
    "        super(ResNet, self).__init__()\n",
    "        self.conv1 = nn.Conv3d(\n",
    "            3,\n",
    "            64,\n",
    "            kernel_size=7,\n",
    "            stride=(1, 2, 2),\n",
    "            padding=(3, 3, 3),\n",
    "            bias=False)\n",
    "        self.bn1 = nn.BatchNorm3d(64)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)\n",
    "        self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)\n",
    "        self.layer2 = self._make_layer(\n",
    "            block, 128, layers[1], shortcut_type, stride=2)\n",
    "        self.layer3 = self._make_layer(\n",
    "            block, 256, layers[2], shortcut_type, stride=2)\n",
    "        self.layer4 = self._make_layer(\n",
    "            block, 512, layers[3], shortcut_type, stride=2)\n",
    "        last_duration = int(math.ceil(sample_duration / 16))\n",
    "        last_size = int(math.ceil(sample_size / 32))\n",
    "        self.avgpool = nn.AvgPool3d(\n",
    "            (last_duration, last_size, last_size), stride=1)\n",
    "        self.fc = nn.Linear(512 * block.expansion, num_classes)\n",
    "\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv3d):\n",
    "                m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')\n",
    "            elif isinstance(m, nn.BatchNorm3d):\n",
    "                m.weight.data.fill_(1)\n",
    "                m.bias.data.zero_()\n",
    "\n",
    "    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):\n",
    "        downsample = None\n",
    "        if stride != 1 or self.inplanes != planes * block.expansion:\n",
    "            if shortcut_type == 'A':\n",
    "                downsample = partial(\n",
    "                    downsample_basic_block,\n",
    "                    planes=planes * block.expansion,\n",
    "                    stride=stride)\n",
    "            else:\n",
    "                downsample = nn.Sequential(\n",
    "                    nn.Conv3d(\n",
    "                        self.inplanes,\n",
    "                        planes * block.expansion,\n",
    "                        kernel_size=1,\n",
    "                        stride=stride,\n",
    "                        bias=False), nn.BatchNorm3d(planes * block.expansion))\n",
    "\n",
    "        layers = []\n",
    "        layers.append(block(self.inplanes, planes, stride, downsample))\n",
    "        self.inplanes = planes * block.expansion\n",
    "        for i in range(1, blocks):\n",
    "            layers.append(block(self.inplanes, planes))\n",
    "\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.bn1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.maxpool(x)\n",
    "\n",
    "        x = self.layer1(x)\n",
    "        x = self.layer2(x)\n",
    "        x = self.layer3(x)\n",
    "        x = self.layer4(x)\n",
    "\n",
    "        x = self.avgpool(x)\n",
    "\n",
    "        x = x.view(x.size(0), -1)\n",
    "        x = self.fc(x)\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def resnet34(**kwargs):\n",
    "    \"\"\"Constructs a ResNet-34 model.\n",
    "    \"\"\"\n",
    "    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_classes = 2\n",
    "resnet_shortcut = 'B'\n",
    "sample_size = 224\n",
    "sample_duration = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:132: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.\n"
     ]
    }
   ],
   "source": [
    "model = resnet34(\n",
    "                num_classes=2,\n",
    "                shortcut_type=resnet_shortcut,\n",
    "                sample_size=sample_size,\n",
    "                sample_duration=sample_duration)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ResNet(\n",
       "  (conv1): Conv3d(3, 64, kernel_size=(7, 7, 7), stride=(1, 2, 2), padding=(3, 3, 3), bias=False)\n",
       "  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (relu): ReLU(inplace)\n",
       "  (maxpool): MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  (layer1): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (2): BasicBlock(\n",
       "      (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (layer2): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(2, 2, 2), bias=False)\n",
       "        (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (2): BasicBlock(\n",
       "      (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (3): BasicBlock(\n",
       "      (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (layer3): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv3d(128, 256, kernel_size=(1, 1, 1), stride=(2, 2, 2), bias=False)\n",
       "        (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (2): BasicBlock(\n",
       "      (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (3): BasicBlock(\n",
       "      (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (4): BasicBlock(\n",
       "      (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (5): BasicBlock(\n",
       "      (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (layer4): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv3d(256, 512, kernel_size=(1, 1, 1), stride=(2, 2, 2), bias=False)\n",
       "        (1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (2): BasicBlock(\n",
       "      (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
       "      (bn2): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (avgpool): AvgPool3d(kernel_size=(1, 7, 7), stride=1, padding=0)\n",
       "  (fc): Linear(in_features=512, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "q = torch.rand((2, 3, 16, 224, 224))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.3253, 0.5383],\n",
       "        [1.2986, 0.5157]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from kissing_detector import KissingDetector3DConv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/aziai/Dropbox/github/cs231n-project/conv3d.py:142: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.\n",
      "  m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')\n"
     ]
    }
   ],
   "source": [
    "model = KissingDetector3DConv(num_classes=2,\n",
    "                      feature_extract=True,\n",
    "                      use_vggish=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.1913, -0.5335]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(\n",
    "    torch.rand((1, 1, 96, 64)),\n",
    "    torch.rand((1, 3, 16, 224, 224))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = [\n",
    "    torch.rand((3, 224, 224))\n",
    "    for _ in range(10)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(xs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 224, 224])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xs[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 16, 224, 224])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "e.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = []\n",
    "for i in range(len(xs)):\n",
    "    # build an e\n",
    "    e = torch.zeros((16, 3, 224, 224))\n",
    "    \n",
    "    e[-1] = xs[0]\n",
    "    \n",
    "    for j in range(i):\n",
    "        e[- 1 - j] = xs[j]\n",
    "    \n",
    "    # permute e\n",
    "    ee = e.permute((1, 0, 2, 3))\n",
    "    out.append(ee)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 16, 224, 224])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.2543, 0.3794, 0.6759,  ..., 0.4921, 0.7928, 0.3837],\n",
       "        [0.9456, 0.6443, 0.3277,  ..., 0.7827, 0.9809, 0.2572],\n",
       "        [0.8232, 0.0090, 0.7486,  ..., 0.9921, 0.6396, 0.8412],\n",
       "        ...,\n",
       "        [0.8520, 0.2385, 0.7409,  ..., 0.9279, 0.7809, 0.0244],\n",
       "        [0.9987, 0.8331, 0.8345,  ..., 0.8206, 0.7880, 0.4744],\n",
       "        [0.5744, 0.5991, 0.7520,  ..., 0.5209, 0.7709, 0.7906]])"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out[1][0, -1, :, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1- broken 2d?\n",
    "from experiments import ExperimentRunner\n",
    "import params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "ex = ExperimentRunner(params.experiment_test_3d, n_jobs=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Running param set: {'data_path_base': 'vtest_new2', 'conv_model_name': 'resnet', 'num_epochs': 10, 'feature_extract': True, 'batch_size': 64, 'lr': 0.001, 'use_vggish': True, 'momentum': 0.9, 'use_3d': True}\n",
      "/Users/aziai/Dropbox/github/cs231n-project/conv3d.py:142: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.\n",
      "  m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Params to update\n",
      "* combined.weight\n",
      "* combined.bias\n",
      "Epoch 0/9\n",
      "----------\n",
      "train Loss: 1.0060 F1: 0.6833 Acc: 0.5190\n"
     ]
    }
   ],
   "source": [
    "ex.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "from train import train_kd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_kd()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}