conv.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import torch
  2. from torch import nn
  3. from torchvision import models
  4. def set_parameter_requires_grad(model, feature_extracting):
  5. if feature_extracting:
  6. for param in model.parameters():
  7. param.requires_grad = False
  8. def convnet_init(model_name: str,
  9. num_classes: int,
  10. feature_extract: bool,
  11. use_pretrained: bool = True):
  12. # Initialize these variables which will be set in this if statement. Each of these
  13. # variables is model specific.
  14. model_ft = None
  15. input_size = 0
  16. if model_name == "resnet":
  17. """ Resnet18
  18. """
  19. model_ft = models.resnet18(pretrained=use_pretrained)
  20. set_parameter_requires_grad(model_ft, feature_extract)
  21. # num_ftrs = model_ft.fc.in_features
  22. # model_ft.fc = nn.Linear(num_ftrs, num_classes)
  23. model_ft.fc = nn.Identity()
  24. input_size = 224
  25. elif model_name == "alexnet":
  26. """ Alexnet
  27. """
  28. model_ft = models.alexnet(pretrained=use_pretrained)
  29. set_parameter_requires_grad(model_ft, feature_extract)
  30. # num_ftrs = model_ft.classifier[6].in_features
  31. # model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
  32. model_ft.classifier = nn.Identity()
  33. input_size = 224
  34. elif model_name == "vgg":
  35. """ VGG11_bn
  36. """
  37. model_ft = models.vgg11_bn(pretrained=use_pretrained)
  38. set_parameter_requires_grad(model_ft, feature_extract)
  39. # num_ftrs = model_ft.classifier[6].in_features
  40. # model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
  41. model_ft.fc = nn.Identity()
  42. input_size = 224
  43. elif model_name == "squeezenet":
  44. """ Squeezenet
  45. """
  46. model_ft = models.squeezenet1_0(pretrained=use_pretrained)
  47. set_parameter_requires_grad(model_ft, feature_extract)
  48. # TODO: this is my attempt to remove the last FC layer, doesn't seem to work for SqueezeNet
  49. # model_ft.classifier = nn.Identity()
  50. # model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
  51. model_ft.classifier = nn.Identity()
  52. # model_ft.num_classes = num_classes
  53. input_size = 224
  54. elif model_name == "densenet":
  55. """ Densenet
  56. """
  57. model_ft = models.densenet121(pretrained=use_pretrained)
  58. set_parameter_requires_grad(model_ft, feature_extract)
  59. num_ftrs = model_ft.classifier.in_features
  60. # model_ft.classifier = nn.Linear(num_ftrs, num_classes)
  61. model_ft.classifier = nn.Identity()
  62. input_size = 224
  63. # elif model_name == "inception":
  64. # """ Inception v3
  65. # Be careful, expects (299,299) sized images and has auxiliary output
  66. # """
  67. # model_ft = models.inception_v3(pretrained=use_pretrained)
  68. # set_parameter_requires_grad(model_ft, feature_extract)
  69. # # Handle the auxiliary net
  70. # num_ftrs = model_ft.AuxLogits.fc.in_features
  71. # model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
  72. # # Handle the primary net
  73. # num_ftrs = model_ft.fc.in_features
  74. # # model_ft.fc = nn.Linear(num_ftrs, num_classes)
  75. # model_ft.fc = nn.Identity()
  76. # input_size = 299
  77. else:
  78. print("Invalid model name, exiting...")
  79. exit()
  80. output_size = model_ft(torch.rand((1, 3, input_size, input_size))).shape[1]
  81. return model_ft, input_size, output_size