PULSE.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from stylegan import G_synthesis,G_mapping
  2. from dataclasses import dataclass
  3. from SphericalOptimizer import SphericalOptimizer
  4. from pathlib import Path
  5. import numpy as np
  6. import time
  7. import torch
  8. from loss import LossBuilder
  9. from functools import partial
  10. from drive import open_url
  11. class PULSE(torch.nn.Module):
  12. def __init__(self, cache_dir, verbose=True):
  13. super(PULSE, self).__init__()
  14. self.synthesis = G_synthesis().cuda()
  15. self.verbose = verbose
  16. cache_dir = Path(cache_dir)
  17. cache_dir.mkdir(parents=True, exist_ok = True)
  18. if self.verbose: print("Loading Synthesis Network")
  19. with open_url("https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8", cache_dir=cache_dir, verbose=verbose) as f:
  20. self.synthesis.load_state_dict(torch.load(f))
  21. for param in self.synthesis.parameters():
  22. param.requires_grad = False
  23. self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2)
  24. if Path("gaussian_fit.pt").exists():
  25. self.gaussian_fit = torch.load("gaussian_fit.pt")
  26. else:
  27. if self.verbose: print("\tLoading Mapping Network")
  28. mapping = G_mapping().cuda()
  29. with open_url("https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k", cache_dir=cache_dir, verbose=verbose) as f:
  30. mapping.load_state_dict(torch.load(f))
  31. if self.verbose: print("\tRunning Mapping Network")
  32. with torch.no_grad():
  33. torch.manual_seed(0)
  34. latent = torch.randn((1000000,512),dtype=torch.float32, device="cuda")
  35. latent_out = torch.nn.LeakyReLU(5)(mapping(latent))
  36. self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)}
  37. torch.save(self.gaussian_fit,"gaussian_fit.pt")
  38. if self.verbose: print("\tSaved \"gaussian_fit.pt\"")
  39. def forward(self, ref_im,
  40. seed,
  41. loss_str,
  42. eps,
  43. noise_type,
  44. num_trainable_noise_layers,
  45. tile_latent,
  46. bad_noise_layers,
  47. opt_name,
  48. learning_rate,
  49. steps,
  50. lr_schedule,
  51. save_intermediate,
  52. **kwargs):
  53. if seed:
  54. torch.manual_seed(seed)
  55. torch.cuda.manual_seed(seed)
  56. torch.backends.cudnn.deterministic = True
  57. batch_size = ref_im.shape[0]
  58. # Generate latent tensor
  59. if(tile_latent):
  60. latent = torch.randn(
  61. (batch_size, 1, 512), dtype=torch.float, requires_grad=True, device='cuda')
  62. else:
  63. latent = torch.randn(
  64. (batch_size, 18, 512), dtype=torch.float, requires_grad=True, device='cuda')
  65. # Generate list of noise tensors
  66. noise = [] # stores all of the noise tensors
  67. noise_vars = [] # stores the noise tensors that we want to optimize on
  68. for i in range(18):
  69. # dimension of the ith noise tensor
  70. res = (batch_size, 1, 2**(i//2+2), 2**(i//2+2))
  71. if(noise_type == 'zero' or i in [int(layer) for layer in bad_noise_layers.split('.')]):
  72. new_noise = torch.zeros(res, dtype=torch.float, device='cuda')
  73. new_noise.requires_grad = False
  74. elif(noise_type == 'fixed'):
  75. new_noise = torch.randn(res, dtype=torch.float, device='cuda')
  76. new_noise.requires_grad = False
  77. elif (noise_type == 'trainable'):
  78. new_noise = torch.randn(res, dtype=torch.float, device='cuda')
  79. if (i < num_trainable_noise_layers):
  80. new_noise.requires_grad = True
  81. noise_vars.append(new_noise)
  82. else:
  83. new_noise.requires_grad = False
  84. else:
  85. raise Exception("unknown noise type")
  86. noise.append(new_noise)
  87. var_list = [latent]+noise_vars
  88. opt_dict = {
  89. 'sgd': torch.optim.SGD,
  90. 'adam': torch.optim.Adam,
  91. 'sgdm': partial(torch.optim.SGD, momentum=0.9),
  92. 'adamax': torch.optim.Adamax
  93. }
  94. opt_func = opt_dict[opt_name]
  95. opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate)
  96. schedule_dict = {
  97. 'fixed': lambda x: 1,
  98. 'linear1cycle': lambda x: (9*(1-np.abs(x/steps-1/2)*2)+1)/10,
  99. 'linear1cycledrop': lambda x: (9*(1-np.abs(x/(0.9*steps)-1/2)*2)+1)/10 if x < 0.9*steps else 1/10 + (x-0.9*steps)/(0.1*steps)*(1/1000-1/10),
  100. }
  101. schedule_func = schedule_dict[lr_schedule]
  102. scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func)
  103. loss_builder = LossBuilder(ref_im, loss_str, eps).cuda()
  104. min_loss = np.inf
  105. min_l2 = np.inf
  106. best_summary = ""
  107. start_t = time.time()
  108. gen_im = None
  109. if self.verbose: print("Optimizing")
  110. for j in range(steps):
  111. opt.opt.zero_grad()
  112. # Duplicate latent in case tile_latent = True
  113. if (tile_latent):
  114. latent_in = latent.expand(-1, 18, -1)
  115. else:
  116. latent_in = latent
  117. # Apply learned linear mapping to match latent distribution to that of the mapping network
  118. latent_in = self.lrelu(latent_in*self.gaussian_fit["std"] + self.gaussian_fit["mean"])
  119. # Normalize image to [0,1] instead of [-1,1]
  120. gen_im = (self.synthesis(latent_in, noise)+1)/2
  121. # Calculate Losses
  122. loss, loss_dict = loss_builder(latent_in, gen_im)
  123. loss_dict['TOTAL'] = loss
  124. # Save best summary for log
  125. if(loss < min_loss):
  126. min_loss = loss
  127. best_summary = f'BEST ({j+1}) | '+' | '.join(
  128. [f'{x}: {y:.4f}' for x, y in loss_dict.items()])
  129. best_im = gen_im.clone()
  130. loss_l2 = loss_dict['L2']
  131. if(loss_l2 < min_l2):
  132. min_l2 = loss_l2
  133. # Save intermediate HR and LR images
  134. if(save_intermediate):
  135. yield (best_im.cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1))
  136. loss.backward()
  137. opt.step()
  138. scheduler.step()
  139. total_t = time.time()-start_t
  140. current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}'
  141. if self.verbose: print(best_summary+current_info)
  142. if(min_l2 <= eps):
  143. yield (gen_im.clone().cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1))
  144. else:
  145. print("Could not find a face that downscales correctly within epsilon")