reduce_model.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import argparse
  2. import tensorflow as tf
  3. CROP_SIZE = 256 # scale_size = CROP_SIZE
  4. ngf = 64
  5. ndf = 64
  6. def preprocess(image):
  7. with tf.name_scope('preprocess'):
  8. # [0, 1] => [-1, 1]
  9. return image * 2 - 1
  10. def deprocess(image):
  11. with tf.name_scope('deprocess'):
  12. # [-1, 1] => [0, 1]
  13. return (image + 1) / 2
  14. def gen_conv(batch_input, out_channels):
  15. # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
  16. initializer = tf.random_normal_initializer(0, 0.02)
  17. # if a.separable_conv:
  18. # return tf.layers.separable_conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer)
  19. # else:
  20. return tf.layers.conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer)
  21. def lrelu(x, a):
  22. with tf.name_scope('lrelu'):
  23. # adding these together creates the leak part and linear part
  24. # then cancels them out by subtracting/adding an absolute value term
  25. # leak: a*x/2 - a*abs(x)/2
  26. # linear: x/2 + abs(x)/2
  27. # this block looks like it has 2 inputs on the graph unless we do this
  28. x = tf.identity(x)
  29. return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)
  30. def batchnorm(inputs):
  31. return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=tf.random_normal_initializer(1.0, 0.02))
  32. # with tf.variable_scope('batchnorm'):
  33. # # this block looks like it has 3 inputs on the graph unless we do this
  34. # input = tf.identity(input)
  35. #
  36. # channels = input.get_shape()[3]
  37. # offset = tf.get_variable('offset', [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
  38. # scale = tf.get_variable('scale', [channels], dtype=tf.float32,
  39. # initializer=tf.random_normal_initializer(1.0, 0.02))
  40. # mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False)
  41. # variance_epsilon = 1e-5
  42. # normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
  43. # return normalized
  44. def gen_deconv(batch_input, out_channels):
  45. # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
  46. initializer = tf.random_normal_initializer(0, 0.02)
  47. # if a.separable_conv:
  48. # _b, h, w, _c = batch_input.shape
  49. # resized_input = tf.image.resize_images(batch_input, [h * 2, w * 2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  50. # return tf.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer)
  51. # else:
  52. return tf.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer)
  53. def process_image(x):
  54. with tf.name_scope('load_images'):
  55. raw_input = tf.image.convert_image_dtype(x, dtype=tf.float32)
  56. raw_input.set_shape([None, None, 3])
  57. # break apart image pair and move to range [-1, 1]
  58. width = tf.shape(raw_input)[1] # [height, width, channels]
  59. a_images = preprocess(raw_input[:, :width // 2, :])
  60. b_images = preprocess(raw_input[:, width // 2:, :])
  61. inputs, targets = [a_images, b_images]
  62. # synchronize seed for image operations so that we do the same operations to both
  63. # input and output images
  64. def transform(image):
  65. r = image
  66. # area produces a nice downscaling, but does nearest neighbor for upscaling
  67. # assume we're going to be doing downscaling here
  68. r = tf.image.resize_images(r, [CROP_SIZE, CROP_SIZE], method=tf.image.ResizeMethod.AREA)
  69. return r
  70. with tf.name_scope('input_images'):
  71. input_images = tf.expand_dims(transform(inputs), 0)
  72. with tf.name_scope('target_images'):
  73. target_images = tf.expand_dims(transform(targets), 0)
  74. return input_images, target_images
  75. # Tensor('batch:1', shape=(1, 256, 256, 3), dtype=float32) -> 1 batch size
  76. def create_generator(generator_inputs, generator_outputs_channels):
  77. layers = []
  78. # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
  79. with tf.variable_scope('encoder_1'):
  80. output = gen_conv(generator_inputs, ngf)
  81. layers.append(output)
  82. layer_specs = [
  83. ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
  84. ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
  85. ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
  86. ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
  87. ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
  88. ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
  89. ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
  90. ]
  91. for out_channels in layer_specs:
  92. with tf.variable_scope('encoder_%d' % (len(layers) + 1)):
  93. rectified = lrelu(layers[-1], 0.2)
  94. # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
  95. convolved = gen_conv(rectified, out_channels)
  96. output = batchnorm(convolved)
  97. layers.append(output)
  98. layer_specs = [
  99. (ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
  100. (ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
  101. (ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
  102. (ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
  103. (ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
  104. (ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
  105. (ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
  106. ]
  107. num_encoder_layers = len(layers)
  108. for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
  109. skip_layer = num_encoder_layers - decoder_layer - 1
  110. with tf.variable_scope('decoder_%d' % (skip_layer + 1)):
  111. if decoder_layer == 0:
  112. # first decoder layer doesn't have skip connections
  113. # since it is directly connected to the skip_layer
  114. input = layers[-1]
  115. else:
  116. input = tf.concat([layers[-1], layers[skip_layer]], axis=3)
  117. rectified = tf.nn.relu(input)
  118. # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
  119. output = gen_deconv(rectified, out_channels)
  120. output = batchnorm(output)
  121. if dropout > 0.0:
  122. output = tf.nn.dropout(output, keep_prob=1 - dropout)
  123. layers.append(output)
  124. # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
  125. with tf.variable_scope('decoder_1'):
  126. input = tf.concat([layers[-1], layers[0]], axis=3)
  127. rectified = tf.nn.relu(input)
  128. output = gen_deconv(rectified, generator_outputs_channels)
  129. output = tf.tanh(output)
  130. layers.append(output)
  131. return layers[-1]
  132. def create_model(inputs, targets):
  133. with tf.variable_scope('generator'): # as scope
  134. out_channels = int(targets.get_shape()[-1])
  135. outputs = create_generator(inputs, out_channels)
  136. return outputs
  137. def convert(image):
  138. return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True, name='output') # output tensor
  139. def generate_output(x):
  140. with tf.name_scope('generate_output'):
  141. test_inputs, test_targets = process_image(x)
  142. # inputs and targets are [batch_size, height, width, channels]
  143. model = create_model(test_inputs, test_targets)
  144. # deprocess files
  145. outputs = deprocess(model)
  146. # reverse any processing on images so they can be written to disk or displayed to user
  147. converted_outputs = convert(outputs)
  148. return converted_outputs
  149. if __name__ == '__main__':
  150. parser = argparse.ArgumentParser()
  151. parser.add_argument('--model-input', dest='input_folder', type=str, help='Model folder to import.')
  152. parser.add_argument('--model-output', dest='output_folder', type=str, help='Model (reduced) folder to export.')
  153. args = parser.parse_args()
  154. x = tf.placeholder(tf.uint8, shape=(256, 512, 3), name='image_tensor') # input tensor
  155. y = generate_output(x)
  156. with tf.Session() as sess:
  157. # Restore original model
  158. saver = tf.train.Saver()
  159. checkpoint = tf.train.latest_checkpoint(args.input_folder)
  160. saver.restore(sess, checkpoint)
  161. # Export reduced model used for prediction
  162. saver = tf.train.Saver()
  163. saver.save(sess, '{}/reduced_model'.format(args.output_folder))
  164. print("Model is exported to {}".format(checkpoint))