SphericalOptimizer.py 909 B

1234567891011121314151617181920212223242526
  1. import math
  2. import torch
  3. from torch.optim import Optimizer
  4. # Spherical Optimizer Class
  5. # Uses the first two dimensions as batch information
  6. # Optimizes over the surface of a sphere using the initial radius throughout
  7. #
  8. # Example Usage:
  9. # opt = SphericalOptimizer(torch.optim.SGD, [x], lr=0.01)
  10. class SphericalOptimizer(Optimizer):
  11. def __init__(self, optimizer, params, **kwargs):
  12. self.opt = optimizer(params, **kwargs)
  13. self.params = params
  14. with torch.no_grad():
  15. self.radii = {param: (param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt() for param in params}
  16. @torch.no_grad()
  17. def step(self, closure=None):
  18. loss = self.opt.step(closure)
  19. for param in self.params:
  20. param.data.div_((param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt())
  21. param.mul_(self.radii[param])
  22. return loss