Your Name 3 years ago
commit
fedefa11a1
13 changed files with 1254 additions and 0 deletions
  1. 7 0
      .gitignore
  2. 179 0
      PULSE.py
  3. 74 0
      README.md
  4. 26 0
      SphericalOptimizer.py
  5. 49 0
      align_face.py
  6. 75 0
      bicubic.py
  7. 94 0
      drive.py
  8. BIN
      gaussian_fit.pt
  9. 57 0
      loss.py
  10. 63 0
      pulse.yml
  11. 82 0
      run.py
  12. 138 0
      shape_predictor.py
  13. 410 0
      stylegan.py

+ 7 - 0
.gitignore

@@ -0,0 +1,7 @@
+.DS_Store
+__pycache__/*
+.idea/*
+runs/*
+input/*
+cache/*
+realpics/*

+ 179 - 0
PULSE.py

@@ -0,0 +1,179 @@
+from stylegan import G_synthesis,G_mapping
+from dataclasses import dataclass
+from SphericalOptimizer import SphericalOptimizer
+from pathlib import Path
+import numpy as np
+import time
+import torch
+from loss import LossBuilder
+from functools import partial
+from drive import open_url
+
+
+class PULSE(torch.nn.Module):
+    def __init__(self, cache_dir, verbose=True):
+        super(PULSE, self).__init__()
+
+        self.synthesis = G_synthesis().cuda()
+        self.verbose = verbose
+
+        cache_dir = Path(cache_dir)
+        cache_dir.mkdir(parents=True, exist_ok = True)
+        if self.verbose: print("Loading Synthesis Network")
+        with open_url("https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8", cache_dir=cache_dir, verbose=verbose) as f:
+            self.synthesis.load_state_dict(torch.load(f))
+
+        for param in self.synthesis.parameters():
+            param.requires_grad = False
+
+        self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2)
+
+        if Path("gaussian_fit.pt").exists():
+            self.gaussian_fit = torch.load("gaussian_fit.pt")
+        else:
+            if self.verbose: print("\tLoading Mapping Network")
+            mapping = G_mapping().cuda()
+
+            with open_url("https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k", cache_dir=cache_dir, verbose=verbose) as f:
+                    mapping.load_state_dict(torch.load(f))
+
+            if self.verbose: print("\tRunning Mapping Network")
+            with torch.no_grad():
+                torch.manual_seed(0)
+                latent = torch.randn((1000000,512),dtype=torch.float32, device="cuda")
+                latent_out = torch.nn.LeakyReLU(5)(mapping(latent))
+                self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)}
+                torch.save(self.gaussian_fit,"gaussian_fit.pt")
+                if self.verbose: print("\tSaved \"gaussian_fit.pt\"")
+
+    def forward(self, ref_im,
+                seed,
+                loss_str,
+                eps,
+                noise_type,
+                num_trainable_noise_layers,
+                tile_latent,
+                bad_noise_layers,
+                opt_name,
+                learning_rate,
+                steps,
+                lr_schedule,
+                save_intermediate,
+                **kwargs):
+
+        if seed:
+            torch.manual_seed(seed)
+            torch.cuda.manual_seed(seed)
+            torch.backends.cudnn.deterministic = True
+
+        batch_size = ref_im.shape[0]
+
+        # Generate latent tensor
+        if(tile_latent):
+            latent = torch.randn(
+                (batch_size, 1, 512), dtype=torch.float, requires_grad=True, device='cuda')
+        else:
+            latent = torch.randn(
+                (batch_size, 18, 512), dtype=torch.float, requires_grad=True, device='cuda')
+
+        # Generate list of noise tensors
+        noise = [] # stores all of the noise tensors
+        noise_vars = []  # stores the noise tensors that we want to optimize on
+
+        for i in range(18):
+            # dimension of the ith noise tensor
+            res = (batch_size, 1, 2**(i//2+2), 2**(i//2+2))
+
+            if(noise_type == 'zero' or i in [int(layer) for layer in bad_noise_layers.split('.')]):
+                new_noise = torch.zeros(res, dtype=torch.float, device='cuda')
+                new_noise.requires_grad = False
+            elif(noise_type == 'fixed'):
+                new_noise = torch.randn(res, dtype=torch.float, device='cuda')
+                new_noise.requires_grad = False
+            elif (noise_type == 'trainable'):
+                new_noise = torch.randn(res, dtype=torch.float, device='cuda')
+                if (i < num_trainable_noise_layers):
+                    new_noise.requires_grad = True
+                    noise_vars.append(new_noise)
+                else:
+                    new_noise.requires_grad = False
+            else:
+                raise Exception("unknown noise type")
+
+            noise.append(new_noise)
+
+        var_list = [latent]+noise_vars
+
+        opt_dict = {
+            'sgd': torch.optim.SGD,
+            'adam': torch.optim.Adam,
+            'sgdm': partial(torch.optim.SGD, momentum=0.9),
+            'adamax': torch.optim.Adamax
+        }
+        opt_func = opt_dict[opt_name]
+        opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate)
+
+        schedule_dict = {
+            'fixed': lambda x: 1,
+            'linear1cycle': lambda x: (9*(1-np.abs(x/steps-1/2)*2)+1)/10,
+            'linear1cycledrop': lambda x: (9*(1-np.abs(x/(0.9*steps)-1/2)*2)+1)/10 if x < 0.9*steps else 1/10 + (x-0.9*steps)/(0.1*steps)*(1/1000-1/10),
+        }
+        schedule_func = schedule_dict[lr_schedule]
+        scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func)
+        
+        loss_builder = LossBuilder(ref_im, loss_str, eps).cuda()
+
+        min_loss = np.inf
+        min_l2 = np.inf
+        best_summary = ""
+        start_t = time.time()
+        gen_im = None
+
+
+        if self.verbose: print("Optimizing")
+        for j in range(steps):
+            opt.opt.zero_grad()
+
+            # Duplicate latent in case tile_latent = True
+            if (tile_latent):
+                latent_in = latent.expand(-1, 18, -1)
+            else:
+                latent_in = latent
+
+            # Apply learned linear mapping to match latent distribution to that of the mapping network
+            latent_in = self.lrelu(latent_in*self.gaussian_fit["std"] + self.gaussian_fit["mean"])
+
+            # Normalize image to [0,1] instead of [-1,1]
+            gen_im = (self.synthesis(latent_in, noise)+1)/2
+
+            # Calculate Losses
+            loss, loss_dict = loss_builder(latent_in, gen_im)
+            loss_dict['TOTAL'] = loss
+
+            # Save best summary for log
+            if(loss < min_loss):
+                min_loss = loss
+                best_summary = f'BEST ({j+1}) | '+' | '.join(
+                [f'{x}: {y:.4f}' for x, y in loss_dict.items()])
+                best_im = gen_im.clone()
+
+            loss_l2 = loss_dict['L2']
+
+            if(loss_l2 < min_l2):
+                min_l2 = loss_l2
+
+            # Save intermediate HR and LR images
+            if(save_intermediate):
+                yield (best_im.cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1))
+
+            loss.backward()
+            opt.step()
+            scheduler.step()
+
+        total_t = time.time()-start_t
+        current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}'
+        if self.verbose: print(best_summary+current_info)
+        if(min_l2 <= eps):
+            yield (gen_im.clone().cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1))
+        else:
+            print("Could not find a face that downscales correctly within epsilon")

+ 74 - 0
README.md

@@ -0,0 +1,74 @@
+# PULSE: Self-Supervised Photo Upsampling via Latent Space Exploration of Generative Models
+Code accompanying CVPR'20 paper of the same title. Paper link: https://arxiv.org/pdf/2003.03808.pdf
+
+## NOTE
+
+We have noticed a lot of concern that PULSE will be used to identify individuals whose faces have been blurred out. We want to emphasize that this is impossible - **PULSE makes imaginary faces of people who do not exist, which should not be confused for real people.** It will **not** help identify or reconstruct the original image.
+
+We also want to address concerns of bias in PULSE. **We have now included a new section in the [paper](https://arxiv.org/pdf/2003.03808.pdf) and an accompanying model card directly addressing this bias.**
+
+---
+
+![Transformation Preview](./readme_resources/014.jpeg)
+![Transformation Preview](./readme_resources/034.jpeg)
+![Transformation Preview](./readme_resources/094.jpeg)
+
+Table of Contents
+=================
+- [PULSE: Self-Supervised Photo Upsampling via Latent Space Exploration of Generative Models](#pulse-self-supervised-photo-upsampling-via-latent-space-exploration-of-generative-models)
+- [Table of Contents](#table-of-contents)
+  - [What does it do?](#what-does-it-do)
+  - [Usage](#usage)
+    - [Prereqs](#prereqs)
+    - [Data](#data)
+    - [Applying PULSE](#applying-pulse)
+## What does it do? 
+Given a low-resolution input image, PULSE searches the outputs of a generative model (here, [StyleGAN](https://github.com/NVlabs/stylegan)) for high-resolution images that are perceptually realistic and downscale correctly.
+
+![Transformation Preview](./readme_resources/transformation.gif)
+
+## Usage
+
+The main file of interest for applying PULSE is `run.py`. A full list of arguments with descriptions can be found in that file; here we describe those relevant to getting started.
+
+### Prereqs
+
+You will need to install cmake first (required for dlib, which is used for face alignment). Currently the code only works with CUDA installed (and therefore requires an appropriate GPU) and has been tested on Linux and Windows. For the full set of required Python packages, create a Conda environment from the provided YAML, e.g.
+
+```
+conda create -f pulse.yml 
+```
+or (Anaconda on Windows):
+```
+conda env create -n pulse -f pulse.yml
+conda activate pulse
+```
+
+In some environments (e.g. on Windows), you may have to edit the pulse.yml to remove the version specific hash on each dependency and remove any dependency that still throws an error after running ```conda env create...``` (such as readline)
+```
+dependencies
+  - blas=1.0=mkl
+  ...
+```
+to
+```
+dependencies
+  - blas=1.0
+ ...
+```
+
+Finally, you will need an internet connection the first time you run the code as it will automatically download the relevant pretrained model from Google Drive (if it has already been downloaded, it will use the local copy). In the event that the public Google Drive is out of capacity, add the files to your own Google Drive instead; get the share URL and replace the ID in the https://drive.google.com/uc?=ID links in ```align_face.py``` and ```PULSE.py``` with the new file ids from the share URL given by your own Drive file.
+ 
+
+### Data
+
+By default, input data for `run.py` should be placed in `./input/` (though this can be modified). However, this assumes faces have already been aligned and downscaled. If you have data that is not already in this form, place it in `realpics` and run `align_face.py` which will automatically do this for you. (Again, all directories can be changed by command line arguments if more convenient.) You will at this stage pic a downscaling factor. 
+
+Note that if your data begins at a low resolution already, downscaling it further will retain very little information. In this case, you may wish to bicubically upsample (usually, to 1024x1024) and allow `align_face.py` to downscale for you.  
+
+### Applying PULSE
+Once your data is appropriately formatted, all you need to do is
+```
+python run.py
+```
+Enjoy!

+ 26 - 0
SphericalOptimizer.py

@@ -0,0 +1,26 @@
+import math
+import torch
+from torch.optim import Optimizer
+
+# Spherical Optimizer Class
+# Uses the first two dimensions as batch information
+# Optimizes over the surface of a sphere using the initial radius throughout
+#
+# Example Usage:
+# opt = SphericalOptimizer(torch.optim.SGD, [x], lr=0.01)
+
+class SphericalOptimizer(Optimizer):
+    def __init__(self, optimizer, params, **kwargs):
+        self.opt = optimizer(params, **kwargs)
+        self.params = params
+        with torch.no_grad():
+            self.radii = {param: (param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt() for param in params}
+
+    @torch.no_grad()
+    def step(self, closure=None):
+        loss = self.opt.step(closure)
+        for param in self.params:
+            param.data.div_((param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt())
+            param.mul_(self.radii[param])
+
+        return loss

+ 49 - 0
align_face.py

@@ -0,0 +1,49 @@
+import numpy as np
+import PIL
+import PIL.Image
+import sys
+import os
+import glob
+import scipy
+import scipy.ndimage
+import dlib
+from drive import open_url
+from pathlib import Path
+import argparse
+from bicubic import BicubicDownSample
+import torchvision
+from shape_predictor import align_face
+
+parser = argparse.ArgumentParser(description='PULSE')
+
+parser.add_argument('-input_dir', type=str, default='realpics', help='directory with unprocessed images')
+parser.add_argument('-output_dir', type=str, default='input', help='output directory')
+parser.add_argument('-output_size', type=int, default=32, help='size to downscale the input images to, must be power of 2')
+parser.add_argument('-seed', type=int, help='manual seed to use')
+parser.add_argument('-cache_dir', type=str, default='cache', help='cache directory for model weights')
+
+args = parser.parse_args()
+
+cache_dir = Path(args.cache_dir)
+cache_dir.mkdir(parents=True, exist_ok=True)
+
+output_dir = Path(args.output_dir)
+output_dir.mkdir(parents=True,exist_ok=True)
+
+print("Downloading Shape Predictor")
+f=open_url("https://drive.google.com/uc?id=1huhv8PYpNNKbGCLOaYUjOgR1pY5pmbJx", cache_dir=cache_dir, return_path=True)
+predictor = dlib.shape_predictor(f)
+
+for im in Path(args.input_dir).glob("*.*"):
+    faces = align_face(str(im),predictor)
+
+    for i,face in enumerate(faces):
+        if(args.output_size):
+            factor = 1024//args.output_size
+            assert args.output_size*factor == 1024
+            D = BicubicDownSample(factor=factor)
+            face_tensor = torchvision.transforms.ToTensor()(face).unsqueeze(0).cuda()
+            face_tensor_lr = D(face_tensor)[0].cpu().detach().clamp(0, 1)
+            face = torchvision.transforms.ToPILImage()(face_tensor_lr)
+
+        face.save(Path(args.output_dir) / (im.stem+f"_{i}.png"))

+ 75 - 0
bicubic.py

@@ -0,0 +1,75 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class BicubicDownSample(nn.Module):
+    def bicubic_kernel(self, x, a=-0.50):
+        """
+        This equation is exactly copied from the website below:
+        https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic
+        """
+        abs_x = torch.abs(x)
+        if abs_x <= 1.:
+            return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1
+        elif 1. < abs_x < 2.:
+            return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a
+        else:
+            return 0.0
+
+    def __init__(self, factor=4, cuda=True, padding='reflect'):
+        super().__init__()
+        self.factor = factor
+        size = factor * 4
+        k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor)
+                          for i in range(size)], dtype=torch.float32)
+        k = k / torch.sum(k)
+        # k = torch.einsum('i,j->ij', (k, k))
+        k1 = torch.reshape(k, shape=(1, 1, size, 1))
+        self.k1 = torch.cat([k1, k1, k1], dim=0)
+        k2 = torch.reshape(k, shape=(1, 1, 1, size))
+        self.k2 = torch.cat([k2, k2, k2], dim=0)
+        self.cuda = '.cuda' if cuda else ''
+        self.padding = padding
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, x, nhwc=False, clip_round=False, byte_output=False):
+        # x = torch.from_numpy(x).type('torch.FloatTensor')
+        filter_height = self.factor * 4
+        filter_width = self.factor * 4
+        stride = self.factor
+
+        pad_along_height = max(filter_height - stride, 0)
+        pad_along_width = max(filter_width - stride, 0)
+        filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda))
+        filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda))
+
+        # compute actual padding values for each side
+        pad_top = pad_along_height // 2
+        pad_bottom = pad_along_height - pad_top
+        pad_left = pad_along_width // 2
+        pad_right = pad_along_width - pad_left
+
+        # apply mirror padding
+        if nhwc:
+            x = torch.transpose(torch.transpose(
+                x, 2, 3), 1, 2)   # NHWC to NCHW
+
+        # downscaling performed by 1-d convolution
+        x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding)
+        x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3)
+        if clip_round:
+            x = torch.clamp(torch.round(x), 0.0, 255.)
+
+        x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding)
+        x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3)
+        if clip_round:
+            x = torch.clamp(torch.round(x), 0.0, 255.)
+
+        if nhwc:
+            x = torch.transpose(torch.transpose(x, 1, 3), 1, 2)
+        if byte_output:
+            return x.type('torch.ByteTensor'.format(self.cuda))
+        else:
+            return x

+ 94 - 0
drive.py

@@ -0,0 +1,94 @@
+# URL helpers, see https://github.com/NVlabs/stylegan
+# ------------------------------------------------------------------------------------------
+
+import requests
+import html
+import hashlib
+import glob
+import os
+import io
+from typing import Any
+import re
+import uuid
+
+def is_url(obj: Any) -> bool:
+    """Determine whether the given object is a valid URL string."""
+    if not isinstance(obj, str) or not "://" in obj:
+        return False
+    try:
+        res = requests.compat.urlparse(obj)
+        if not res.scheme or not res.netloc or not "." in res.netloc:
+            return False
+        res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
+        if not res.scheme or not res.netloc or not "." in res.netloc:
+            return False
+    except:
+        return False
+    return True
+
+
+def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_path: bool = False) -> Any:
+    """Download the given URL and return a binary-mode file object to access the data."""
+    assert is_url(url)
+    assert num_attempts >= 1
+
+    # Lookup from cache.
+    url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
+    if cache_dir is not None:
+        cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
+        if len(cache_files) == 1:
+            if(return_path):
+                return cache_files[0]
+            else:
+                return open(cache_files[0], "rb")
+
+    # Download.
+    url_name = None
+    url_data = None
+    with requests.Session() as session:
+        if verbose:
+            print("Downloading %s ..." % url, end="", flush=True)
+        for attempts_left in reversed(range(num_attempts)):
+            try:
+                with session.get(url) as res:
+                    res.raise_for_status()
+                    if len(res.content) == 0:
+                        raise IOError("No data received")
+
+                    if len(res.content) < 8192:
+                        content_str = res.content.decode("utf-8")
+                        if "download_warning" in res.headers.get("Set-Cookie", ""):
+                            links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
+                            if len(links) == 1:
+                                url = requests.compat.urljoin(url, links[0])
+                                raise IOError("Google Drive virus checker nag")
+                        if "Google Drive - Quota exceeded" in content_str:
+                            raise IOError("Google Drive quota exceeded")
+
+                    match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
+                    url_name = match[1] if match else url
+                    url_data = res.content
+                    if verbose:
+                        print(" done")
+                    break
+            except:
+                if not attempts_left:
+                    if verbose:
+                        print(" failed")
+                    raise
+                if verbose:
+                    print(".", end="", flush=True)
+
+    # Save to cache.
+    if cache_dir is not None:
+        safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
+        cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
+        temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
+        os.makedirs(cache_dir, exist_ok=True)
+        with open(temp_file, "wb") as f:
+            f.write(url_data)
+        os.replace(temp_file, cache_file) # atomic
+        if(return_path): return cache_file
+
+    # Return data as file object.
+    return io.BytesIO(url_data)

BIN
gaussian_fit.pt


+ 57 - 0
loss.py

@@ -0,0 +1,57 @@
+import torch
+from bicubic import BicubicDownSample
+
+class LossBuilder(torch.nn.Module):
+    def __init__(self, ref_im, loss_str, eps):
+        super(LossBuilder, self).__init__()
+        assert ref_im.shape[2]==ref_im.shape[3]
+        im_size = ref_im.shape[2]
+        factor=1024//im_size
+        assert im_size*factor==1024
+        self.D = BicubicDownSample(factor=factor)
+        self.ref_im = ref_im
+        self.parsed_loss = [loss_term.split('*') for loss_term in loss_str.split('+')]
+        self.eps = eps
+
+    # Takes a list of tensors, flattens them, and concatenates them into a vector
+    # Used to calculate euclidian distance between lists of tensors
+    def flatcat(self, l):
+        l = l if(isinstance(l, list)) else [l]
+        return torch.cat([x.flatten() for x in l], dim=0)
+
+    def _loss_l2(self, gen_im_lr, ref_im, **kwargs):
+        return ((gen_im_lr - ref_im).pow(2).mean((1, 2, 3)).clamp(min=self.eps).sum())
+
+    def _loss_l1(self, gen_im_lr, ref_im, **kwargs):
+        return 10*((gen_im_lr - ref_im).abs().mean((1, 2, 3)).clamp(min=self.eps).sum())
+
+    # Uses geodesic distance on sphere to sum pairwise distances of the 18 vectors
+    def _loss_geocross(self, latent, **kwargs):
+        if(latent.shape[1] == 1):
+            return 0
+        else:
+            X = latent.view(-1, 1, 18, 512)
+            Y = latent.view(-1, 18, 1, 512)
+            A = ((X-Y).pow(2).sum(-1)+1e-9).sqrt()
+            B = ((X+Y).pow(2).sum(-1)+1e-9).sqrt()
+            D = 2*torch.atan2(A, B)
+            D = ((D.pow(2)*512).mean((1, 2))/8.).sum()
+            return D
+
+    def forward(self, latent, gen_im):
+        var_dict = {'latent': latent,
+                    'gen_im_lr': self.D(gen_im),
+                    'ref_im': self.ref_im,
+                    }
+        loss = 0
+        loss_fun_dict = {
+            'L2': self._loss_l2,
+            'L1': self._loss_l1,
+            'GEOCROSS': self._loss_geocross,
+        }
+        losses = {}
+        for weight, loss_type in self.parsed_loss:
+            tmp_loss = loss_fun_dict[loss_type](**var_dict)
+            losses[loss_type] = tmp_loss
+            loss += float(weight)*tmp_loss
+        return loss, losses

+ 63 - 0
pulse.yml

@@ -0,0 +1,63 @@
+name: pulse
+channels:
+  - pytorch
+  - defaults
+dependencies:
+  - blas=1.0=mkl
+  - ca-certificates=2020.1.1=0
+  - certifi=2020.4.5.1=py38_0
+  - cffi=1.14.0=py38hc512035_1
+  - chardet=3.0.4=py38_1003
+  - cryptography=2.9.2=py38ha12b0ac_0
+  - cycler=0.10.0=py38_0
+  - freetype=2.9.1=hb4e5f40_0
+  - idna=2.9=py_1
+  - intel-openmp=2019.4=233
+  - jpeg=9b=he5867d9_2
+  - kiwisolver=1.2.0=py38h04f5b5a_0
+  - libcxx=10.0.0=1
+  - libedit=3.1.20181209=hb402a30_0
+  - libffi=3.3=h0a44026_1
+  - libgfortran=3.0.1=h93005f0_2
+  - libpng=1.6.37=ha441bb4_0
+  - libtiff=4.1.0=hcb84e12_0
+  - matplotlib=3.1.3=py38_0
+  - matplotlib-base=3.1.3=py38h9aa3819_0
+  - mkl=2019.4=233
+  - mkl-service=2.3.0=py38hfbe908c_0
+  - mkl_fft=1.0.15=py38h5e564d8_0
+  - mkl_random=1.1.0=py38h6440ff4_0
+  - ncurses=6.2=h0a44026_1
+  - ninja=1.9.0=py38h04f5b5a_0
+  - numpy=1.18.1=py38h7241aed_0
+  - numpy-base=1.18.1=py38h6575580_1
+  - olefile=0.46=py_0
+  - openssl=1.1.1g=h1de35cc_0
+  - pandas=1.0.3=py38h6c726b0_0
+  - pillow=7.1.2=py38h4655f20_0
+  - pip=20.0.2=py38_3
+  - pycparser=2.20=py_0
+  - pyopenssl=19.1.0=py38_0
+  - pyparsing=2.4.7=py_0
+  - pysocks=1.7.1=py38_0
+  - python=3.8.2=hf48f09d_13
+  - python-dateutil=2.8.1=py_0
+  - pytorch=1.5.0=py3.8_0
+  - pytz=2020.1=py_0
+  - readline=8.0=h1de35cc_0
+  - requests=2.23.0=py38_0
+  - scipy=1.4.1=py38h44e99c9_0
+  - setuptools=46.2.0=py38_0
+  - six=1.14.0=py38_0
+  - sqlite=3.31.1=h5c1f38d_1
+  - tk=8.6.8=ha441bb4_0
+  - torchvision=0.6.0=py38_cpu
+  - tornado=6.0.4=py38h1de35cc_1
+  - urllib3=1.25.8=py38_0
+  - wheel=0.34.2=py38_0
+  - xz=5.2.5=h1de35cc_0
+  - zlib=1.2.11=h1de35cc_3
+  - zstd=1.3.7=h5bba6e5_0
+  - pip:
+    - dlib==19.19.0
+prefix: /Users/sachit/opt/miniconda3/envs/pulse

+ 82 - 0
run.py

@@ -0,0 +1,82 @@
+from PULSE import PULSE
+from torch.utils.data import Dataset, DataLoader
+from torch.nn import DataParallel
+from pathlib import Path
+from PIL import Image
+import torchvision
+from math import log10, ceil
+import argparse
+
+class Images(Dataset):
+    def __init__(self, root_dir, duplicates):
+        self.root_path = Path(root_dir)
+        self.image_list = list(self.root_path.glob("*.png"))
+        self.duplicates = duplicates # Number of times to duplicate the image in the dataset to produce multiple HR images
+
+    def __len__(self):
+        return self.duplicates*len(self.image_list)
+
+    def __getitem__(self, idx):
+        img_path = self.image_list[idx//self.duplicates]
+        image = torchvision.transforms.ToTensor()(Image.open(img_path))
+        if(self.duplicates == 1):
+            return image,img_path.stem
+        else:
+            return image,img_path.stem+f"_{(idx % self.duplicates)+1}"
+
+parser = argparse.ArgumentParser(description='PULSE')
+
+#I/O arguments
+parser.add_argument('-input_dir', type=str, default='input', help='input data directory')
+parser.add_argument('-output_dir', type=str, default='runs', help='output data directory')
+parser.add_argument('-cache_dir', type=str, default='cache', help='cache directory for model weights')
+parser.add_argument('-duplicates', type=int, default=1, help='How many HR images to produce for every image in the input directory')
+parser.add_argument('-batch_size', type=int, default=1, help='Batch size to use during optimization')
+
+#PULSE arguments
+parser.add_argument('-seed', type=int, help='manual seed to use')
+parser.add_argument('-loss_str', type=str, default="100*L2+0.05*GEOCROSS", help='Loss function to use')
+parser.add_argument('-eps', type=float, default=2e-3, help='Target for downscaling loss (L2)')
+parser.add_argument('-noise_type', type=str, default='trainable', help='zero, fixed, or trainable')
+parser.add_argument('-num_trainable_noise_layers', type=int, default=5, help='Number of noise layers to optimize')
+parser.add_argument('-tile_latent', action='store_true', help='Whether to forcibly tile the same latent 18 times')
+parser.add_argument('-bad_noise_layers', type=str, default="17", help='List of noise layers to zero out to improve image quality')
+parser.add_argument('-opt_name', type=str, default='adam', help='Optimizer to use in projected gradient descent')
+parser.add_argument('-learning_rate', type=float, default=0.4, help='Learning rate to use during optimization')
+parser.add_argument('-steps', type=int, default=100, help='Number of optimization steps')
+parser.add_argument('-lr_schedule', type=str, default='linear1cycledrop', help='fixed, linear1cycledrop, linear1cycle')
+parser.add_argument('-save_intermediate', action='store_true', help='Whether to store and save intermediate HR and LR images during optimization')
+
+kwargs = vars(parser.parse_args())
+
+dataset = Images(kwargs["input_dir"], duplicates=kwargs["duplicates"])
+out_path = Path(kwargs["output_dir"])
+out_path.mkdir(parents=True, exist_ok=True)
+
+dataloader = DataLoader(dataset, batch_size=kwargs["batch_size"])
+
+model = PULSE(cache_dir=kwargs["cache_dir"])
+model = DataParallel(model)
+
+toPIL = torchvision.transforms.ToPILImage()
+
+for ref_im, ref_im_name in dataloader:
+    if(kwargs["save_intermediate"]):
+        padding = ceil(log10(100))
+        for i in range(kwargs["batch_size"]):
+            int_path_HR = Path(out_path / ref_im_name[i] / "HR")
+            int_path_LR = Path(out_path / ref_im_name[i] / "LR")
+            int_path_HR.mkdir(parents=True, exist_ok=True)
+            int_path_LR.mkdir(parents=True, exist_ok=True)
+        for j,(HR,LR) in enumerate(model(ref_im,**kwargs)):
+            for i in range(kwargs["batch_size"]):
+                toPIL(HR[i].cpu().detach().clamp(0, 1)).save(
+                    int_path_HR / f"{ref_im_name[i]}_{j:0{padding}}.png")
+                toPIL(LR[i].cpu().detach().clamp(0, 1)).save(
+                    int_path_LR / f"{ref_im_name[i]}_{j:0{padding}}.png")
+    else:
+        #out_im = model(ref_im,**kwargs)
+        for j,(HR,LR) in enumerate(model(ref_im,**kwargs)):
+            for i in range(kwargs["batch_size"]):
+                toPIL(HR[i].cpu().detach().clamp(0, 1)).save(
+                    out_path / f"{ref_im_name[i]}.png")

+ 138 - 0
shape_predictor.py

@@ -0,0 +1,138 @@
+import numpy as np
+import PIL
+import PIL.Image
+import sys
+import os
+import glob
+import scipy
+import scipy.ndimage
+import dlib
+from drive import open_url
+from pathlib import Path
+import argparse
+from bicubic import BicubicDownSample
+import torchvision
+
+"""
+brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
+author: lzhbrian (https://lzhbrian.me)
+date: 2020.1.5
+note: code is heavily borrowed from
+    https://github.com/NVlabs/ffhq-dataset
+    http://dlib.net/face_landmark_detection.py.html
+
+requirements:
+    apt install cmake
+    conda install Pillow numpy scipy
+    pip install dlib
+    # download face landmark model from:
+    # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
+"""
+
+def get_landmark(filepath,predictor):
+    """get landmark with dlib
+    :return: np.array shape=(68, 2)
+    """
+    detector = dlib.get_frontal_face_detector()
+
+    img = dlib.load_rgb_image(filepath)
+    dets = detector(img, 1)
+    filepath = Path(filepath)
+    print(f"{filepath.name}: Number of faces detected: {len(dets)}")
+    shapes = [predictor(img, d) for k, d in enumerate(dets)]
+
+    lms = [np.array([[tt.x, tt.y] for tt in shape.parts()]) for shape in shapes]
+
+    return lms
+
+
+def align_face(filepath,predictor):
+    """
+    :param filepath: str
+    :return: list of PIL Images
+    """
+
+    lms = get_landmark(filepath,predictor)
+    imgs = []
+    for lm in lms:
+        lm_chin = lm[0: 17]  # left-right
+        lm_eyebrow_left = lm[17: 22]  # left-right
+        lm_eyebrow_right = lm[22: 27]  # left-right
+        lm_nose = lm[27: 31]  # top-down
+        lm_nostrils = lm[31: 36]  # top-down
+        lm_eye_left = lm[36: 42]  # left-clockwise
+        lm_eye_right = lm[42: 48]  # left-clockwise
+        lm_mouth_outer = lm[48: 60]  # left-clockwise
+        lm_mouth_inner = lm[60: 68]  # left-clockwise
+
+        # Calculate auxiliary vectors.
+        eye_left = np.mean(lm_eye_left, axis=0)
+        eye_right = np.mean(lm_eye_right, axis=0)
+        eye_avg = (eye_left + eye_right) * 0.5
+        eye_to_eye = eye_right - eye_left
+        mouth_left = lm_mouth_outer[0]
+        mouth_right = lm_mouth_outer[6]
+        mouth_avg = (mouth_left + mouth_right) * 0.5
+        eye_to_mouth = mouth_avg - eye_avg
+
+        # Choose oriented crop rectangle.
+        x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+        x /= np.hypot(*x)
+        x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
+        y = np.flipud(x) * [-1, 1]
+        c = eye_avg + eye_to_mouth * 0.1
+        quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+        qsize = np.hypot(*x) * 2
+
+        # read image
+        img = PIL.Image.open(filepath)
+
+        output_size = 1024
+        transform_size = 4096
+        enable_padding = True
+
+        # Shrink.
+        shrink = int(np.floor(qsize / output_size * 0.5))
+        if shrink > 1:
+            rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
+            img = img.resize(rsize, PIL.Image.ANTIALIAS)
+            quad /= shrink
+            qsize /= shrink
+
+        # Crop.
+        border = max(int(np.rint(qsize * 0.1)), 3)
+        crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+                int(np.ceil(max(quad[:, 1]))))
+        crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
+                min(crop[3] + border, img.size[1]))
+        if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
+            img = img.crop(crop)
+            quad -= crop[0:2]
+
+        # Pad.
+        pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+               int(np.ceil(max(quad[:, 1]))))
+        pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
+               max(pad[3] - img.size[1] + border, 0))
+        if enable_padding and max(pad) > border - 4:
+            pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+            img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+            h, w, _ = img.shape
+            y, x, _ = np.ogrid[:h, :w, :1]
+            mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
+                              1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
+            blur = qsize * 0.02
+            img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+            img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+            img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
+            quad += pad[:2]
+
+        # Transform.
+        img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(),
+                            PIL.Image.BILINEAR)
+        if output_size < transform_size:
+            img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
+
+        # Save aligned image.
+        imgs.append(img)
+    return imgs

+ 410 - 0
stylegan.py

@@ -0,0 +1,410 @@
+#Modified from https://github.com/lernapparat/lernapparat/
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from collections import OrderedDict
+import pickle
+
+import numpy as np
+
+
+class MyLinear(nn.Module):
+    """Linear layer with equalized learning rate and custom learning rate multiplier."""
+    
+    def __init__(self, input_size, output_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True):
+        super().__init__()
+        he_std = gain * input_size**(-0.5)  # He init
+        # Equalized learning rate and custom learning rate multiplier.
+        if use_wscale:
+            init_std = 1.0 / lrmul
+            self.w_mul = he_std * lrmul
+        else:
+            init_std = he_std / lrmul
+            self.w_mul = lrmul
+        self.weight = torch.nn.Parameter(
+            torch.randn(output_size, input_size) * init_std)
+        if bias:
+            self.bias = torch.nn.Parameter(torch.zeros(output_size))
+            self.b_mul = lrmul
+        else:
+            self.bias = None
+
+    def forward(self, x):
+        bias = self.bias
+        if bias is not None:
+            bias = bias * self.b_mul
+        return F.linear(x, self.weight * self.w_mul, bias)
+
+
+class MyConv2d(nn.Module):
+    """Conv layer with equalized learning rate and custom learning rate multiplier."""
+
+    def __init__(self, input_channels, output_channels, kernel_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True,
+                 intermediate=None, upscale=False):
+        super().__init__()
+        if upscale:
+            self.upscale = Upscale2d()
+        else:
+            self.upscale = None
+        he_std = gain * (input_channels * kernel_size **
+                         2) ** (-0.5)  # He init
+        self.kernel_size = kernel_size
+        if use_wscale:
+            init_std = 1.0 / lrmul
+            self.w_mul = he_std * lrmul
+        else:
+            init_std = he_std / lrmul
+            self.w_mul = lrmul
+        self.weight = torch.nn.Parameter(torch.randn(
+            output_channels, input_channels, kernel_size, kernel_size) * init_std)
+        if bias:
+            self.bias = torch.nn.Parameter(torch.zeros(output_channels))
+            self.b_mul = lrmul
+        else:
+            self.bias = None
+        self.intermediate = intermediate
+
+    def forward(self, x):
+        bias = self.bias
+        if bias is not None:
+            bias = bias * self.b_mul
+
+        have_convolution = False
+        if self.upscale is not None and min(x.shape[2:]) * 2 >= 128:
+            # this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way
+            # this really needs to be cleaned up and go into the conv...
+            w = self.weight * self.w_mul
+            w = w.permute(1, 0, 2, 3)
+            # probably applying a conv on w would be more efficient. also this quadruples the weight (average)?!
+            w = F.pad(w, (1, 1, 1, 1))
+            w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + \
+                w[:, :, 1:, :-1] + w[:, :, :-1, :-1]
+            x = F.conv_transpose2d(
+                x, w, stride=2, padding=int((w.size(-1)-1)//2))
+            have_convolution = True
+        elif self.upscale is not None:
+            x = self.upscale(x)
+
+        if not have_convolution and self.intermediate is None:
+            return F.conv2d(x, self.weight * self.w_mul, bias, padding=int(self.kernel_size//2))
+        elif not have_convolution:
+            x = F.conv2d(x, self.weight * self.w_mul, None,
+                         padding=int(self.kernel_size//2))
+
+        if self.intermediate is not None:
+            x = self.intermediate(x)
+        if bias is not None:
+            x = x + bias.view(1, -1, 1, 1)
+        return x
+
+
+class NoiseLayer(nn.Module):
+    """adds noise. noise is per pixel (constant over channels) with per-channel weight"""
+
+    def __init__(self, channels):
+        super().__init__()
+        self.weight = nn.Parameter(torch.zeros(channels))
+        self.noise = None
+
+    def forward(self, x, noise=None):
+        if noise is None and self.noise is None:
+            noise = torch.randn(x.size(0), 1, x.size(
+                2), x.size(3), device=x.device, dtype=x.dtype)
+        elif noise is None:
+            # here is a little trick: if you get all the noiselayers and set each
+            # modules .noise attribute, you can have pre-defined noise.
+            # Very useful for analysis
+            noise = self.noise
+        x = x + self.weight.view(1, -1, 1, 1) * noise
+        return x
+
+
+class StyleMod(nn.Module):
+    def __init__(self, latent_size, channels, use_wscale):
+        super(StyleMod, self).__init__()
+        self.lin = MyLinear(latent_size,
+                            channels * 2,
+                            gain=1.0, use_wscale=use_wscale)
+
+    def forward(self, x, latent):
+        style = self.lin(latent)  # style => [batch_size, n_channels*2]
+        shape = [-1, 2, x.size(1)] + (x.dim() - 2) * [1]
+        style = style.view(shape)  # [batch_size, 2, n_channels, ...]
+        x = x * (style[:, 0] + 1.) + style[:, 1]
+        return x
+
+
+class PixelNormLayer(nn.Module):
+    def __init__(self, epsilon=1e-8):
+        super().__init__()
+        self.epsilon = epsilon
+
+    def forward(self, x):
+        return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)
+
+
+class BlurLayer(nn.Module):
+    def __init__(self, kernel=[1, 2, 1], normalize=True, flip=False, stride=1):
+        super(BlurLayer, self).__init__()
+        kernel = [1, 2, 1]
+        kernel = torch.tensor(kernel, dtype=torch.float32)
+        kernel = kernel[:, None] * kernel[None, :]
+        kernel = kernel[None, None]
+        if normalize:
+            kernel = kernel / kernel.sum()
+        if flip:
+            kernel = kernel[:, :, ::-1, ::-1]
+        self.register_buffer('kernel', kernel)
+        self.stride = stride
+
+    def forward(self, x):
+        # expand kernel channels
+        kernel = self.kernel.expand(x.size(1), -1, -1, -1)
+        x = F.conv2d(
+            x,
+            kernel,
+            stride=self.stride,
+            padding=int((self.kernel.size(2)-1)/2),
+            groups=x.size(1)
+        )
+        return x
+
+
+def upscale2d(x, factor=2, gain=1):
+    assert x.dim() == 4
+    if gain != 1:
+        x = x * gain
+    if factor != 1:
+        shape = x.shape
+        x = x.view(shape[0], shape[1], shape[2], 1, shape[3],
+                   1).expand(-1, -1, -1, factor, -1, factor)
+        x = x.contiguous().view(
+            shape[0], shape[1], factor * shape[2], factor * shape[3])
+    return x
+
+
+class Upscale2d(nn.Module):
+    def __init__(self, factor=2, gain=1):
+        super().__init__()
+        assert isinstance(factor, int) and factor >= 1
+        self.gain = gain
+        self.factor = factor
+
+    def forward(self, x):
+        return upscale2d(x, factor=self.factor, gain=self.gain)
+
+
+class G_mapping(nn.Sequential):
+    def __init__(self, nonlinearity='lrelu', use_wscale=True):
+        act, gain = {'relu': (torch.relu, np.sqrt(2)),
+                     'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
+        layers = [
+            ('pixel_norm', PixelNormLayer()),
+            ('dense0', MyLinear(512, 512, gain=gain,
+                                lrmul=0.01, use_wscale=use_wscale)),
+            ('dense0_act', act),
+            ('dense1', MyLinear(512, 512, gain=gain,
+                                lrmul=0.01, use_wscale=use_wscale)),
+            ('dense1_act', act),
+            ('dense2', MyLinear(512, 512, gain=gain,
+                                lrmul=0.01, use_wscale=use_wscale)),
+            ('dense2_act', act),
+            ('dense3', MyLinear(512, 512, gain=gain,
+                                lrmul=0.01, use_wscale=use_wscale)),
+            ('dense3_act', act),
+            ('dense4', MyLinear(512, 512, gain=gain,
+                                lrmul=0.01, use_wscale=use_wscale)),
+            ('dense4_act', act),
+            ('dense5', MyLinear(512, 512, gain=gain,
+                                lrmul=0.01, use_wscale=use_wscale)),
+            ('dense5_act', act),
+            ('dense6', MyLinear(512, 512, gain=gain,
+                                lrmul=0.01, use_wscale=use_wscale)),
+            ('dense6_act', act),
+            ('dense7', MyLinear(512, 512, gain=gain,
+                                lrmul=0.01, use_wscale=use_wscale)),
+            ('dense7_act', act)
+        ]
+        super().__init__(OrderedDict(layers))
+
+    def forward(self, x):
+        x = super().forward(x)
+        return x
+
+
+class Truncation(nn.Module):
+    def __init__(self, avg_latent, max_layer=8, threshold=0.7):
+        super().__init__()
+        self.max_layer = max_layer
+        self.threshold = threshold
+        self.register_buffer('avg_latent', avg_latent)
+
+    def forward(self, x):
+        assert x.dim() == 3
+        interp = torch.lerp(self.avg_latent, x, self.threshold)
+        do_trunc = (torch.arange(x.size(1)) < self.max_layer).view(1, -1, 1)
+        return torch.where(do_trunc, interp, x)
+
+
+class LayerEpilogue(nn.Module):
+    """Things to do at the end of each layer."""
+
+    def __init__(self, channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
+        super().__init__()
+        layers = []
+        if use_noise:
+            self.noise = NoiseLayer(channels)
+        else:
+            self.noise = None
+        layers.append(('activation', activation_layer))
+        if use_pixel_norm:
+            layers.append(('pixel_norm', PixelNormLayer()))
+        if use_instance_norm:
+            layers.append(('instance_norm', nn.InstanceNorm2d(channels)))
+
+        self.top_epi = nn.Sequential(OrderedDict(layers))
+        if use_styles:
+            self.style_mod = StyleMod(
+                dlatent_size, channels, use_wscale=use_wscale)
+        else:
+            self.style_mod = None
+
+    def forward(self, x, dlatents_in_slice=None, noise_in_slice=None):
+        if(self.noise is not None):
+            x = self.noise(x, noise=noise_in_slice)
+        x = self.top_epi(x)
+        if self.style_mod is not None:
+            x = self.style_mod(x, dlatents_in_slice)
+        else:
+            assert dlatents_in_slice is None
+        return x
+
+
+class InputBlock(nn.Module):
+    def __init__(self, nf, dlatent_size, const_input_layer, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
+        super().__init__()
+        self.const_input_layer = const_input_layer
+        self.nf = nf
+        if self.const_input_layer:
+            # called 'const' in tf
+            self.const = nn.Parameter(torch.ones(1, nf, 4, 4))
+            self.bias = nn.Parameter(torch.ones(nf))
+        else:
+            # tweak gain to match the official implementation of Progressing GAN
+            self.dense = MyLinear(dlatent_size, nf*16,
+                                  gain=gain/4, use_wscale=use_wscale)
+        self.epi1 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise,
+                                  use_pixel_norm, use_instance_norm, use_styles, activation_layer)
+        self.conv = MyConv2d(nf, nf, 3, gain=gain, use_wscale=use_wscale)
+        self.epi2 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise,
+                                  use_pixel_norm, use_instance_norm, use_styles, activation_layer)
+
+    def forward(self, dlatents_in_range, noise_in_range):
+        batch_size = dlatents_in_range.size(0)
+        if self.const_input_layer:
+            x = self.const.expand(batch_size, -1, -1, -1)
+            x = x + self.bias.view(1, -1, 1, 1)
+        else:
+            x = self.dense(dlatents_in_range[:, 0]).view(
+                batch_size, self.nf, 4, 4)
+        x = self.epi1(x, dlatents_in_range[:, 0], noise_in_range[0])
+        x = self.conv(x)
+        x = self.epi2(x, dlatents_in_range[:, 1], noise_in_range[1])
+        return x
+
+
+class GSynthesisBlock(nn.Module):
+    def __init__(self, in_channels, out_channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
+        # 2**res x 2**res # res = 3..resolution_log2
+        super().__init__()
+        if blur_filter:
+            blur = BlurLayer(blur_filter)
+        else:
+            blur = None
+        self.conv0_up = MyConv2d(in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale,
+                                 intermediate=blur, upscale=True)
+        self.epi1 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise,
+                                  use_pixel_norm, use_instance_norm, use_styles, activation_layer)
+        self.conv1 = MyConv2d(out_channels, out_channels,
+                              kernel_size=3, gain=gain, use_wscale=use_wscale)
+        self.epi2 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise,
+                                  use_pixel_norm, use_instance_norm, use_styles, activation_layer)
+
+    def forward(self, x, dlatents_in_range, noise_in_range):
+        x = self.conv0_up(x)
+        x = self.epi1(x, dlatents_in_range[:, 0], noise_in_range[0])
+        x = self.conv1(x)
+        x = self.epi2(x, dlatents_in_range[:, 1], noise_in_range[1])
+        return x
+
+
+class G_synthesis(nn.Module):
+    def __init__(self,
+                 # Disentangled latent (W) dimensionality.
+                 dlatent_size=512,
+                 num_channels=3,            # Number of output color channels.
+                 resolution=1024,         # Output resolution.
+                 # Overall multiplier for the number of feature maps.
+                 fmap_base=8192,
+                 # log2 feature map reduction when doubling the resolution.
+                 fmap_decay=1.0,
+                 # Maximum number of feature maps in any layer.
+                 fmap_max=512,
+                 use_styles=True,         # Enable style inputs?
+                 const_input_layer=True,         # First layer is a learned constant?
+                 use_noise=True,         # Enable noise inputs?
+                 # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
+                 randomize_noise=True,
+                 nonlinearity='lrelu',      # Activation function: 'relu', 'lrelu'
+                 use_wscale=True,         # Enable equalized learning rate?
+                 use_pixel_norm=False,        # Enable pixelwise feature vector normalization?
+                 use_instance_norm=True,         # Enable instance normalization?
+                 # Data type to use for activations and outputs.
+                 dtype=torch.float32,
+                 # Low-pass filter to apply when resampling activations. None = no filtering.
+                 blur_filter=[1, 2, 1],
+                 ):
+
+        super().__init__()
+
+        def nf(stage):
+            return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
+        self.dlatent_size = dlatent_size
+        resolution_log2 = int(np.log2(resolution))
+        assert resolution == 2**resolution_log2 and resolution >= 4
+
+        act, gain = {'relu': (torch.relu, np.sqrt(2)),
+                     'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
+        num_layers = resolution_log2 * 2 - 2
+        num_styles = num_layers if use_styles else 1
+        torgbs = []
+        blocks = []
+        for res in range(2, resolution_log2 + 1):
+            channels = nf(res-1)
+            name = '{s}x{s}'.format(s=2**res)
+            if res == 2:
+                blocks.append((name,
+                               InputBlock(channels, dlatent_size, const_input_layer, gain, use_wscale,
+                                          use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))
+
+            else:
+                blocks.append((name,
+                               GSynthesisBlock(last_channels, channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))
+            last_channels = channels
+        self.torgb = MyConv2d(channels, num_channels, 1,
+                              gain=1, use_wscale=use_wscale)
+        self.blocks = nn.ModuleDict(OrderedDict(blocks))
+
+    def forward(self, dlatents_in, noise_in):
+        # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].
+        # lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype)
+        batch_size = dlatents_in.size(0)
+        for i, m in enumerate(self.blocks.values()):
+            if i == 0:
+                x = m(dlatents_in[:, 2*i:2*i+2], noise_in[2*i:2*i+2])
+            else:
+                x = m(x, dlatents_in[:, 2*i:2*i+2], noise_in[2*i:2*i+2])
+        rgb = self.torgb(x)
+        return rgb