registry.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. package server
  2. import (
  3. "encoding/gob"
  4. "fmt"
  5. "ngrok/cache"
  6. "ngrok/log"
  7. "sync"
  8. "time"
  9. )
  10. const (
  11. cacheSaveInterval time.Duration = 10 * time.Minute
  12. )
  13. type cacheUrl string
  14. func (url cacheUrl) Size() int {
  15. return len(url)
  16. }
  17. // TunnelRegistry maps a tunnel URL to Tunnel structures
  18. type TunnelRegistry struct {
  19. tunnels map[string]*Tunnel
  20. affinity *cache.LRUCache
  21. log.Logger
  22. sync.RWMutex
  23. }
  24. func NewTunnelRegistry(cacheSize uint64, cacheFile string) *TunnelRegistry {
  25. registry := &TunnelRegistry{
  26. tunnels: make(map[string]*Tunnel),
  27. affinity: cache.NewLRUCache(cacheSize),
  28. Logger: log.NewPrefixLogger("registry", "tun"),
  29. }
  30. // LRUCache uses Gob encoding. Unfortunately, Gob is fickle and will fail
  31. // to encode or decode any non-primitive types that haven't been "registered"
  32. // with it. Since we store cacheUrl objects, we need to register them here first
  33. // for the encoding/decoding to work
  34. var urlobj cacheUrl
  35. gob.Register(urlobj)
  36. // try to load and then periodically save the affinity cache to file, if specified
  37. if cacheFile != "" {
  38. err := registry.affinity.LoadItemsFromFile(cacheFile)
  39. if err != nil {
  40. registry.Error("Failed to load affinity cache %s: %v", cacheFile, err)
  41. }
  42. registry.SaveCacheThread(cacheFile, cacheSaveInterval)
  43. } else {
  44. registry.Info("No affinity cache specified")
  45. }
  46. return registry
  47. }
  48. // Spawns a goroutine the periodically saves the cache to a file.
  49. func (r *TunnelRegistry) SaveCacheThread(path string, interval time.Duration) {
  50. go func() {
  51. r.Info("Saving affinity cache to %s every %s", path, interval.String())
  52. for {
  53. time.Sleep(interval)
  54. r.Debug("Saving affinity cache")
  55. err := r.affinity.SaveItemsToFile(path)
  56. if err != nil {
  57. r.Error("Failed to save affinity cache: %v", err)
  58. } else {
  59. r.Info("Saved affinity cache")
  60. }
  61. }
  62. }()
  63. }
  64. // Register a tunnel with a specific url, returns an error
  65. // if a tunnel is already registered at that url
  66. func (r *TunnelRegistry) Register(url string, t *Tunnel) error {
  67. r.Lock()
  68. defer r.Unlock()
  69. if r.tunnels[url] != nil {
  70. return fmt.Errorf("The tunnel %s is already registered.", url)
  71. }
  72. r.tunnels[url] = t
  73. return nil
  74. }
  75. func (r *TunnelRegistry) cacheKeys(t *Tunnel) (string) {
  76. return fmt.Sprintf("client-id-%s:%s", t.req.Protocol, t.ctl.id)
  77. }
  78. func (r *TunnelRegistry) GetCachedRegistration(t *Tunnel) (url string) {
  79. idCacheKey := r.cacheKeys(t)
  80. // check cache for ID first, because we prefer that over IP which might
  81. // not be specific to a user because of NATs
  82. if v, ok := r.affinity.Get(idCacheKey); ok {
  83. url = string(v.(cacheUrl))
  84. t.Debug("Found registry affinity %s for %s", url, idCacheKey)
  85. }
  86. return
  87. }
  88. func (r *TunnelRegistry) RegisterAndCache(url string, t *Tunnel) (err error) {
  89. if err = r.Register(url, t); err == nil {
  90. // we successfully assigned a url, cache it
  91. idCacheKey := r.cacheKeys(t)
  92. r.affinity.Set(idCacheKey, cacheUrl(url))
  93. }
  94. return
  95. }
  96. // Register a tunnel with the following process:
  97. // Consult the affinity cache to try to assign a previously used tunnel url if possible
  98. // Generate new urls repeatedly with the urlFn and register until one is available.
  99. func (r *TunnelRegistry) RegisterRepeat(urlFn func() string, t *Tunnel) (string, error) {
  100. url := r.GetCachedRegistration(t)
  101. if url == "" {
  102. url = urlFn()
  103. }
  104. maxAttempts := 5
  105. for i := 0; i < maxAttempts; i++ {
  106. if err := r.RegisterAndCache(url, t); err != nil {
  107. // pick a new url and try again
  108. url = urlFn()
  109. } else {
  110. // we successfully assigned a url, we're done
  111. return url, nil
  112. }
  113. }
  114. return "", fmt.Errorf("Failed to assign a URL after %d attempts!", maxAttempts)
  115. }
  116. func (r *TunnelRegistry) Del(url string) {
  117. r.Lock()
  118. defer r.Unlock()
  119. delete(r.tunnels, url)
  120. }
  121. func (r *TunnelRegistry) Get(url string) *Tunnel {
  122. r.RLock()
  123. defer r.RUnlock()
  124. return r.tunnels[url]
  125. }
  126. // ControlRegistry maps a client ID to Control structures
  127. type ControlRegistry struct {
  128. controls map[string]*Control
  129. log.Logger
  130. sync.RWMutex
  131. }
  132. func NewControlRegistry() *ControlRegistry {
  133. return &ControlRegistry{
  134. controls: make(map[string]*Control),
  135. Logger: log.NewPrefixLogger("registry", "ctl"),
  136. }
  137. }
  138. func (r *ControlRegistry) Get(clientId string) *Control {
  139. r.RLock()
  140. defer r.RUnlock()
  141. return r.controls[clientId]
  142. }
  143. func (r *ControlRegistry) Add(clientId string, ctl *Control) (oldCtl *Control) {
  144. r.Lock()
  145. defer r.Unlock()
  146. oldCtl = r.controls[clientId]
  147. if oldCtl != nil {
  148. oldCtl.Replaced(ctl)
  149. }
  150. r.controls[clientId] = ctl
  151. r.Info("Registered control with id %s", clientId)
  152. return
  153. }
  154. func (r *ControlRegistry) Del(clientId string) error {
  155. r.Lock()
  156. defer r.Unlock()
  157. if r.controls[clientId] == nil {
  158. return fmt.Errorf("No control found for client id: %s", clientId)
  159. } else {
  160. r.Info("Removed control registry id %s", clientId)
  161. delete(r.controls, clientId)
  162. return nil
  163. }
  164. }