stylegan.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. #Modified from https://github.com/lernapparat/lernapparat/
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from collections import OrderedDict
  6. import pickle
  7. import numpy as np
  8. class MyLinear(nn.Module):
  9. """Linear layer with equalized learning rate and custom learning rate multiplier."""
  10. def __init__(self, input_size, output_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True):
  11. super().__init__()
  12. he_std = gain * input_size**(-0.5) # He init
  13. # Equalized learning rate and custom learning rate multiplier.
  14. if use_wscale:
  15. init_std = 1.0 / lrmul
  16. self.w_mul = he_std * lrmul
  17. else:
  18. init_std = he_std / lrmul
  19. self.w_mul = lrmul
  20. self.weight = torch.nn.Parameter(
  21. torch.randn(output_size, input_size) * init_std)
  22. if bias:
  23. self.bias = torch.nn.Parameter(torch.zeros(output_size))
  24. self.b_mul = lrmul
  25. else:
  26. self.bias = None
  27. def forward(self, x):
  28. bias = self.bias
  29. if bias is not None:
  30. bias = bias * self.b_mul
  31. return F.linear(x, self.weight * self.w_mul, bias)
  32. class MyConv2d(nn.Module):
  33. """Conv layer with equalized learning rate and custom learning rate multiplier."""
  34. def __init__(self, input_channels, output_channels, kernel_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True,
  35. intermediate=None, upscale=False):
  36. super().__init__()
  37. if upscale:
  38. self.upscale = Upscale2d()
  39. else:
  40. self.upscale = None
  41. he_std = gain * (input_channels * kernel_size **
  42. 2) ** (-0.5) # He init
  43. self.kernel_size = kernel_size
  44. if use_wscale:
  45. init_std = 1.0 / lrmul
  46. self.w_mul = he_std * lrmul
  47. else:
  48. init_std = he_std / lrmul
  49. self.w_mul = lrmul
  50. self.weight = torch.nn.Parameter(torch.randn(
  51. output_channels, input_channels, kernel_size, kernel_size) * init_std)
  52. if bias:
  53. self.bias = torch.nn.Parameter(torch.zeros(output_channels))
  54. self.b_mul = lrmul
  55. else:
  56. self.bias = None
  57. self.intermediate = intermediate
  58. def forward(self, x):
  59. bias = self.bias
  60. if bias is not None:
  61. bias = bias * self.b_mul
  62. have_convolution = False
  63. if self.upscale is not None and min(x.shape[2:]) * 2 >= 128:
  64. # this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way
  65. # this really needs to be cleaned up and go into the conv...
  66. w = self.weight * self.w_mul
  67. w = w.permute(1, 0, 2, 3)
  68. # probably applying a conv on w would be more efficient. also this quadruples the weight (average)?!
  69. w = F.pad(w, (1, 1, 1, 1))
  70. w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + \
  71. w[:, :, 1:, :-1] + w[:, :, :-1, :-1]
  72. x = F.conv_transpose2d(
  73. x, w, stride=2, padding=int((w.size(-1)-1)//2))
  74. have_convolution = True
  75. elif self.upscale is not None:
  76. x = self.upscale(x)
  77. if not have_convolution and self.intermediate is None:
  78. return F.conv2d(x, self.weight * self.w_mul, bias, padding=int(self.kernel_size//2))
  79. elif not have_convolution:
  80. x = F.conv2d(x, self.weight * self.w_mul, None,
  81. padding=int(self.kernel_size//2))
  82. if self.intermediate is not None:
  83. x = self.intermediate(x)
  84. if bias is not None:
  85. x = x + bias.view(1, -1, 1, 1)
  86. return x
  87. class NoiseLayer(nn.Module):
  88. """adds noise. noise is per pixel (constant over channels) with per-channel weight"""
  89. def __init__(self, channels):
  90. super().__init__()
  91. self.weight = nn.Parameter(torch.zeros(channels))
  92. self.noise = None
  93. def forward(self, x, noise=None):
  94. if noise is None and self.noise is None:
  95. noise = torch.randn(x.size(0), 1, x.size(
  96. 2), x.size(3), device=x.device, dtype=x.dtype)
  97. elif noise is None:
  98. # here is a little trick: if you get all the noiselayers and set each
  99. # modules .noise attribute, you can have pre-defined noise.
  100. # Very useful for analysis
  101. noise = self.noise
  102. x = x + self.weight.view(1, -1, 1, 1) * noise
  103. return x
  104. class StyleMod(nn.Module):
  105. def __init__(self, latent_size, channels, use_wscale):
  106. super(StyleMod, self).__init__()
  107. self.lin = MyLinear(latent_size,
  108. channels * 2,
  109. gain=1.0, use_wscale=use_wscale)
  110. def forward(self, x, latent):
  111. style = self.lin(latent) # style => [batch_size, n_channels*2]
  112. shape = [-1, 2, x.size(1)] + (x.dim() - 2) * [1]
  113. style = style.view(shape) # [batch_size, 2, n_channels, ...]
  114. x = x * (style[:, 0] + 1.) + style[:, 1]
  115. return x
  116. class PixelNormLayer(nn.Module):
  117. def __init__(self, epsilon=1e-8):
  118. super().__init__()
  119. self.epsilon = epsilon
  120. def forward(self, x):
  121. return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)
  122. class BlurLayer(nn.Module):
  123. def __init__(self, kernel=[1, 2, 1], normalize=True, flip=False, stride=1):
  124. super(BlurLayer, self).__init__()
  125. kernel = [1, 2, 1]
  126. kernel = torch.tensor(kernel, dtype=torch.float32)
  127. kernel = kernel[:, None] * kernel[None, :]
  128. kernel = kernel[None, None]
  129. if normalize:
  130. kernel = kernel / kernel.sum()
  131. if flip:
  132. kernel = kernel[:, :, ::-1, ::-1]
  133. self.register_buffer('kernel', kernel)
  134. self.stride = stride
  135. def forward(self, x):
  136. # expand kernel channels
  137. kernel = self.kernel.expand(x.size(1), -1, -1, -1)
  138. x = F.conv2d(
  139. x,
  140. kernel,
  141. stride=self.stride,
  142. padding=int((self.kernel.size(2)-1)/2),
  143. groups=x.size(1)
  144. )
  145. return x
  146. def upscale2d(x, factor=2, gain=1):
  147. assert x.dim() == 4
  148. if gain != 1:
  149. x = x * gain
  150. if factor != 1:
  151. shape = x.shape
  152. x = x.view(shape[0], shape[1], shape[2], 1, shape[3],
  153. 1).expand(-1, -1, -1, factor, -1, factor)
  154. x = x.contiguous().view(
  155. shape[0], shape[1], factor * shape[2], factor * shape[3])
  156. return x
  157. class Upscale2d(nn.Module):
  158. def __init__(self, factor=2, gain=1):
  159. super().__init__()
  160. assert isinstance(factor, int) and factor >= 1
  161. self.gain = gain
  162. self.factor = factor
  163. def forward(self, x):
  164. return upscale2d(x, factor=self.factor, gain=self.gain)
  165. class G_mapping(nn.Sequential):
  166. def __init__(self, nonlinearity='lrelu', use_wscale=True):
  167. act, gain = {'relu': (torch.relu, np.sqrt(2)),
  168. 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
  169. layers = [
  170. ('pixel_norm', PixelNormLayer()),
  171. ('dense0', MyLinear(512, 512, gain=gain,
  172. lrmul=0.01, use_wscale=use_wscale)),
  173. ('dense0_act', act),
  174. ('dense1', MyLinear(512, 512, gain=gain,
  175. lrmul=0.01, use_wscale=use_wscale)),
  176. ('dense1_act', act),
  177. ('dense2', MyLinear(512, 512, gain=gain,
  178. lrmul=0.01, use_wscale=use_wscale)),
  179. ('dense2_act', act),
  180. ('dense3', MyLinear(512, 512, gain=gain,
  181. lrmul=0.01, use_wscale=use_wscale)),
  182. ('dense3_act', act),
  183. ('dense4', MyLinear(512, 512, gain=gain,
  184. lrmul=0.01, use_wscale=use_wscale)),
  185. ('dense4_act', act),
  186. ('dense5', MyLinear(512, 512, gain=gain,
  187. lrmul=0.01, use_wscale=use_wscale)),
  188. ('dense5_act', act),
  189. ('dense6', MyLinear(512, 512, gain=gain,
  190. lrmul=0.01, use_wscale=use_wscale)),
  191. ('dense6_act', act),
  192. ('dense7', MyLinear(512, 512, gain=gain,
  193. lrmul=0.01, use_wscale=use_wscale)),
  194. ('dense7_act', act)
  195. ]
  196. super().__init__(OrderedDict(layers))
  197. def forward(self, x):
  198. x = super().forward(x)
  199. return x
  200. class Truncation(nn.Module):
  201. def __init__(self, avg_latent, max_layer=8, threshold=0.7):
  202. super().__init__()
  203. self.max_layer = max_layer
  204. self.threshold = threshold
  205. self.register_buffer('avg_latent', avg_latent)
  206. def forward(self, x):
  207. assert x.dim() == 3
  208. interp = torch.lerp(self.avg_latent, x, self.threshold)
  209. do_trunc = (torch.arange(x.size(1)) < self.max_layer).view(1, -1, 1)
  210. return torch.where(do_trunc, interp, x)
  211. class LayerEpilogue(nn.Module):
  212. """Things to do at the end of each layer."""
  213. def __init__(self, channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
  214. super().__init__()
  215. layers = []
  216. if use_noise:
  217. self.noise = NoiseLayer(channels)
  218. else:
  219. self.noise = None
  220. layers.append(('activation', activation_layer))
  221. if use_pixel_norm:
  222. layers.append(('pixel_norm', PixelNormLayer()))
  223. if use_instance_norm:
  224. layers.append(('instance_norm', nn.InstanceNorm2d(channels)))
  225. self.top_epi = nn.Sequential(OrderedDict(layers))
  226. if use_styles:
  227. self.style_mod = StyleMod(
  228. dlatent_size, channels, use_wscale=use_wscale)
  229. else:
  230. self.style_mod = None
  231. def forward(self, x, dlatents_in_slice=None, noise_in_slice=None):
  232. if(self.noise is not None):
  233. x = self.noise(x, noise=noise_in_slice)
  234. x = self.top_epi(x)
  235. if self.style_mod is not None:
  236. x = self.style_mod(x, dlatents_in_slice)
  237. else:
  238. assert dlatents_in_slice is None
  239. return x
  240. class InputBlock(nn.Module):
  241. def __init__(self, nf, dlatent_size, const_input_layer, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
  242. super().__init__()
  243. self.const_input_layer = const_input_layer
  244. self.nf = nf
  245. if self.const_input_layer:
  246. # called 'const' in tf
  247. self.const = nn.Parameter(torch.ones(1, nf, 4, 4))
  248. self.bias = nn.Parameter(torch.ones(nf))
  249. else:
  250. # tweak gain to match the official implementation of Progressing GAN
  251. self.dense = MyLinear(dlatent_size, nf*16,
  252. gain=gain/4, use_wscale=use_wscale)
  253. self.epi1 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise,
  254. use_pixel_norm, use_instance_norm, use_styles, activation_layer)
  255. self.conv = MyConv2d(nf, nf, 3, gain=gain, use_wscale=use_wscale)
  256. self.epi2 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise,
  257. use_pixel_norm, use_instance_norm, use_styles, activation_layer)
  258. def forward(self, dlatents_in_range, noise_in_range):
  259. batch_size = dlatents_in_range.size(0)
  260. if self.const_input_layer:
  261. x = self.const.expand(batch_size, -1, -1, -1)
  262. x = x + self.bias.view(1, -1, 1, 1)
  263. else:
  264. x = self.dense(dlatents_in_range[:, 0]).view(
  265. batch_size, self.nf, 4, 4)
  266. x = self.epi1(x, dlatents_in_range[:, 0], noise_in_range[0])
  267. x = self.conv(x)
  268. x = self.epi2(x, dlatents_in_range[:, 1], noise_in_range[1])
  269. return x
  270. class GSynthesisBlock(nn.Module):
  271. def __init__(self, in_channels, out_channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
  272. # 2**res x 2**res # res = 3..resolution_log2
  273. super().__init__()
  274. if blur_filter:
  275. blur = BlurLayer(blur_filter)
  276. else:
  277. blur = None
  278. self.conv0_up = MyConv2d(in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale,
  279. intermediate=blur, upscale=True)
  280. self.epi1 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise,
  281. use_pixel_norm, use_instance_norm, use_styles, activation_layer)
  282. self.conv1 = MyConv2d(out_channels, out_channels,
  283. kernel_size=3, gain=gain, use_wscale=use_wscale)
  284. self.epi2 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise,
  285. use_pixel_norm, use_instance_norm, use_styles, activation_layer)
  286. def forward(self, x, dlatents_in_range, noise_in_range):
  287. x = self.conv0_up(x)
  288. x = self.epi1(x, dlatents_in_range[:, 0], noise_in_range[0])
  289. x = self.conv1(x)
  290. x = self.epi2(x, dlatents_in_range[:, 1], noise_in_range[1])
  291. return x
  292. class G_synthesis(nn.Module):
  293. def __init__(self,
  294. # Disentangled latent (W) dimensionality.
  295. dlatent_size=512,
  296. num_channels=3, # Number of output color channels.
  297. resolution=1024, # Output resolution.
  298. # Overall multiplier for the number of feature maps.
  299. fmap_base=8192,
  300. # log2 feature map reduction when doubling the resolution.
  301. fmap_decay=1.0,
  302. # Maximum number of feature maps in any layer.
  303. fmap_max=512,
  304. use_styles=True, # Enable style inputs?
  305. const_input_layer=True, # First layer is a learned constant?
  306. use_noise=True, # Enable noise inputs?
  307. # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
  308. randomize_noise=True,
  309. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu'
  310. use_wscale=True, # Enable equalized learning rate?
  311. use_pixel_norm=False, # Enable pixelwise feature vector normalization?
  312. use_instance_norm=True, # Enable instance normalization?
  313. # Data type to use for activations and outputs.
  314. dtype=torch.float32,
  315. # Low-pass filter to apply when resampling activations. None = no filtering.
  316. blur_filter=[1, 2, 1],
  317. ):
  318. super().__init__()
  319. def nf(stage):
  320. return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
  321. self.dlatent_size = dlatent_size
  322. resolution_log2 = int(np.log2(resolution))
  323. assert resolution == 2**resolution_log2 and resolution >= 4
  324. act, gain = {'relu': (torch.relu, np.sqrt(2)),
  325. 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
  326. num_layers = resolution_log2 * 2 - 2
  327. num_styles = num_layers if use_styles else 1
  328. torgbs = []
  329. blocks = []
  330. for res in range(2, resolution_log2 + 1):
  331. channels = nf(res-1)
  332. name = '{s}x{s}'.format(s=2**res)
  333. if res == 2:
  334. blocks.append((name,
  335. InputBlock(channels, dlatent_size, const_input_layer, gain, use_wscale,
  336. use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))
  337. else:
  338. blocks.append((name,
  339. GSynthesisBlock(last_channels, channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))
  340. last_channels = channels
  341. self.torgb = MyConv2d(channels, num_channels, 1,
  342. gain=1, use_wscale=use_wscale)
  343. self.blocks = nn.ModuleDict(OrderedDict(blocks))
  344. def forward(self, dlatents_in, noise_in):
  345. # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].
  346. # lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype)
  347. batch_size = dlatents_in.size(0)
  348. for i, m in enumerate(self.blocks.values()):
  349. if i == 0:
  350. x = m(dlatents_in[:, 2*i:2*i+2], noise_in[2*i:2*i+2])
  351. else:
  352. x = m(x, dlatents_in[:, 2*i:2*i+2], noise_in[2*i:2*i+2])
  353. rgb = self.torgb(x)
  354. return rgb