reduce_model.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import argparse
  2. import tensorflow as tf
  3. CROP_SIZE = 256 # scale_size = CROP_SIZE
  4. ngf = 64
  5. ndf = 64
  6. def preprocess(image):
  7. with tf.name_scope('preprocess'):
  8. # [0, 1] => [-1, 1]
  9. return image * 2 - 1
  10. def deprocess(image):
  11. with tf.name_scope('deprocess'):
  12. # [-1, 1] => [0, 1]
  13. return (image + 1) / 2
  14. def conv(batch_input, out_channels, stride):
  15. with tf.variable_scope('conv'):
  16. in_channels = batch_input.get_shape()[3]
  17. filter = tf.get_variable('filter', [4, 4, in_channels, out_channels], dtype=tf.float32,
  18. initializer=tf.random_normal_initializer(0, 0.02))
  19. # [batch, in_height, in_width, in_channels], [filter_width, filter_height, in_channels, out_channels]
  20. # => [batch, out_height, out_width, out_channels]
  21. padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='CONSTANT')
  22. conv = tf.nn.conv2d(padded_input, filter, [1, stride, stride, 1], padding='VALID')
  23. return conv
  24. def lrelu(x, a):
  25. with tf.name_scope('lrelu'):
  26. # adding these together creates the leak part and linear part
  27. # then cancels them out by subtracting/adding an absolute value term
  28. # leak: a*x/2 - a*abs(x)/2
  29. # linear: x/2 + abs(x)/2
  30. # this block looks like it has 2 inputs on the graph unless we do this
  31. x = tf.identity(x)
  32. return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)
  33. def batchnorm(input):
  34. with tf.variable_scope('batchnorm'):
  35. # this block looks like it has 3 inputs on the graph unless we do this
  36. input = tf.identity(input)
  37. channels = input.get_shape()[3]
  38. offset = tf.get_variable('offset', [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
  39. scale = tf.get_variable('scale', [channels], dtype=tf.float32,
  40. initializer=tf.random_normal_initializer(1.0, 0.02))
  41. mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False)
  42. variance_epsilon = 1e-5
  43. normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
  44. return normalized
  45. def deconv(batch_input, out_channels):
  46. with tf.variable_scope('deconv'):
  47. batch, in_height, in_width, in_channels = [int(d) for d in batch_input.get_shape()]
  48. filter = tf.get_variable('filter', [4, 4, out_channels, in_channels], dtype=tf.float32,
  49. initializer=tf.random_normal_initializer(0, 0.02))
  50. # [batch, in_height, in_width, in_channels], [filter_width, filter_height, out_channels, in_channels]
  51. # => [batch, out_height, out_width, out_channels]
  52. conv = tf.nn.conv2d_transpose(batch_input, filter, [batch, in_height * 2, in_width * 2, out_channels],
  53. [1, 2, 2, 1], padding='SAME')
  54. return conv
  55. def process_image(x):
  56. with tf.name_scope('load_images'):
  57. raw_input = tf.image.convert_image_dtype(x, dtype=tf.float32)
  58. raw_input.set_shape([None, None, 3])
  59. # break apart image pair and move to range [-1, 1]
  60. width = tf.shape(raw_input)[1] # [height, width, channels]
  61. a_images = preprocess(raw_input[:, :width // 2, :])
  62. b_images = preprocess(raw_input[:, width // 2:, :])
  63. inputs, targets = [a_images, b_images]
  64. # synchronize seed for image operations so that we do the same operations to both
  65. # input and output images
  66. def transform(image):
  67. r = image
  68. # area produces a nice downscaling, but does nearest neighbor for upscaling
  69. # assume we're going to be doing downscaling here
  70. r = tf.image.resize_images(r, [CROP_SIZE, CROP_SIZE], method=tf.image.ResizeMethod.AREA)
  71. return r
  72. with tf.name_scope('input_images'):
  73. input_images = tf.expand_dims(transform(inputs), 0)
  74. with tf.name_scope('target_images'):
  75. target_images = tf.expand_dims(transform(targets), 0)
  76. return input_images, target_images
  77. # Tensor('batch:1', shape=(1, 256, 256, 3), dtype=float32) -> 1 batch size
  78. def create_generator(generator_inputs, generator_outputs_channels):
  79. layers = []
  80. # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
  81. with tf.variable_scope('encoder_1'):
  82. output = conv(generator_inputs, ngf, stride=2)
  83. layers.append(output)
  84. layer_specs = [
  85. ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
  86. ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
  87. ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
  88. ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
  89. ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
  90. ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
  91. ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
  92. ]
  93. for out_channels in layer_specs:
  94. with tf.variable_scope('encoder_%d' % (len(layers) + 1)):
  95. rectified = lrelu(layers[-1], 0.2)
  96. # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
  97. convolved = conv(rectified, out_channels, stride=2)
  98. output = batchnorm(convolved)
  99. layers.append(output)
  100. layer_specs = [
  101. (ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
  102. (ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
  103. (ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
  104. (ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
  105. (ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
  106. (ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
  107. (ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
  108. ]
  109. num_encoder_layers = len(layers)
  110. for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
  111. skip_layer = num_encoder_layers - decoder_layer - 1
  112. with tf.variable_scope('decoder_%d' % (skip_layer + 1)):
  113. if decoder_layer == 0:
  114. # first decoder layer doesn't have skip connections
  115. # since it is directly connected to the skip_layer
  116. input = layers[-1]
  117. else:
  118. input = tf.concat([layers[-1], layers[skip_layer]], axis=3)
  119. rectified = tf.nn.relu(input)
  120. # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
  121. output = deconv(rectified, out_channels)
  122. output = batchnorm(output)
  123. if dropout > 0.0:
  124. output = tf.nn.dropout(output, keep_prob=1 - dropout)
  125. layers.append(output)
  126. # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
  127. with tf.variable_scope('decoder_1'):
  128. input = tf.concat([layers[-1], layers[0]], axis=3)
  129. rectified = tf.nn.relu(input)
  130. output = deconv(rectified, generator_outputs_channels)
  131. output = tf.tanh(output)
  132. layers.append(output)
  133. return layers[-1]
  134. def create_model(inputs, targets):
  135. with tf.variable_scope('generator') as scope:
  136. out_channels = int(targets.get_shape()[-1])
  137. outputs = create_generator(inputs, out_channels)
  138. return outputs
  139. def convert(image):
  140. return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True, name='output') # output tensor
  141. def generate_output(x):
  142. with tf.name_scope('generate_output'):
  143. test_inputs, test_targets = process_image(x)
  144. # inputs and targets are [batch_size, height, width, channels]
  145. model = create_model(test_inputs, test_targets)
  146. # deprocess files
  147. outputs = deprocess(model)
  148. # reverse any processing on images so they can be written to disk or displayed to user
  149. converted_outputs = convert(outputs)
  150. return converted_outputs
  151. if __name__ == '__main__':
  152. parser = argparse.ArgumentParser()
  153. parser.add_argument('--model-input', dest='input_folder', type=str, help='Model folder to import.')
  154. parser.add_argument('--model-output', dest='output_folder', type=str, help='Model (reduced) folder to export.')
  155. args = parser.parse_args()
  156. x = tf.placeholder(tf.uint8, shape=(256, 512, 3), name='image_tensor') # input tensor
  157. y = generate_output(x)
  158. with tf.Session() as sess:
  159. # Restore original model
  160. saver = tf.train.Saver()
  161. checkpoint = tf.train.latest_checkpoint(args.input_folder)
  162. saver.restore(sess, checkpoint)
  163. # Export reduced model used for prediction
  164. saver = tf.train.Saver()
  165. saver.save(sess, '{}/reduced_model'.format(args.output_folder))
  166. print("Model is exported to {}".format(checkpoint))