run_webcam.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import argparse
  2. import cv2
  3. import dlib
  4. import numpy as np
  5. import tensorflow as tf
  6. from imutils import video
  7. CROP_SIZE = 256
  8. DOWNSAMPLE_RATIO = 4
  9. def reshape_for_polyline(array):
  10. """Reshape image so that it works with polyline."""
  11. return np.array(array, np.int32).reshape((-1, 1, 2))
  12. def resize(image):
  13. """Crop and resize image for pix2pix."""
  14. height, width, _ = image.shape
  15. if height != width:
  16. # crop to correct ratio
  17. size = min(height, width)
  18. oh = (height - size) // 2
  19. ow = (width - size) // 2
  20. cropped_image = image[oh:(oh + size), ow:(ow + size)]
  21. image_resize = cv2.resize(cropped_image, (CROP_SIZE, CROP_SIZE))
  22. return image_resize
  23. def load_graph(frozen_graph_filename):
  24. """Load a (frozen) Tensorflow model into memory."""
  25. graph = tf.Graph()
  26. with graph.as_default():
  27. od_graph_def = tf.GraphDef()
  28. with tf.gfile.GFile(frozen_graph_filename, 'rb') as fid:
  29. serialized_graph = fid.read()
  30. od_graph_def.ParseFromString(serialized_graph)
  31. tf.import_graph_def(od_graph_def, name='')
  32. return graph
  33. def main():
  34. # TensorFlow
  35. graph = load_graph(args.frozen_model_file)
  36. image_tensor = graph.get_tensor_by_name('image_tensor:0')
  37. output_tensor = graph.get_tensor_by_name('generate_output/output:0')
  38. sess = tf.Session(graph=graph)
  39. # OpenCV
  40. cap = cv2.VideoCapture(args.video_source)
  41. fps = video.FPS().start()
  42. while True:
  43. ret, frame = cap.read()
  44. # resize image and detect face
  45. frame_resize = cv2.resize(frame, None, fx=1 / DOWNSAMPLE_RATIO, fy=1 / DOWNSAMPLE_RATIO)
  46. gray = cv2.cvtColor(frame_resize, cv2.COLOR_BGR2GRAY)
  47. faces = detector(gray, 1)
  48. black_image = np.zeros(frame.shape, np.uint8)
  49. for face in faces:
  50. detected_landmarks = predictor(gray, face).parts()
  51. landmarks = [[p.x * DOWNSAMPLE_RATIO, p.y * DOWNSAMPLE_RATIO] for p in detected_landmarks]
  52. jaw = reshape_for_polyline(landmarks[0:17])
  53. left_eyebrow = reshape_for_polyline(landmarks[22:27])
  54. right_eyebrow = reshape_for_polyline(landmarks[17:22])
  55. nose_bridge = reshape_for_polyline(landmarks[27:31])
  56. lower_nose = reshape_for_polyline(landmarks[30:35])
  57. left_eye = reshape_for_polyline(landmarks[42:48])
  58. right_eye = reshape_for_polyline(landmarks[36:42])
  59. outer_lip = reshape_for_polyline(landmarks[48:60])
  60. inner_lip = reshape_for_polyline(landmarks[60:68])
  61. color = (255, 255, 255)
  62. thickness = 3
  63. cv2.polylines(black_image, [jaw], False, color, thickness)
  64. cv2.polylines(black_image, [left_eyebrow], False, color, thickness)
  65. cv2.polylines(black_image, [right_eyebrow], False, color, thickness)
  66. cv2.polylines(black_image, [nose_bridge], False, color, thickness)
  67. cv2.polylines(black_image, [lower_nose], True, color, thickness)
  68. cv2.polylines(black_image, [left_eye], True, color, thickness)
  69. cv2.polylines(black_image, [right_eye], True, color, thickness)
  70. cv2.polylines(black_image, [outer_lip], True, color, thickness)
  71. cv2.polylines(black_image, [inner_lip], True, color, thickness)
  72. # generate prediction
  73. combined_image = np.concatenate([resize(black_image), resize(frame_resize)], axis=1)
  74. image_rgb = cv2.cvtColor(combined_image, cv2.COLOR_BGR2RGB) # OpenCV uses BGR instead of RGB
  75. generated_image = sess.run(output_tensor, feed_dict={image_tensor: image_rgb})
  76. image_bgr = cv2.cvtColor(np.squeeze(generated_image), cv2.COLOR_RGB2BGR)
  77. output_image = np.concatenate([resize(frame_resize), image_bgr], axis=1)
  78. cv2.imshow('frame', output_image)
  79. fps.update()
  80. if cv2.waitKey(1) & 0xFF == ord('q'):
  81. break
  82. fps.stop()
  83. print('[INFO] elapsed time (total): {:.2f}'.format(fps.elapsed()))
  84. print('[INFO] approx. FPS: {:.2f}'.format(fps.fps()))
  85. sess.close()
  86. cap.release()
  87. cv2.destroyAllWindows()
  88. if __name__ == '__main__':
  89. parser = argparse.ArgumentParser()
  90. parser.add_argument('-src', '--source', dest='video_source', type=int,
  91. default=0, help='Device index of the camera.')
  92. parser.add_argument('--landmark-model', dest='face_landmark_shape_file', type=str, help='Face landmark model file.')
  93. parser.add_argument('--tf-model', dest='frozen_model_file', type=str, help='Frozen TensorFlow model file.')
  94. args = parser.parse_args()
  95. # Create the face predictor and landmark predictor
  96. detector = dlib.get_frontal_face_detector()
  97. predictor = dlib.shape_predictor(args.face_landmark_shape_file)
  98. main()