Browse Source

update conv&deconv for new pix2pix version

Servon 6 years ago
parent
commit
90c6bd6e00
1 changed files with 37 additions and 40 deletions
  1. 37 40
      reduce_model.py

+ 37 - 40
reduce_model.py

@@ -18,16 +18,13 @@ def deprocess(image):
         return (image + 1) / 2
 
 
-def conv(batch_input, out_channels, stride):
-    with tf.variable_scope('conv'):
-        in_channels = batch_input.get_shape()[3]
-        filter = tf.get_variable('filter', [4, 4, in_channels, out_channels], dtype=tf.float32,
-                                 initializer=tf.random_normal_initializer(0, 0.02))
-        # [batch, in_height, in_width, in_channels], [filter_width, filter_height, in_channels, out_channels]
-        #     => [batch, out_height, out_width, out_channels]
-        padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='CONSTANT')
-        conv = tf.nn.conv2d(padded_input, filter, [1, stride, stride, 1], padding='VALID')
-        return conv
+def gen_conv(batch_input, out_channels):
+    # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
+    initializer = tf.random_normal_initializer(0, 0.02)
+    # if a.separable_conv:
+    #     return tf.layers.separable_conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer)
+    # else:
+    return tf.layers.conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer)
 
 
 def lrelu(x, a):
@@ -42,31 +39,31 @@ def lrelu(x, a):
         return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)
 
 
-def batchnorm(input):
-    with tf.variable_scope('batchnorm'):
-        # this block looks like it has 3 inputs on the graph unless we do this
-        input = tf.identity(input)
-
-        channels = input.get_shape()[3]
-        offset = tf.get_variable('offset', [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
-        scale = tf.get_variable('scale', [channels], dtype=tf.float32,
-                                initializer=tf.random_normal_initializer(1.0, 0.02))
-        mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False)
-        variance_epsilon = 1e-5
-        normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
-        return normalized
-
-
-def deconv(batch_input, out_channels):
-    with tf.variable_scope('deconv'):
-        batch, in_height, in_width, in_channels = [int(d) for d in batch_input.get_shape()]
-        filter = tf.get_variable('filter', [4, 4, out_channels, in_channels], dtype=tf.float32,
-                                 initializer=tf.random_normal_initializer(0, 0.02))
-        # [batch, in_height, in_width, in_channels], [filter_width, filter_height, out_channels, in_channels]
-        #     => [batch, out_height, out_width, out_channels]
-        conv = tf.nn.conv2d_transpose(batch_input, filter, [batch, in_height * 2, in_width * 2, out_channels],
-                                      [1, 2, 2, 1], padding='SAME')
-        return conv
+def batchnorm(inputs):
+    return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=tf.random_normal_initializer(1.0, 0.02))
+    # with tf.variable_scope('batchnorm'):
+    #     # this block looks like it has 3 inputs on the graph unless we do this
+    #     input = tf.identity(input)
+    #
+    #     channels = input.get_shape()[3]
+    #     offset = tf.get_variable('offset', [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
+    #     scale = tf.get_variable('scale', [channels], dtype=tf.float32,
+    #                             initializer=tf.random_normal_initializer(1.0, 0.02))
+    #     mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False)
+    #     variance_epsilon = 1e-5
+    #     normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
+    #     return normalized
+
+
+def gen_deconv(batch_input, out_channels):
+    # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
+    initializer = tf.random_normal_initializer(0, 0.02)
+    # if a.separable_conv:
+    #     _b, h, w, _c = batch_input.shape
+    #     resized_input = tf.image.resize_images(batch_input, [h * 2, w * 2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
+    #     return tf.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer)
+    # else:
+    return tf.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer)
 
 
 def process_image(x):
@@ -109,7 +106,7 @@ def create_generator(generator_inputs, generator_outputs_channels):
 
     # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
     with tf.variable_scope('encoder_1'):
-        output = conv(generator_inputs, ngf, stride=2)
+        output = gen_conv(generator_inputs, ngf)
         layers.append(output)
 
     layer_specs = [
@@ -126,7 +123,7 @@ def create_generator(generator_inputs, generator_outputs_channels):
         with tf.variable_scope('encoder_%d' % (len(layers) + 1)):
             rectified = lrelu(layers[-1], 0.2)
             # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
-            convolved = conv(rectified, out_channels, stride=2)
+            convolved = gen_conv(rectified, out_channels)
             output = batchnorm(convolved)
             layers.append(output)
 
@@ -153,7 +150,7 @@ def create_generator(generator_inputs, generator_outputs_channels):
 
             rectified = tf.nn.relu(input)
             # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
-            output = deconv(rectified, out_channels)
+            output = gen_deconv(rectified, out_channels)
             output = batchnorm(output)
 
             if dropout > 0.0:
@@ -165,7 +162,7 @@ def create_generator(generator_inputs, generator_outputs_channels):
     with tf.variable_scope('decoder_1'):
         input = tf.concat([layers[-1], layers[0]], axis=3)
         rectified = tf.nn.relu(input)
-        output = deconv(rectified, generator_outputs_channels)
+        output = gen_deconv(rectified, generator_outputs_channels)
         output = tf.tanh(output)
         layers.append(output)
 
@@ -173,7 +170,7 @@ def create_generator(generator_inputs, generator_outputs_channels):
 
 
 def create_model(inputs, targets):
-    with tf.variable_scope('generator') as scope:
+    with tf.variable_scope('generator'): # as scope
         out_channels = int(targets.get_shape()[-1])
         outputs = create_generator(inputs, out_channels)