align_face.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import numpy as np
  2. import PIL
  3. import PIL.Image
  4. import sys
  5. import os
  6. import glob
  7. import scipy
  8. import scipy.ndimage
  9. import dlib
  10. from drive import open_url
  11. from pathlib import Path
  12. import argparse
  13. from bicubic import BicubicDownSample
  14. import torchvision
  15. from shape_predictor import align_face
  16. parser = argparse.ArgumentParser(description='PULSE')
  17. parser.add_argument('-input_dir', type=str, default='realpics', help='directory with unprocessed images')
  18. parser.add_argument('-output_dir', type=str, default='input', help='output directory')
  19. parser.add_argument('-output_size', type=int, default=32, help='size to downscale the input images to, must be power of 2')
  20. parser.add_argument('-seed', type=int, help='manual seed to use')
  21. parser.add_argument('-cache_dir', type=str, default='cache', help='cache directory for model weights')
  22. args = parser.parse_args()
  23. cache_dir = Path(args.cache_dir)
  24. cache_dir.mkdir(parents=True, exist_ok=True)
  25. output_dir = Path(args.output_dir)
  26. output_dir.mkdir(parents=True,exist_ok=True)
  27. print("Downloading Shape Predictor")
  28. f=open_url("https://drive.google.com/uc?id=1huhv8PYpNNKbGCLOaYUjOgR1pY5pmbJx", cache_dir=cache_dir, return_path=True)
  29. predictor = dlib.shape_predictor(f)
  30. for im in Path(args.input_dir).glob("*.*"):
  31. faces = align_face(str(im),predictor)
  32. for i,face in enumerate(faces):
  33. if(args.output_size):
  34. factor = 1024//args.output_size
  35. assert args.output_size*factor == 1024
  36. D = BicubicDownSample(factor=factor)
  37. face_tensor = torchvision.transforms.ToTensor()(face).unsqueeze(0).cuda()
  38. face_tensor_lr = D(face_tensor)[0].cpu().detach().clamp(0, 1)
  39. face = torchvision.transforms.ToPILImage()(face_tensor_lr)
  40. face.save(Path(args.output_dir) / (im.stem+f"_{i}.png"))