Browse Source

refactor tunnel registry. use an LRU size-limited cache instead a time-based cache for tunnel urls. remove nrsc script that is no longer used

Alan Shreve 12 years ago
parent
commit
885e29abde
6 changed files with 413 additions and 181 deletions
  1. 0 46
      nrsc
  2. 250 0
      src/ngrok/cache/lru.go
  3. 9 9
      src/ngrok/server/main.go
  4. 0 120
      src/ngrok/server/manager.go
  5. 118 0
      src/ngrok/server/registry.go
  6. 36 6
      src/ngrok/server/tunnel.go

+ 0 - 46
nrsc

@@ -1,46 +0,0 @@
-#!/bin/bash
-# Pack assets as zip payload in go executable
-
-# Idea from Carlos Castillo (http://bit.ly/SmYXXm)
-
-case "$1" in
-    -h | --help )
-        echo "usage: $(basename $0) EXECTABLE RESOURCE_DIR [ZIP OPTIONS]";
-        exit;;
-    --version )
-        echo "nrsc version 0.3.1"; exit;;
-esac
-
-if [ $# -lt 2 ]; then
-    $0 -h
-    exit 1
-fi
-
-exe=$1
-shift
-root=$1
-shift
-
-if [ ! -f "${exe}" ]; then
-    echo "error: can't find $exe"
-    exit 1
-fi
-
-if [ ! -d "${root}" ]; then
-    echo "error: ${root} is not a directory"
-    exit 1
-fi
-
-# Exit on 1'st error
-set -e
-
-tmp="/tmp/nrsc-$(date +%s).zip"
-trap "rm -f ${tmp}" EXIT
-
-# Create zip file
-(zip -r "${tmp}" ${root} $@)
-
-# Append zip to executable
-cat "${tmp}" >> "${exe}"
-# Fix zip offset in file
-zip -q -A "${exe}"

+ 250 - 0
src/ngrok/cache/lru.go

@@ -0,0 +1,250 @@
+// Copyright 2012, Google Inc. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The implementation borrows heavily from SmallLRUCache (originally by Nathan
+// Schrenk). The object maintains a doubly-linked list of elements in the
+// When an element is accessed it is promoted to the head of the list, and when
+// space is needed the element at the tail of the list (the least recently used
+// element) is evicted.
+package cache
+
+import (
+	"container/list"
+	"encoding/gob"
+	"fmt"
+	"io"
+	"os"
+	"sync"
+	"time"
+)
+
+type LRUCache struct {
+	mu sync.Mutex
+
+	// list & table of *entry objects
+	list  *list.List
+	table map[string]*list.Element
+
+	// Our current size, in bytes. Obviously a gross simplification and low-grade
+	// approximation.
+	size uint64
+
+	// How many bytes we are limiting the cache to.
+	capacity uint64
+}
+
+// Values that go into LRUCache need to satisfy this interface.
+type Value interface {
+	Size() int
+}
+
+type Item struct {
+	Key   string
+	Value Value
+}
+
+type entry struct {
+	key           string
+	value         Value
+	size          int
+	time_accessed time.Time
+}
+
+func NewLRUCache(capacity uint64) *LRUCache {
+	return &LRUCache{
+		list:     list.New(),
+		table:    make(map[string]*list.Element),
+		capacity: capacity,
+	}
+}
+
+func (lru *LRUCache) Get(key string) (v Value, ok bool) {
+	lru.mu.Lock()
+	defer lru.mu.Unlock()
+
+	element := lru.table[key]
+	if element == nil {
+		return nil, false
+	}
+	lru.moveToFront(element)
+	return element.Value.(*entry).value, true
+}
+
+func (lru *LRUCache) Set(key string, value Value) {
+	lru.mu.Lock()
+	defer lru.mu.Unlock()
+
+	if element := lru.table[key]; element != nil {
+		lru.updateInplace(element, value)
+	} else {
+		lru.addNew(key, value)
+	}
+}
+
+func (lru *LRUCache) SetIfAbsent(key string, value Value) {
+	lru.mu.Lock()
+	defer lru.mu.Unlock()
+
+	if element := lru.table[key]; element != nil {
+		lru.moveToFront(element)
+	} else {
+		lru.addNew(key, value)
+	}
+}
+
+func (lru *LRUCache) Delete(key string) bool {
+	lru.mu.Lock()
+	defer lru.mu.Unlock()
+
+	element := lru.table[key]
+	if element == nil {
+		return false
+	}
+
+	lru.list.Remove(element)
+	delete(lru.table, key)
+	lru.size -= uint64(element.Value.(*entry).size)
+	return true
+}
+
+func (lru *LRUCache) Clear() {
+	lru.mu.Lock()
+	defer lru.mu.Unlock()
+
+	lru.list.Init()
+	lru.table = make(map[string]*list.Element)
+	lru.size = 0
+}
+
+func (lru *LRUCache) SetCapacity(capacity uint64) {
+	lru.mu.Lock()
+	defer lru.mu.Unlock()
+
+	lru.capacity = capacity
+	lru.checkCapacity()
+}
+
+func (lru *LRUCache) Stats() (length, size, capacity uint64, oldest time.Time) {
+	lru.mu.Lock()
+	defer lru.mu.Unlock()
+	if lastElem := lru.list.Back(); lastElem != nil {
+		oldest = lastElem.Value.(*entry).time_accessed
+	}
+	return uint64(lru.list.Len()), lru.size, lru.capacity, oldest
+}
+
+func (lru *LRUCache) StatsJSON() string {
+	if lru == nil {
+		return "{}"
+	}
+	l, s, c, o := lru.Stats()
+	return fmt.Sprintf("{\"Length\": %v, \"Size\": %v, \"Capacity\": %v, \"OldestAccess\": \"%v\"}", l, s, c, o)
+}
+
+func (lru *LRUCache) Keys() []string {
+	lru.mu.Lock()
+	defer lru.mu.Unlock()
+
+	keys := make([]string, 0, lru.list.Len())
+	for e := lru.list.Front(); e != nil; e = e.Next() {
+		keys = append(keys, e.Value.(*entry).key)
+	}
+	return keys
+}
+
+func (lru *LRUCache) Items() []Item {
+	lru.mu.Lock()
+	defer lru.mu.Unlock()
+
+	items := make([]Item, 0, lru.list.Len())
+	for e := lru.list.Front(); e != nil; e = e.Next() {
+		v := e.Value.(*entry)
+		items = append(items, Item{Key: v.key, Value: v.value})
+	}
+	return items
+}
+
+func (lru *LRUCache) SaveItems(w io.Writer) error {
+	items := lru.Items()
+
+	for _, v := range items {
+		gob.Register(v)
+	}
+
+	encoder := gob.NewEncoder(w)
+	return encoder.Encode(items)
+}
+
+func (lru *LRUCache) SaveItemsToFile(path string) error {
+	if wr, err := os.Open(path); err != nil {
+		return err
+	} else {
+		defer wr.Close()
+		return lru.SaveItems(wr)
+	}
+}
+
+func (lru *LRUCache) LoadItems(r io.Reader) error {
+	items := make([]Item, 0)
+	decoder := gob.NewDecoder(r)
+	if err := decoder.Decode(items); err != nil {
+		return err
+	}
+
+	lru.mu.Lock()
+	lru.mu.Unlock()
+	for _, item := range items {
+		// XXX: copied from Set()
+		if element := lru.table[item.Key]; element != nil {
+			lru.updateInplace(element, item.Value)
+		} else {
+			lru.addNew(item.Key, item.Value)
+		}
+	}
+
+	return nil
+}
+
+func (lru *LRUCache) LoadItemsFromFile(path string) error {
+	if rd, err := os.Open(path); err != nil {
+		return err
+	} else {
+		defer rd.Close()
+		return lru.LoadItems(rd)
+	}
+}
+
+func (lru *LRUCache) updateInplace(element *list.Element, value Value) {
+	valueSize := value.Size()
+	sizeDiff := valueSize - element.Value.(*entry).size
+	element.Value.(*entry).value = value
+	element.Value.(*entry).size = valueSize
+	lru.size += uint64(sizeDiff)
+	lru.moveToFront(element)
+	lru.checkCapacity()
+}
+
+func (lru *LRUCache) moveToFront(element *list.Element) {
+	lru.list.MoveToFront(element)
+	element.Value.(*entry).time_accessed = time.Now()
+}
+
+func (lru *LRUCache) addNew(key string, value Value) {
+	newEntry := &entry{key, value, value.Size(), time.Now()}
+	element := lru.list.PushFront(newEntry)
+	lru.table[key] = element
+	lru.size += uint64(newEntry.size)
+	lru.checkCapacity()
+}
+
+func (lru *LRUCache) checkCapacity() {
+	// Partially duplicated from Delete
+	for lru.size > lru.capacity {
+		delElem := lru.list.Back()
+		delValue := delElem.Value.(*entry)
+		lru.list.Remove(delElem)
+		delete(lru.table, delValue.key)
+		lru.size -= uint64(delValue.size)
+	}
+}

+ 9 - 9
src/ngrok/server/main.go

@@ -7,7 +7,7 @@ import (
 	"ngrok/conn"
 	"ngrok/conn"
 	log "ngrok/log"
 	log "ngrok/log"
 	"ngrok/msg"
 	"ngrok/msg"
-	"regexp"
+	"os"
 )
 )
 
 
 type Options struct {
 type Options struct {
@@ -20,15 +20,12 @@ type Options struct {
 
 
 /* GLOBALS */
 /* GLOBALS */
 var (
 var (
-	hostRegex *regexp.Regexp
-	proxyAddr string
-	tunnels   *TunnelManager
+	proxyAddr         string
+	tunnels           *TunnelRegistry
+	registryCacheSize uint64 = 1024 * 1024 // 1 MB
+	domain            string
 )
 )
 
 
-func init() {
-	hostRegex = regexp.MustCompile("[H|h]ost: ([^\\(\\);:,<>]+)\n")
-}
-
 func parseArgs() *Options {
 func parseArgs() *Options {
 	publicPort := flag.Int("publicport", 80, "Public port")
 	publicPort := flag.Int("publicport", 80, "Public port")
 	tunnelPort := flag.Int("tunnelport", 4443, "Tunnel port")
 	tunnelPort := flag.Int("tunnelport", 4443, "Tunnel port")
@@ -110,11 +107,14 @@ func proxyListener(addr *net.TCPAddr, domain string) {
 func Main() {
 func Main() {
 	// parse options
 	// parse options
 	opts := parseArgs()
 	opts := parseArgs()
+	domain = opts.domain
 
 
 	// init logging
 	// init logging
 	log.LogTo(opts.logto)
 	log.LogTo(opts.logto)
 
 
-	tunnels = NewTunnelManager(opts.domain)
+	// init tunnel registry
+	registryCacheFile := os.Getenv("REGISTRY_CACHE_FILE")
+	tunnels = NewTunnelRegistry(registryCacheSize, registryCacheFile)
 
 
 	go proxyListener(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: opts.proxyPort}, opts.domain)
 	go proxyListener(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: opts.proxyPort}, opts.domain)
 	go controlListener(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: opts.tunnelPort}, opts.domain)
 	go controlListener(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: opts.tunnelPort}, opts.domain)

+ 0 - 120
src/ngrok/server/manager.go

@@ -1,120 +0,0 @@
-package server
-
-import (
-	"fmt"
-	cache "github.com/pmylund/go-cache"
-	"math/rand"
-	"net"
-	"strings"
-	"sync"
-	"time"
-)
-
-const (
-	cacheDuration        time.Duration = 24 * time.Hour
-	cacheCleanupInterval time.Duration = time.Minute
-)
-
-/**
- * TunnelManager: Manages a set of tunnels
- */
-type TunnelManager struct {
-	domain           string
-	tunnels          map[string]*Tunnel
-	idDomainAffinity *cache.Cache
-	ipDomainAffinity *cache.Cache
-	sync.RWMutex
-}
-
-func NewTunnelManager(domain string) *TunnelManager {
-	return &TunnelManager{
-		domain:           domain,
-		tunnels:          make(map[string]*Tunnel),
-		idDomainAffinity: cache.New(cacheDuration, cacheCleanupInterval),
-		ipDomainAffinity: cache.New(cacheDuration, cacheCleanupInterval),
-	}
-}
-
-func (m *TunnelManager) Add(t *Tunnel) error {
-	assignTunnel := func(url string) bool {
-		m.Lock()
-		defer m.Unlock()
-
-		if m.tunnels[url] == nil {
-			m.tunnels[url] = t
-			return true
-		}
-
-		return false
-	}
-
-	url := ""
-	switch t.regMsg.Protocol {
-	case "tcp":
-		addr := t.listener.Addr().(*net.TCPAddr)
-		url = fmt.Sprintf("tcp://%s:%d", m.domain, addr.Port)
-		if !assignTunnel(url) {
-			return t.Error("TCP at %s already registered!", url)
-		}
-
-	case "http":
-		if strings.TrimSpace(t.regMsg.Hostname) != "" {
-			url = fmt.Sprintf("http://%s", t.regMsg.Hostname)
-		} else if strings.TrimSpace(t.regMsg.Subdomain) != "" {
-			url = fmt.Sprintf("http://%s.%s", t.regMsg.Subdomain, m.domain)
-		}
-
-		if url != "" {
-			if !assignTunnel(url) {
-				return t.Warn("The tunnel address %s is already registered!", url)
-			}
-		} else {
-			clientIp := t.ctl.conn.RemoteAddr().(*net.TCPAddr).IP.String()
-			clientId := t.regMsg.ClientId
-
-			// try to give the same subdomain back if it's available
-			subdomain := fmt.Sprintf("%x", rand.Int31())
-			if lastDomain, ok := m.idDomainAffinity.Get(clientId); ok {
-				t.Debug("Found affinity for subdomain %s with client id %s", subdomain, clientId)
-				subdomain = lastDomain.(string)
-			} else if lastDomain, ok = m.ipDomainAffinity.Get(clientIp); ok {
-				t.Debug("Found affinity for subdomain %s with client ip %s", subdomain, clientIp)
-				subdomain = lastDomain.(string)
-			}
-
-			// pick one randomly
-			for {
-				url = fmt.Sprintf("http://%s.%s", subdomain, m.domain)
-				if assignTunnel(url) {
-					break
-				} else {
-					subdomain = fmt.Sprintf("%x", rand.Int31())
-				}
-			}
-
-			// save our choice so we can try to give clients back the same
-			// tunnel later
-			m.idDomainAffinity.Set(clientId, subdomain, 0)
-			m.ipDomainAffinity.Set(clientIp, subdomain, 0)
-		}
-
-	default:
-		return t.Error("Unrecognized protocol type %s", t.regMsg.Protocol)
-	}
-
-	t.url = url
-
-	return nil
-}
-
-func (m *TunnelManager) Del(url string) {
-	m.Lock()
-	defer m.Unlock()
-	delete(m.tunnels, url)
-}
-
-func (m *TunnelManager) Get(url string) *Tunnel {
-	m.RLock()
-	defer m.RUnlock()
-	return m.tunnels[url]
-}

+ 118 - 0
src/ngrok/server/registry.go

@@ -0,0 +1,118 @@
+package server
+
+import (
+	"fmt"
+	"net"
+	"ngrok/cache"
+	"sync"
+	"time"
+)
+
+const (
+	cacheSaveInterval time.Duration = 10 * time.Minute
+)
+
+type cacheUrl string
+
+func (url cacheUrl) Size() int {
+	return len(url)
+}
+
+// TunnelRegistry maps a tunnel URL to Tunnel structures
+type TunnelRegistry struct {
+	tunnels  map[string]*Tunnel
+	affinity *cache.LRUCache
+	sync.RWMutex
+}
+
+func NewTunnelRegistry(cacheSize uint64, cacheFile string) *TunnelRegistry {
+	manager := &TunnelRegistry{
+		tunnels:  make(map[string]*Tunnel),
+		affinity: cache.NewLRUCache(cacheSize),
+	}
+
+	if cacheFile != "" {
+		// load cache entries from file
+		manager.affinity.LoadItemsFromFile(cacheFile)
+
+		// save cache periodically to file
+		manager.SaveCacheThread(cacheFile, cacheSaveInterval)
+	}
+
+	return manager
+}
+
+// Spawns a goroutine the periodically saves the cache to a file.
+func (r *TunnelRegistry) SaveCacheThread(path string, interval time.Duration) {
+	go func() {
+		for {
+			time.Sleep(interval)
+			r.affinity.SaveItemsToFile(path)
+		}
+	}()
+}
+
+// Register a tunnel with a specific url, returns an error
+// if a tunnel is already registered at that url
+func (r *TunnelRegistry) Register(url string, t *Tunnel) error {
+	r.Lock()
+	defer r.Unlock()
+
+	if r.tunnels[url] != nil {
+		return fmt.Errorf("The tunnel %s is already registered.", url)
+	}
+
+	return nil
+}
+
+// Register a tunnel with the following process:
+// Consult the affinity cache to try to assign a previously used tunnel url if possible
+// Generate new urls repeatedly with the urlFn and register until one is available.
+func (r *TunnelRegistry) RegisterRepeat(urlFn func() string, t *Tunnel) string {
+	var url string
+
+	clientIp := t.ctl.conn.RemoteAddr().(*net.TCPAddr).IP.String()
+	clientId := t.regMsg.ClientId
+
+	ipCacheKey := fmt.Sprintf("client-ip:%s", clientIp)
+	idCacheKey := fmt.Sprintf("client-id:%s", clientId)
+
+	// check cache for ID first, because we prefer that over IP which might
+	// not be specific to a user because of NATs
+	if v, ok := r.affinity.Get(idCacheKey); ok {
+		url = string(v.(cacheUrl))
+		t.Debug("Found registry affinity %s for %s", url, idCacheKey)
+	} else if v, ok := r.affinity.Get(ipCacheKey); ok {
+		url = string(v.(cacheUrl))
+		t.Debug("Found registry affinity %s for %s", url, ipCacheKey)
+	} else {
+		url = urlFn()
+	}
+
+	for {
+		if err := r.Register(url, t); err != nil {
+			// pick a new url and try again
+			url = urlFn()
+		} else {
+			// we successfully assigned a url, we're done
+
+			// save our choice in the cache
+			r.affinity.Set(ipCacheKey, cacheUrl(url))
+			r.affinity.Set(idCacheKey, cacheUrl(url))
+
+			return url
+		}
+	}
+}
+
+func (r *TunnelRegistry) Del(url string) {
+	r.Lock()
+	defer r.Unlock()
+	delete(r.tunnels, url)
+}
+
+func (r *TunnelRegistry) Get(url string) *Tunnel {
+	r.RLock()
+	defer r.RUnlock()
+	return r.tunnels[url]
+}

+ 36 - 6
src/ngrok/server/tunnel.go

@@ -3,11 +3,13 @@ package server
 import (
 import (
 	"encoding/base64"
 	"encoding/base64"
 	"fmt"
 	"fmt"
+	"math/rand"
 	"net"
 	"net"
 	"ngrok/conn"
 	"ngrok/conn"
 	"ngrok/log"
 	"ngrok/log"
 	"ngrok/msg"
 	"ngrok/msg"
 	"ngrok/version"
 	"ngrok/version"
+	"strings"
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
 )
 )
@@ -50,6 +52,10 @@ func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
 		Logger:  log.NewPrefixLogger(),
 		Logger:  log.NewPrefixLogger(),
 	}
 	}
 
 
+	failReg := func(err error) {
+		t.ctl.stop <- &msg.RegAckMsg{Error: err.Error()}
+	}
+
 	switch t.regMsg.Protocol {
 	switch t.regMsg.Protocol {
 	case "tcp":
 	case "tcp":
 		var err error
 		var err error
@@ -61,18 +67,42 @@ func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
 			t.ctl.stop <- &msg.RegAckMsg{Error: "Internal server error"}
 			t.ctl.stop <- &msg.RegAckMsg{Error: "Internal server error"}
 		}
 		}
 
 
+		addr := t.listener.Addr().(*net.TCPAddr)
+		t.url = fmt.Sprintf("tcp://%s:%d", domain, addr.Port)
+
+		if err = tunnels.Register(t.url, t); err != nil {
+			// This should never be possible because the OS will only assign
+			// available ports to us.
+			t.Error("TCP listener bound, but failed to register: %s", err.Error())
+			t.listener.Close()
+			failReg(err)
+			return
+		}
+
 		go t.listenTcp(t.listener)
 		go t.listenTcp(t.listener)
 
 
-	default:
-	}
+	case "http":
+		if strings.TrimSpace(t.regMsg.Hostname) != "" {
+			t.url = fmt.Sprintf("http://%s", t.regMsg.Hostname)
+		} else if strings.TrimSpace(t.regMsg.Subdomain) != "" {
+			t.url = fmt.Sprintf("http://%s.%s", t.regMsg.Subdomain, domain)
+		}
 
 
-	if err := tunnels.Add(t); err != nil {
-		t.ctl.stop <- &msg.RegAckMsg{Error: fmt.Sprint(err)}
-		return
+		if t.url != "" {
+			if err := tunnels.Register(t.url, t); err != nil {
+				failReg(err)
+				return
+			}
+		} else {
+			t.url = tunnels.RegisterRepeat(func() string {
+				return fmt.Sprintf("http://%x.%s", rand.Int31(), domain)
+			}, t)
+		}
 	}
 	}
 
 
 	if m.Version != version.Proto {
 	if m.Version != version.Proto {
-		t.ctl.stop <- &msg.RegAckMsg{Error: fmt.Sprintf("Incompatible versions. Server %s, client %s. Download a new version at http://ngrok.com", version.MajorMinor(), m.Version)}
+		failReg(fmt.Errorf("Incompatible versions. Server %s, client %s. Download a new version at http://ngrok.com", version.MajorMinor(), m.Version))
+		return
 	}
 	}
 
 
 	// pre-encode the http basic auth for fast comparisons later
 	// pre-encode the http basic auth for fast comparisons later