conv3d.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # code is from https://github.com/kenshohara/3D-ResNets-PyTorch/blob/master/model.py
  2. import math
  3. from functools import partial
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch.autograd import Variable
  8. def conv3x3x3(in_planes, out_planes, stride=1):
  9. # 3x3x3 convolution with padding
  10. return nn.Conv3d(
  11. in_planes,
  12. out_planes,
  13. kernel_size=3,
  14. stride=stride,
  15. padding=1,
  16. bias=False)
  17. def downsample_basic_block(x, planes, stride):
  18. out = F.avg_pool3d(x, kernel_size=1, stride=stride)
  19. zero_pads = torch.Tensor(
  20. out.size(0), planes - out.size(1), out.size(2), out.size(3),
  21. out.size(4)).zero_()
  22. if isinstance(out.data, torch.cuda.FloatTensor):
  23. zero_pads = zero_pads.cuda()
  24. out = Variable(torch.cat([out.data, zero_pads], dim=1))
  25. return out
  26. class BasicBlock(nn.Module):
  27. expansion = 1
  28. def __init__(self, inplanes, planes, stride=1, downsample=None):
  29. super(BasicBlock, self).__init__()
  30. self.conv1 = conv3x3x3(inplanes, planes, stride)
  31. self.bn1 = nn.BatchNorm3d(planes)
  32. self.relu = nn.ReLU(inplace=True)
  33. self.conv2 = conv3x3x3(planes, planes)
  34. self.bn2 = nn.BatchNorm3d(planes)
  35. self.downsample = downsample
  36. self.stride = stride
  37. def forward(self, x):
  38. residual = x
  39. out = self.conv1(x)
  40. out = self.bn1(out)
  41. out = self.relu(out)
  42. out = self.conv2(out)
  43. out = self.bn2(out)
  44. if self.downsample is not None:
  45. residual = self.downsample(x)
  46. out += residual
  47. out = self.relu(out)
  48. return out
  49. class Bottleneck(nn.Module):
  50. expansion = 4
  51. def __init__(self, inplanes, planes, stride=1, downsample=None):
  52. super(Bottleneck, self).__init__()
  53. self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
  54. self.bn1 = nn.BatchNorm3d(planes)
  55. self.conv2 = nn.Conv3d(
  56. planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
  57. self.bn2 = nn.BatchNorm3d(planes)
  58. self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
  59. self.bn3 = nn.BatchNorm3d(planes * 4)
  60. self.relu = nn.ReLU(inplace=True)
  61. self.downsample = downsample
  62. self.stride = stride
  63. def forward(self, x):
  64. residual = x
  65. out = self.conv1(x)
  66. out = self.bn1(out)
  67. out = self.relu(out)
  68. out = self.conv2(out)
  69. out = self.bn2(out)
  70. out = self.relu(out)
  71. out = self.conv3(out)
  72. out = self.bn3(out)
  73. if self.downsample is not None:
  74. residual = self.downsample(x)
  75. out += residual
  76. out = self.relu(out)
  77. return out
  78. class ResNet(nn.Module):
  79. def __init__(self,
  80. block,
  81. layers,
  82. sample_size,
  83. sample_duration,
  84. shortcut_type='B',
  85. num_classes=400):
  86. self.inplanes = 64
  87. super(ResNet, self).__init__()
  88. self.conv1 = nn.Conv3d(
  89. 3,
  90. 64,
  91. kernel_size=7,
  92. stride=(1, 2, 2),
  93. padding=(3, 3, 3),
  94. bias=False)
  95. self.bn1 = nn.BatchNorm3d(64)
  96. self.relu = nn.ReLU(inplace=True)
  97. self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
  98. self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
  99. self.layer2 = self._make_layer(
  100. block, 128, layers[1], shortcut_type, stride=2)
  101. self.layer3 = self._make_layer(
  102. block, 256, layers[2], shortcut_type, stride=2)
  103. self.layer4 = self._make_layer(
  104. block, 512, layers[3], shortcut_type, stride=2)
  105. last_duration = int(math.ceil(sample_duration / 16))
  106. last_size = int(math.ceil(sample_size / 32))
  107. self.avgpool = nn.AvgPool3d(
  108. (last_duration, last_size, last_size), stride=1)
  109. self.fc = nn.Linear(512 * block.expansion, num_classes)
  110. for m in self.modules():
  111. if isinstance(m, nn.Conv3d):
  112. m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
  113. elif isinstance(m, nn.BatchNorm3d):
  114. m.weight.data.fill_(1)
  115. m.bias.data.zero_()
  116. def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
  117. downsample = None
  118. if stride != 1 or self.inplanes != planes * block.expansion:
  119. if shortcut_type == 'A':
  120. downsample = partial(
  121. downsample_basic_block,
  122. planes=planes * block.expansion,
  123. stride=stride)
  124. else:
  125. downsample = nn.Sequential(
  126. nn.Conv3d(
  127. self.inplanes,
  128. planes * block.expansion,
  129. kernel_size=1,
  130. stride=stride,
  131. bias=False), nn.BatchNorm3d(planes * block.expansion))
  132. layers = list()
  133. layers.append(block(self.inplanes, planes, stride, downsample))
  134. self.inplanes = planes * block.expansion
  135. for i in range(1, blocks):
  136. layers.append(block(self.inplanes, planes))
  137. return nn.Sequential(*layers)
  138. def forward(self, x):
  139. x = self.conv1(x)
  140. x = self.bn1(x)
  141. x = self.relu(x)
  142. x = self.maxpool(x)
  143. x = self.layer1(x)
  144. x = self.layer2(x)
  145. x = self.layer3(x)
  146. x = self.layer4(x)
  147. x = self.avgpool(x)
  148. x = x.view(x.size(0), -1)
  149. x = self.fc(x)
  150. return x
  151. def resnet34(**kwargs):
  152. """Constructs a ResNet-34 model.
  153. """
  154. model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
  155. return model