conv.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # adapted from PyTorch tutorials
  2. import torch
  3. from torch import nn
  4. from torchvision import models
  5. def set_parameter_requires_grad(model, feature_extracting):
  6. if feature_extracting:
  7. for param in model.parameters():
  8. param.requires_grad = False
  9. def convnet_init(model_name: str,
  10. num_classes: int,
  11. feature_extract: bool,
  12. use_pretrained: bool = True):
  13. # Initialize these variables which will be set in this if statement. Each of these
  14. # variables is model specific.
  15. model_ft = None
  16. input_size = 0
  17. if model_name == "resnet":
  18. """ Resnet18
  19. """
  20. model_ft = models.resnet18(pretrained=use_pretrained)
  21. set_parameter_requires_grad(model_ft, feature_extract)
  22. # num_ftrs = model_ft.fc.in_features
  23. # model_ft.fc = nn.Linear(num_ftrs, num_classes)
  24. model_ft.fc = nn.Identity()
  25. input_size = 224
  26. elif model_name == "alexnet":
  27. """ Alexnet
  28. """
  29. model_ft = models.alexnet(pretrained=use_pretrained)
  30. set_parameter_requires_grad(model_ft, feature_extract)
  31. # num_ftrs = model_ft.classifier[6].in_features
  32. # model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
  33. model_ft.classifier = nn.Identity()
  34. input_size = 224
  35. elif model_name == "vgg":
  36. """ VGG11_bn
  37. """
  38. model_ft = models.vgg11_bn(pretrained=use_pretrained)
  39. set_parameter_requires_grad(model_ft, feature_extract)
  40. # num_ftrs = model_ft.classifier[6].in_features
  41. # model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
  42. model_ft.fc = nn.Identity()
  43. input_size = 224
  44. elif model_name == "squeezenet":
  45. """ Squeezenet
  46. """
  47. model_ft = models.squeezenet1_0(pretrained=use_pretrained)
  48. set_parameter_requires_grad(model_ft, feature_extract)
  49. # TODO: this is my attempt to remove the last FC layer, doesn't seem to work for SqueezeNet
  50. # model_ft.classifier = nn.Identity()
  51. # model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
  52. model_ft.classifier = nn.Identity()
  53. # model_ft.num_classes = num_classes
  54. input_size = 224
  55. elif model_name == "densenet":
  56. """ Densenet
  57. """
  58. model_ft = models.densenet121(pretrained=use_pretrained)
  59. set_parameter_requires_grad(model_ft, feature_extract)
  60. num_ftrs = model_ft.classifier.in_features
  61. # model_ft.classifier = nn.Linear(num_ftrs, num_classes)
  62. model_ft.classifier = nn.Identity()
  63. input_size = 224
  64. # elif model_name == "inception":
  65. # """ Inception v3
  66. # Be careful, expects (299,299) sized images and has auxiliary output
  67. # """
  68. # model_ft = models.inception_v3(pretrained=use_pretrained)
  69. # set_parameter_requires_grad(model_ft, feature_extract)
  70. # # Handle the auxiliary net
  71. # num_ftrs = model_ft.AuxLogits.fc.in_features
  72. # model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
  73. # # Handle the primary net
  74. # num_ftrs = model_ft.fc.in_features
  75. # # model_ft.fc = nn.Linear(num_ftrs, num_classes)
  76. # model_ft.fc = nn.Identity()
  77. # input_size = 299
  78. else:
  79. print("Invalid model name, exiting...")
  80. exit()
  81. output_size = model_ft(torch.rand((1, 3, input_size, input_size))).shape[1]
  82. return model_ft, input_size, output_size