conv.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from torch import nn
  2. from torchvision import models
  3. import torch
  4. def set_parameter_requires_grad(model, feature_extracting):
  5. if feature_extracting:
  6. for param in model.parameters():
  7. param.requires_grad = False
  8. def convnet_init(model_name, num_classes, feature_extract, use_pretrained=True):
  9. # Initialize these variables which will be set in this if statement. Each of these
  10. # variables is model specific.
  11. model_ft = None
  12. input_size = 0
  13. output_size = 0
  14. if model_name == "resnet":
  15. """ Resnet18
  16. """
  17. model_ft = models.resnet18(pretrained=use_pretrained)
  18. set_parameter_requires_grad(model_ft, feature_extract)
  19. num_ftrs = model_ft.fc.in_features
  20. # model_ft.fc = nn.Linear(num_ftrs, num_classes)
  21. model_ft.fc = nn.Identity()
  22. input_size = 224
  23. output_size = model_ft(torch.rand((1, 3, input_size, input_size))).shape[1]
  24. elif model_name == "alexnet":
  25. """ Alexnet
  26. """
  27. model_ft = models.alexnet(pretrained=use_pretrained)
  28. set_parameter_requires_grad(model_ft, feature_extract)
  29. num_ftrs = model_ft.classifier[6].in_features
  30. model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
  31. input_size = 224
  32. elif model_name == "vgg":
  33. """ VGG11_bn
  34. """
  35. model_ft = models.vgg11_bn(pretrained=use_pretrained)
  36. set_parameter_requires_grad(model_ft, feature_extract)
  37. num_ftrs = model_ft.classifier[6].in_features
  38. model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
  39. input_size = 224
  40. elif model_name == "squeezenet":
  41. """ Squeezenet
  42. """
  43. model_ft = models.squeezenet1_0(pretrained=use_pretrained)
  44. set_parameter_requires_grad(model_ft, feature_extract)
  45. # TODO: this is my attempt to remove the last FC layer, doesn't seem to work for SqueezeNet
  46. # model_ft.classifier = nn.Identity()
  47. # model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
  48. model_ft.classifier[1] = nn.Identity()
  49. # model_ft.num_classes = num_classes
  50. input_size = 224
  51. elif model_name == "densenet":
  52. """ Densenet
  53. """
  54. model_ft = models.densenet121(pretrained=use_pretrained)
  55. set_parameter_requires_grad(model_ft, feature_extract)
  56. num_ftrs = model_ft.classifier.in_features
  57. model_ft.classifier = nn.Linear(num_ftrs, num_classes)
  58. input_size = 224
  59. elif model_name == "inception":
  60. """ Inception v3
  61. Be careful, expects (299,299) sized images and has auxiliary output
  62. """
  63. model_ft = models.inception_v3(pretrained=use_pretrained)
  64. set_parameter_requires_grad(model_ft, feature_extract)
  65. # Handle the auxiliary net
  66. num_ftrs = model_ft.AuxLogits.fc.in_features
  67. model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
  68. # Handle the primary net
  69. num_ftrs = model_ft.fc.in_features
  70. model_ft.fc = nn.Linear(num_ftrs, num_classes)
  71. input_size = 299
  72. else:
  73. print("Invalid model name, exiting...")
  74. exit()
  75. return model_ft, input_size, output_size