Browse Source

incremental progress on supporting multiple active tunnels in the ngrok client

Alan Shreve 12 years ago
parent
commit
ea1b1ed632

+ 150 - 0
src/ngrok/client/controller/controller.go

@@ -0,0 +1,150 @@
+package controller
+
+import (
+	"ngrok/client/mvc"
+	"ngrok/util"
+	"sync"
+)
+
+type command interface{}
+
+type cmdQuit struct {
+	// display this message after quit
+	message string
+}
+
+type cmdPlayRequest struct {
+	// the tunnel to play this request over
+	tunnel mvc.Tunnel
+
+	// the bytes of the request to issue
+	payload []byte
+}
+
+// The MVC Controller
+type Controller struct {
+	// the model sends updates through this broadcast channel
+	updates *util.Broadcast
+
+	// the model
+	model mvc.Model
+
+	// the views
+	view []mvc.View
+
+	// interal structure to issue commands to the controller
+	cmds chan Command
+}
+
+// public interface
+func NewController(model mvc.Model) *Controller {
+	ctl := &Controller{
+		updates:  util.NewBroadcast(),
+		model:	  model,
+		cmds:     make(chan command),
+		view:     make([]View),
+	}
+
+	return ctl
+}
+
+func (ctl *Controller) Update(state State) {
+	ctl.Updates.In() <- state
+}
+
+func (ctl *Controller) Shutdown(message string) {
+	ctl.cmds <- cmdQuit{message: message}
+}
+
+func (ctl *Controller) PlayRequest(tunnel *mvc.Tunnel, payload []byte) {
+	ctl.cmd <- cmdPlayRequest{tunnel: tunnel, payload: payload}
+}
+
+
+// private functions
+func (ctl *Controller) doShutdown() {
+	var wg sync.WaitGroup
+
+	// wait for all of the views, plus the model
+	wg.Add(len(ctl.views) + 1)
+
+	for v := range ctl.Views {
+		go v.Shutdown(&wg)
+	}
+	go model.Shutdown(&wg)
+
+	wg.Wait()
+}
+
+func (ctl *Controller) Go(fn func()) {
+	defer func() {
+		if r := recover(); r != nil {
+			// XXX
+		}
+	}()
+
+	go fn()
+}
+
+func (ctl *Controller) Run() {
+	// parse options
+	opts := parseArgs()
+
+	// set up logging
+	log.LogTo(opts.logto)
+
+	// seed random number generator
+	seed, err := util.RandomSeed()
+	if err != nil {
+		log.Error("Couldn't securely seed the random number generator!")
+	}
+	rand.Seed(seed)
+
+	// set up auth token
+	if opts.authtoken == "" {
+		opts.authtoken = LoadAuthToken()
+	}
+
+	// init web ui
+	if opts.webport != -1 {
+		ctl.views = append(ctl.views, web.NewWebView(ctl, ctl.model, opts.webport))
+	}
+
+	// init term ui
+	if opts.logto != "stdout" {
+		ctl.views = append(ctl.views, term.New(ctl, ctl.model))
+	}
+
+	ctl.Go(func() { autoUpdate(s, ctl, opts.authtoken) })
+
+	reg := &msg.RegMsg{
+		Protocol: opts.protocol,
+		Hostname: opts.hostname,
+		Subdomain: opts.subdomain,
+		HttpAuth: opts.httpAuth,
+		User: opts.user,
+		Password: opts.password,
+	}
+
+	ctl.Go(func() { ctl.model.Run(opts.serverAddr, opts.authtoken, ctl, tunnel) })
+
+	quitMessage = ""
+	defer func() {
+		ctl.doShutdown()
+		fmt.Printf(quitMessage)
+	}()
+
+	for {
+		select {
+		case obj := <-ctl.cmds:
+			switch cmd := obj.(type) {
+			case cmdQuit:
+				quitMessage = cmd.Message
+				return
+
+			case cmdPlayRequest:
+				ctl.Go(func() { model.PlayRequest(tunnel, cmd.Payload) })
+			}
+		}
+	}
+}

+ 3 - 285
src/ngrok/client/main.go

@@ -1,292 +1,10 @@
 package client
 
 import (
-	"fmt"
-	"io/ioutil"
-	"math"
-	"ngrok/client/ui"
-	"ngrok/client/views/term"
-	"ngrok/client/views/web"
-	"ngrok/conn"
-	"ngrok/log"
-	"ngrok/msg"
-	"ngrok/proto"
-	"ngrok/util"
-	"ngrok/version"
-	"runtime"
-	"sync/atomic"
-	"time"
+	"ngrok/client/controller"
 )
 
-const (
-	pingInterval        = 20 * time.Second
-	maxPongLatency      = 15 * time.Second
-	updateCheckInterval = 6 * time.Hour
-	BadGateway          = `<html>
-<body style="background-color: #97a8b9">
-    <div style="margin:auto; width:400px;padding: 20px 60px; background-color: #D3D3D3; border: 5px solid maroon;">
-        <h2>Tunnel %s unavailable</h2>
-        <p>Unable to initiate connection to <strong>%s</strong>. A web server must be running on port <strong>%s</strong> to complete the tunnel.</p>
-`
-)
-
-/**
- * Establishes and manages a tunnel proxy connection with the server
- */
-func proxy(proxyAddr string, url string, s *State, ctl *ui.Controller) {
-	start := time.Now()
-	remoteConn, err := conn.Dial(proxyAddr, "pxy", tlsConfig)
-	if err != nil {
-		log.Error("Failed to establish proxy connection: %v", err)
-		return
-	}
-
-	defer remoteConn.Close()
-	err = msg.WriteMsg(remoteConn, &msg.RegProxyMsg{Url: url, ClientId: s.id})
-	if err != nil {
-		log.Error("Failed to write RegProxyMsg: %v", err)
-		return
-	}
-
-	localConn, err := conn.Dial(s.opts.localaddr, "prv", nil)
-	if err != nil {
-		remoteConn.Warn("Failed to open private leg %s: %v", s.opts.localaddr, err)
-		badGatewayBody := fmt.Sprintf(BadGateway, s.publicUrl, s.opts.localaddr, s.opts.localaddr)
-		remoteConn.Write([]byte(fmt.Sprintf(`HTTP/1.0 502 Bad Gateway
-Content-Type: text/html
-Content-Length: %d
-
-%s`, len(badGatewayBody), badGatewayBody)))
-		return
-	}
-	defer localConn.Close()
-
-	m := s.metrics
-	m.proxySetupTimer.Update(time.Since(start))
-	m.connMeter.Mark(1)
-	ctl.Update(s)
-	m.connTimer.Time(func() {
-		localConn := s.protocol.WrapConn(localConn)
-		bytesIn, bytesOut := conn.Join(localConn, remoteConn)
-		m.bytesIn.Update(bytesIn)
-		m.bytesOut.Update(bytesOut)
-		m.bytesInCount.Inc(bytesIn)
-		m.bytesOutCount.Inc(bytesOut)
-	})
-	ctl.Update(s)
-}
-
-/*
- * Hearbeating to ensure our connection ngrokd is still live
- */
-func heartbeat(lastPongAddr *int64, c conn.Conn) {
-	lastPing := time.Unix(atomic.LoadInt64(lastPongAddr)-1, 0)
-	ping := time.NewTicker(pingInterval)
-	pongCheck := time.NewTicker(time.Second)
-
-	defer func() {
-		c.Close()
-		ping.Stop()
-		pongCheck.Stop()
-	}()
-
-	for {
-		select {
-		case <-pongCheck.C:
-			lastPong := time.Unix(0, atomic.LoadInt64(lastPongAddr))
-			needPong := lastPong.Sub(lastPing) < 0
-			pongLatency := time.Since(lastPing)
-
-			if needPong && pongLatency > maxPongLatency {
-				c.Info("Last ping: %v, Last pong: %v", lastPing, lastPong)
-				c.Info("Connection stale, haven't gotten PongMsg in %d seconds", int(pongLatency.Seconds()))
-				return
-			}
-
-		case <-ping.C:
-			err := msg.WriteMsg(c, &msg.PingMsg{})
-			if err != nil {
-				c.Debug("Got error %v when writing PingMsg", err)
-				return
-			}
-			lastPing = time.Now()
-		}
-	}
-}
-
-func reconnectingControl(s *State, ctl *ui.Controller) {
-	// how long we should wait before we reconnect
-	maxWait := 30 * time.Second
-	wait := 1 * time.Second
-
-	for {
-		control(s, ctl)
-
-		if s.status == "online" {
-			wait = 1 * time.Second
-		}
-
-		log.Info("Waiting %d seconds before reconnecting", int(wait.Seconds()))
-		time.Sleep(wait)
-		// exponentially increase wait time
-		wait = 2 * wait
-		wait = time.Duration(math.Min(float64(wait), float64(maxWait)))
-		s.status = "reconnecting"
-		ctl.Update(s)
-	}
-}
-
-/**
- * Establishes and manages a tunnel control connection with the server
- */
-func control(s *State, ctl *ui.Controller) {
-	defer func() {
-		if r := recover(); r != nil {
-			log.Error("control recovering from failure %v", r)
-		}
-	}()
-
-	// establish control channel
-	conn, err := conn.Dial(s.opts.server, "ctl", tlsConfig)
-	if err != nil {
-		panic(err)
-	}
-	defer conn.Close()
-
-	// register with the server
-	err = msg.WriteMsg(conn, &msg.RegMsg{
-		Protocol:  s.opts.protocol,
-		OS:        runtime.GOOS,
-		HttpAuth:  s.opts.httpAuth,
-		Hostname:  s.opts.hostname,
-		Subdomain: s.opts.subdomain,
-		ClientId:  s.id,
-		Version:   version.Proto,
-		MmVersion: version.MajorMinor(),
-		User:      s.opts.authtoken,
-	})
-
-	if err != nil {
-		panic(err)
-	}
-
-	// wait for the server to ack our register
-	var regAck msg.RegAckMsg
-	if err = msg.ReadMsgInto(conn, &regAck); err != nil {
-		panic(err)
-	}
-
-	if regAck.Error != "" {
-		emsg := fmt.Sprintf("Server failed to allocate tunnel: %s", regAck.Error)
-		ctl.Cmds <- ui.CmdQuit{Message: emsg}
-		return
-	}
-
-	// update UI state
-	s.publicUrl = regAck.Url
-	conn.Info("Tunnel established at %v", s.GetPublicUrl())
-	s.status = "online"
-	s.serverVersion = regAck.MmVersion
-	ctl.Update(s)
-
-	SaveAuthToken(s.opts.authtoken)
-
-	// start the heartbeat
-	lastPong := time.Now().UnixNano()
-	go heartbeat(&lastPong, conn)
-
-	// main control loop
-	for {
-		var rawMsg msg.Message
-		if rawMsg, err = msg.ReadMsg(conn); err != nil {
-			panic(err)
-		}
-
-		switch m := rawMsg.(type) {
-		case *msg.ReqProxyMsg:
-			go proxy(regAck.ProxyAddr, m.Url, s, ctl)
-
-		case *msg.PongMsg:
-			atomic.StoreInt64(&lastPong, time.Now().UnixNano())
-		default:
-			conn.Warn("Ignoring unknown control message %v ", m)
-		}
-	}
-}
-
 func Main() {
-	// parse options
-	opts := parseArgs()
-
-	// set up logging
-	log.LogTo(opts.logto)
-
-	// set up auth token
-	if opts.authtoken == "" {
-		opts.authtoken = LoadAuthToken()
-	}
-
-	// init client state
-	s := &State{
-		status: "connecting",
-
-		// unique client id
-		id: util.RandIdOrPanic(8),
-
-		// command-line options
-		opts: opts,
-
-		// metrics
-		metrics: NewClientMetrics(),
-	}
-
-	switch opts.protocol {
-	case "http":
-		s.protocol = proto.NewHttp()
-	case "tcp":
-		s.protocol = proto.NewTcp()
-	}
-
-	// init ui
-	ctl := ui.NewController()
-	web.NewWebView(ctl, s, opts.webport)
-	if opts.logto != "stdout" {
-		term.New(ctl, s)
-	}
-
-	go reconnectingControl(s, ctl)
-	go autoUpdate(s, ctl, opts.authtoken)
-
-	quitMessage := ""
-	ctl.Wait.Add(1)
-	go func() {
-		defer ctl.Wait.Done()
-		for {
-			select {
-			case obj := <-ctl.Cmds:
-				switch cmd := obj.(type) {
-				case ui.CmdQuit:
-					quitMessage = cmd.Message
-					ctl.DoShutdown()
-					return
-				case ui.CmdRequest:
-					go func() {
-						var localConn conn.Conn
-						localConn, err := conn.Dial(s.opts.localaddr, "prv", nil)
-						if err != nil {
-							log.Warn("Failed to open private leg %s: %v", s.opts.localaddr, err)
-							return
-						}
-						//defer localConn.Close()
-						localConn = s.protocol.WrapConn(localConn)
-						localConn.Write(cmd.Payload)
-						ioutil.ReadAll(localConn)
-					}()
-				}
-			}
-		}
-	}()
-
-	ctl.Wait.Wait()
-	fmt.Println(quitMessage)
+	controller := controller.NewController(newClientModel())
+	controller.Run()
 }

+ 320 - 0
src/ngrok/client/model.go

@@ -0,0 +1,320 @@
+package client
+
+import (
+	"fmt"
+	"io/ioutil"
+	"math"
+	"math/rand"
+	"ngrok/client/mvc"
+	"ngrok/client/views/term"
+	"ngrok/client/views/web"
+	"ngrok/conn"
+	"ngrok/log"
+	"ngrok/msg"
+	"ngrok/proto"
+	"ngrok/util"
+	"ngrok/version"
+	"runtime"
+	"sync/atomic"
+	"time"
+	metrics "github.com/inconshreveable/go-metrics"
+)
+
+const (
+	pingInterval        = 20 * time.Second
+	maxPongLatency      = 15 * time.Second
+	updateCheckInterval = 6 * time.Hour
+	BadGateway          = `<html>
+<body style="background-color: #97a8b9">
+    <div style="margin:auto; width:400px;padding: 20px 60px; background-color: #D3D3D3; border: 5px solid maroon;">
+        <h2>Tunnel %s unavailable</h2>
+        <p>Unable to initiate connection to <strong>%s</strong>. A web server must be running on port <strong>%s</strong> to complete the tunnel.</p>
+`
+)
+
+type ClientModel {
+	log.Logger
+
+	id            string
+	tunnels       []mvc.Tunnel
+	serverVersion string
+	opts          *Options
+	metrics       *ClientMetrics
+	updateStatus  mvc.UpdateStatus
+	connStatus    mvc.ConnStatus
+	protoMap      map[string] *proto.Protocol
+}
+
+func newClient() {
+	protoMap := make(map[string] *proto.Protocol)
+	protoMap["http"] = proto.NewHttp()
+	protoMap["https"] = protoMap["http"]
+	protoMap["tcp"] = proto.NewTcp()
+
+	return &Client {
+		Logger: log.NewPrefixLogger("client"),
+
+		// unique client id
+		id: util.RandIdOrPanic(8),
+
+		// connection status
+		connStatus: mvc.ConnConnecting,
+
+		// update status
+		updateStatus: mvc.UpdateNone,
+
+		// command-line options
+		opts: opts,
+
+		// metrics
+		metrics: NewClientMetrics(),
+
+		// protocols
+		protoMap: protoMap,
+	}
+}
+
+// mvc.State interface
+func (c ClientModel) GetClientVersion() string    { return version.MajorMinor() }
+func (c ClientModel) GetServerVersion() string    { return c.serverVersion }
+func (c ClientModel) GetTunnels() []mvc.Tunnel    { return c.tunnels }
+func (c ClientModel) GetConnStatus() mvc.ConnStatus           { return c.connStatus }
+func (c ClientModel) GetUpdateStatus() mvc.UpdateStatus { return c.updateStatus }
+
+func (c ClientModel) GetConnectionMetrics() (metrics.Meter, metrics.Timer) {
+	return c.metrics.connMeter, c.metrics.connTimer
+}
+
+func (c ClientModel) GetBytesInMetrics() (metrics.Counter, metrics.Histogram) {
+	return c.metrics.bytesInCount, c.metrics.bytesIn
+}
+
+func (c ClientModel) GetBytesOutMetrics() (metrics.Counter, metrics.Histogram) {
+	return c.metrics.bytesOutCount, c.metrics.bytesOut
+}
+
+// mvc.Model interface
+func (c *ClientModel) PlayRequest(tunnel *mvc.Tunnel, payload []byte) {
+	t := m.tunnels[tunnel.PublicUrl]
+
+	var localConn conn.Conn
+	localConn, err := conn.Dial(t.localaddr, "prv", nil)
+	if err != nil {
+		m.Warn("Failed to open private leg to %s: %v", t.localaddr, err)
+		return
+	}
+	//defer localConn.Close()
+	localConn = t.protocol.WrapConn(localConn)
+	localConn.Write(payload)
+	ioutil.ReadAll(localConn)
+}
+
+func (c *ClientModel) Shutdown(wg *sync.WaitGroup) {
+	// there's no clean shutdown needed, do it immediately
+	wg.Done()
+}
+
+func (c *ClientModel) update() {
+	c.ctl.Update(m)
+}
+
+func (c *ClientModel) Run(serverAddr, authToken string, ctl mvc.Controller, tunnel *mvc.Tunnel) {
+	c.serverAddr = serverAddr
+	c.authToken = authToken
+	c.ctl = ctl
+	c.reconnectingControl(tunnel)
+}
+
+func (c *ClientModel) reconnectingControl(reg *msg.RegMsg) {
+	// how long we should wait before we reconnect
+	maxWait := 30 * time.Second
+	wait := 1 * time.Second
+
+	for {
+		c.control(reg)
+
+		if c.connStatus == mvc.ConnOnline {
+			wait = 1 * time.Second
+		}
+
+		log.Info("Waiting %d seconds before reconnecting", int(wait.Seconds()))
+		time.Sleep(wait)
+		// exponentially increase wait time
+		wait = 2 * wait
+		wait = time.Duration(math.Min(float64(wait), float64(maxWait)))
+		c.connStatus = mvc.ConnReconnecting
+		c.update()
+	}
+}
+
+// Establishes and manages a tunnel control connection with the server
+func (c *ClientModel) control(reg *msg.RegMsg) {
+	defer func() {
+		if r := recover(); r != nil {
+			log.Error("control recovering from failure %v", r)
+		}
+	}()
+
+	// establish control channel
+	conn, err := conn.Dial(c.serverAddr, "ctl", tlsConfig)
+	if err != nil {
+		panic(err)
+	}
+	defer conn.Close()
+
+	// register with the server
+	reg["OS"] = runtime.GOOS
+	reg["ClientId"] = c.id
+	reg["Version"] = version.Proto
+	reg["MmVersion"] = version.MajorMinor()
+	reg["User"] = c.authtoken
+
+	if err != nil {
+		panic(err)
+	}
+
+	// wait for the server to ack our register
+	var regAck msg.RegAckMsg
+	if err = msg.ReadMsgInto(conn, &regAck); err != nil {
+		panic(err)
+	}
+
+	if regAck.Error != "" {
+		emsg := fmt.Sprintf("Server failed to allocate tunnel: %s", regAck.Error)
+		c.ctl.Shutdown(emsg)
+		return
+	}
+
+	tunnel := &mvc.Tunnel {
+		PublicUrl: regAck.Url,
+		LocalAddr: localaddr,
+		Protocol: c.protoMap[reg.Protocol],
+	}
+
+	c.tunnels[tunnel.Url] = tunnel
+
+	// update UI state
+	c.id = regAck.ClientId
+	conn.Info("Tunnel established at %v", tunnel.Url)
+	c.status = mvc.ConnOnline
+	c.serverVersion = regAck.MmVersion
+	c.update()
+
+	SaveAuthToken(c.authtoken)
+
+	// start the heartbeat
+	lastPong := time.Now().UnixNano()
+	c.ctl.Go(func() { c.heartbeat(&lastPong, conn) })
+
+	// main control loop
+	for {
+		var rawMsg msg.Message
+		if rawMsg, err = msg.ReadMsg(conn); err != nil {
+			panic(err)
+		}
+
+		switch m := rawMsg.(type) {
+		case *msg.ReqProxyMsg:
+			c.ctl.Go(c.proxy)
+
+		case *msg.PongMsg:
+			atomic.StoreInt64(&lastPong, time.Now().UnixNano())
+		default:
+			conn.Warn("Ignoring unknown control message %v ", m)
+		}
+	}
+}
+
+// Establishes and manages a tunnel proxy connection with the server
+func (c *ClientModel) proxy() {
+	remoteConn, err := conn.Dial(c.serverAddr, "pxy", tlsConfig)
+	if err != nil {
+		log.Error("Failed to establish proxy connection: %v", err)
+		return
+	}
+
+	defer remoteConn.Close()
+	err = msg.WriteMsg(remoteConn, &msg.RegProxyMsg{ClientId: s.id})
+	if err != nil {
+		log.Error("Failed to write RegProxyMsg: %v", err)
+		return
+	}
+
+	// wait for the server to ack our register
+	var startPxyMsg msg.StartProxyMsg
+	if err = msg.ReadMsgInto(remoteConn, &startPxyMsg); err != nil {
+		log.Error("Server failed to write StartProxyMsg: %v", err)
+		return
+	}
+
+	tunnel := tunnels[startPxyMsg.Url]
+	if tunnel == nil {
+		c.Error("Couldn't find tunnel for proxy: %s", startPxyMsg.Url)
+		return
+	}
+
+	// start up the private connection
+	start := time.Now()
+	localConn, err := conn.Dial(tunnel.localaddr, "prv", nil)
+	if err != nil {
+		remoteConn.Warn("Failed to open private leg %s: %v", tunnel.localaddr, err)
+		badGatewayBody := fmt.Sprintf(BadGateway, tunnel.publicUrl, tunnel.localaddr, tunnel.localaddr)
+		remoteConn.Write([]byte(fmt.Sprintf(`HTTP/1.0 502 Bad Gateway
+Content-Type: text/html
+Content-Length: %d
+
+%s`, len(badGatewayBody), badGatewayBody)))
+		return
+	}
+	defer localConn.Close()
+
+	m := c.metrics
+	m.proxySetupTimer.Update(time.Since(start))
+	m.connMeter.Mark(1)
+	c.update()
+	m.connTimer.Time(func() {
+		localConn := tunnel.protocol.WrapConn(localConn)
+		bytesIn, bytesOut := conn.Join(localConn, remoteConn)
+		m.bytesIn.Update(bytesIn)
+		m.bytesOut.Update(bytesOut)
+		m.bytesInCount.Inc(bytesIn)
+		m.bytesOutCount.Inc(bytesOut)
+	})
+	c.update()
+}
+
+// Hearbeating to ensure our connection ngrokd is still live
+func (c *ClientModel) heartbeat(lastPongAddr *int64, conn conn.Conn) {
+	lastPing := time.Unix(atomic.LoadInt64(lastPongAddr)-1, 0)
+	ping := time.NewTicker(pingInterval)
+	pongCheck := time.NewTicker(time.Second)
+
+	defer func() {
+		conn.Close()
+		ping.Stop()
+		pongCheck.Stop()
+	}()
+
+	for {
+		select {
+		case <-pongCheck.C:
+			lastPong := time.Unix(0, atomic.LoadInt64(lastPongAddr))
+			needPong := lastPong.Sub(lastPing) < 0
+			pongLatency := time.Since(lastPing)
+
+			if needPong && pongLatency > maxPongLatency {
+				c.Info("Last ping: %v, Last pong: %v", lastPing, lastPong)
+				c.Info("Connection stale, haven't gotten PongMsg in %d seconds", int(pongLatency.Seconds()))
+				return
+			}
+
+		case <-ping.C:
+			err := msg.WriteMsg(conn, &msg.PingMsg{})
+			if err != nil {
+				conn.Debug("Got error %v when writing PingMsg", err)
+				return
+			}
+			lastPing = time.Now()
+		}
+	}
+}

+ 12 - 0
src/ngrok/client/mvc/controller.go

@@ -0,0 +1,12 @@
+package mvc
+
+type Controller interface {
+	// how the model communicates that it has changed state
+	Update(State)
+
+	// instructs the controller to shut the app down
+	Shutdown(message string)
+
+	// PlayRequest instructs the model to play requests
+	PlayRequest(tunnel *Tunnel, payload []byte)
+}

+ 13 - 0
src/ngrok/client/mvc/model.go

@@ -0,0 +1,13 @@
+package mvc
+
+import (
+	"sync"
+)
+
+type Model interface {
+	Run(opts *Options, ctl Controller)
+
+	Shutdown(wg *sync.WaitGroup)
+
+	PlayRequest(tunnel *Tunnel, payload []byte)
+}

+ 16 - 4
src/ngrok/client/ui/state.go → src/ngrok/client/mvc/state.go

@@ -1,4 +1,4 @@
-package ui
+package mvc
 
 import (
 	metrics "github.com/inconshreveable/go-metrics"
@@ -14,14 +14,26 @@ const (
 	UpdateError
 )
 
+type ConnStatus int
+
+const (
+	ConnConnecting = iota
+	ConnReconnecting
+	ConnOnline
+)
+
+type Tunnel struct {
+	PublicUrl string
+	Protocol proto.Protocol
+	LocalAddr string
+}
+
 type State interface {
 	GetClientVersion() string
 	GetServerVersion() string
 	GetUpdate() UpdateStatus
-	GetPublicUrl() string
-	GetLocalAddr() string
+	GetTunnels() []Tunnel
 	GetStatus() string
-	GetProtocol() proto.Protocol
 	GetWebPort() int
 	GetConnectionMetrics() (metrics.Meter, metrics.Timer)
 	GetBytesInMetrics() (metrics.Counter, metrics.Histogram)

+ 9 - 0
src/ngrok/client/mvc/view.go

@@ -0,0 +1,9 @@
+package mvc
+
+import (
+	"sync"
+)
+
+type View interface {
+	Shutdown(*sync.WaitGroup)
+}

+ 0 - 44
src/ngrok/client/state.go

@@ -1,44 +0,0 @@
-package client
-
-import (
-	metrics "github.com/inconshreveable/go-metrics"
-	"ngrok/client/ui"
-	"ngrok/proto"
-	"ngrok/version"
-)
-
-// client state
-type State struct {
-	id            string
-	publicUrl     string
-	serverVersion string
-	update        ui.UpdateStatus
-	protocol      proto.Protocol
-	opts          *Options
-	metrics       *ClientMetrics
-
-	// just for UI purposes
-	status string
-}
-
-// implement client.ui.State
-func (s State) GetClientVersion() string    { return version.MajorMinor() }
-func (s State) GetServerVersion() string    { return s.serverVersion }
-func (s State) GetLocalAddr() string        { return s.opts.localaddr }
-func (s State) GetWebPort() int             { return s.opts.webport }
-func (s State) GetStatus() string           { return s.status }
-func (s State) GetProtocol() proto.Protocol { return s.protocol }
-func (s State) GetUpdate() ui.UpdateStatus  { return s.update }
-func (s State) GetPublicUrl() string        { return s.publicUrl }
-
-func (s State) GetConnectionMetrics() (metrics.Meter, metrics.Timer) {
-	return s.metrics.connMeter, s.metrics.connTimer
-}
-
-func (s State) GetBytesInMetrics() (metrics.Counter, metrics.Histogram) {
-	return s.metrics.bytesInCount, s.metrics.bytesIn
-}
-
-func (s State) GetBytesOutMetrics() (metrics.Counter, metrics.Histogram) {
-	return s.metrics.bytesOutCount, s.metrics.bytesOut
-}

+ 0 - 13
src/ngrok/client/ui/command.go

@@ -1,13 +0,0 @@
-package ui
-
-type Command interface{}
-
-type CmdQuit struct {
-	// display this message after quit
-	Message string
-}
-
-type CmdRequest struct {
-	// the bytes of the request to issue
-	Payload []byte
-}

+ 0 - 51
src/ngrok/client/ui/controller.go

@@ -1,51 +0,0 @@
-/* The controller in the MVC
- */
-
-package ui
-
-import (
-	"ngrok/util"
-	"sync"
-)
-
-type Controller struct {
-	// the model sends updates through this broadcast channel
-	Updates *util.Broadcast
-
-	// all views put any commands into this channel
-	Cmds chan Command
-
-	// all threads may add themself to this to wait for clean shutdown
-	Wait *sync.WaitGroup
-
-	// channel to signal shutdown
-	Shutdown chan int
-}
-
-func NewController() *Controller {
-	ctl := &Controller{
-		Updates:  util.NewBroadcast(),
-		Cmds:     make(chan Command),
-		Wait:     new(sync.WaitGroup),
-		Shutdown: make(chan int),
-	}
-
-	return ctl
-}
-
-func (ctl *Controller) Update(state State) {
-	ctl.Updates.In() <- state
-}
-
-func (ctl *Controller) DoShutdown() {
-	close(ctl.Shutdown)
-}
-
-func (ctl *Controller) IsShuttingDown() bool {
-	select {
-	case <-ctl.Shutdown:
-		return true
-	default:
-	}
-	return false
-}

+ 0 - 5
src/ngrok/client/ui/view.go

@@ -1,5 +0,0 @@
-package ui
-
-type View interface {
-	Render()
-}

+ 2 - 2
src/ngrok/client/update_debug.go

@@ -3,9 +3,9 @@
 package client
 
 import (
-	"ngrok/client/ui"
+	"ngrok/client/mvc"
 )
 
 // no auto-updating in debug mode
-func autoUpdate(s *State, ctl *ui.Controller, token string) {
+func autoUpdate(ctl *mvc.Controller, token string) {
 }

+ 8 - 8
src/ngrok/client/update_release.go

@@ -6,9 +6,9 @@ import (
 	update "github.com/inconshreveable/go-update"
 	"net/http"
 	"net/url"
-	"ngrok/client/ui"
 	"ngrok/log"
 	"ngrok/version"
+	"ngrok/client/mvc"
 	"runtime"
 	"time"
 )
@@ -17,7 +17,7 @@ const (
 	updateEndpoint = "https://dl.ngrok.com/update"
 )
 
-func autoUpdate(s *State, ctl *ui.Controller, token string) {
+func autoUpdate(ctl mvc.Controller, token string) {
 	update := func() (updateSuccessful bool) {
 		params := make(url.Values)
 		params.Add("version", version.MajorMinor())
@@ -34,7 +34,7 @@ func autoUpdate(s *State, ctl *ui.Controller, token string) {
 						close(downloadComplete)
 						return
 					} else if progress == 100 {
-						s.update = ui.UpdateInstalling
+						s.update = mvc.UpdateInstalling
 						ctl.Update(s)
 						close(downloadComplete)
 						return
@@ -42,7 +42,7 @@ func autoUpdate(s *State, ctl *ui.Controller, token string) {
 						if progress%25 == 0 {
 							log.Info("Downloading update %d%% complete", progress)
 						}
-						s.update = ui.UpdateStatus(progress)
+						s.update = mvc.UpdateStatus(progress)
 						ctl.Update(s)
 					}
 				}
@@ -55,9 +55,9 @@ func autoUpdate(s *State, ctl *ui.Controller, token string) {
 		if err != nil {
 			log.Error("Error while updating ngrok: %v", err)
 			if download.Available {
-				s.update = ui.UpdateError
+				s.update = mvc.UpdateError
 			} else {
-				s.update = ui.UpdateNone
+				s.update = mvc.UpdateNone
 			}
 
 			// record the error to ngrok.com's servers for debugging purposes
@@ -71,11 +71,11 @@ func autoUpdate(s *State, ctl *ui.Controller, token string) {
 		} else {
 			if download.Available {
 				log.Info("Marked update ready")
-				s.update = ui.UpdateReady
+				s.update = mvc.UpdateReady
 				updateSuccessful = true
 			} else {
 				log.Info("No update available at this time")
-				s.update = ui.UpdateNone
+				s.update = mvc.UpdateNone
 			}
 		}
 

+ 16 - 12
src/ngrok/client/views/term/view.go

@@ -6,23 +6,23 @@ package term
 import (
 	"fmt"
 	termbox "github.com/nsf/termbox-go"
-	"ngrok/client/ui"
+	"ngrok/client/mvc"
 	"ngrok/log"
 	"ngrok/proto"
 	"time"
 )
 
 type TermView struct {
-	ctl      *ui.Controller
+	ctl      mvc.Controller
 	updates  chan interface{}
 	flush    chan int
-	subviews []ui.View
-	state    ui.State
+	subviews []mvc.View
+	state    mvc.State
 	log.Logger
 	*area
 }
 
-func New(ctl *ui.Controller, state ui.State) *TermView {
+func New(ctl mvc.Controller, state mvc.State) *TermView {
 	// initialize terminal display
 	termbox.Init()
 
@@ -35,7 +35,7 @@ func New(ctl *ui.Controller, state ui.State) *TermView {
 		ctl:      ctl,
 		updates:  ctl.Updates.Reg(),
 		flush:    make(chan int),
-		subviews: make([]ui.View, 0),
+		subviews: make([]mvc.View, 0),
 		state:    state,
 		Logger:   log.NewPrefixLogger(),
 		area:     NewArea(0, 0, w, 10),
@@ -81,13 +81,13 @@ func (v *TermView) Render() {
 	updateStatus := v.state.GetUpdate()
 	var updateMsg string
 	switch updateStatus {
-	case ui.UpdateNone:
+	case mvc.UpdateNone:
 		updateMsg = ""
-	case ui.UpdateInstalling:
+	case mvc.UpdateInstalling:
 		updateMsg = "ngrok is updating"
-	case ui.UpdateReady:
+	case mvc.UpdateReady:
 		updateMsg = "ngrok has updated: restart ngrok for the new version"
-	case ui.UpdateError:
+	case mvc.UpdateError:
 		updateMsg = "new version available at https://ngrok.com"
 	default:
 		pct := float64(updateStatus) / 100.0
@@ -142,7 +142,7 @@ func (v *TermView) run() {
 
 		case obj := <-v.updates:
 			if obj != nil {
-				v.state = obj.(ui.State)
+				v.state = obj.(mvc.State)
 			}
 			v.Render()
 
@@ -152,6 +152,10 @@ func (v *TermView) run() {
 	}
 }
 
+func (v *TermView) Shutdown(wg *sync.WaitGroup) {
+	wg.Done()
+}
+
 func (v *TermView) input() {
 	for {
 		ev := termbox.PollEvent()
@@ -160,7 +164,7 @@ func (v *TermView) input() {
 			switch ev.Key {
 			case termbox.KeyCtrlC:
 				v.Info("Got quit command")
-				v.ctl.Cmds <- ui.CmdQuit{}
+				ctl.Shutdown()
 			}
 
 		case termbox.EventResize:

+ 2 - 3
src/ngrok/client/views/web/http.go

@@ -10,7 +10,7 @@ import (
 	"net/http/httputil"
 	"net/url"
 	"ngrok/client/assets"
-	"ngrok/client/ui"
+	"ngrok/client/mvc"
 	"ngrok/log"
 	"ngrok/proto"
 	"ngrok/util"
@@ -237,8 +237,7 @@ func (h *WebHttpView) register() {
 			h.ctl.Cmds <- ui.CmdRequest{Payload: bodyBytes}
 			w.Write([]byte(http.StatusText(200)))
 		} else {
-			// XXX: 400
-			http.NotFound(w, r)
+			http.Error(w, http.StatusText(400), 400)
 		}
 	})
 

+ 1 - 1
src/ngrok/client/views/web/view.go

@@ -6,7 +6,7 @@ import (
 	"github.com/garyburd/go-websocket/websocket"
 	"net/http"
 	"ngrok/client/assets"
-	"ngrok/client/ui"
+	"ngrok/client/mvc"
 	"ngrok/log"
 	"ngrok/proto"
 	"ngrok/util"

+ 9 - 0
src/ngrok/conn/conn.go

@@ -16,6 +16,7 @@ type Conn interface {
 	net.Conn
 	log.Logger
 	Id() string
+	SetType(string)
 }
 
 type tcpConn struct {
@@ -103,6 +104,14 @@ func (c *tcpConn) Id() string {
 	return fmt.Sprintf("%s:%x", c.typ, c.id)
 }
 
+func (c *tcpConn) SetType(typ string) {
+	oldId := c.Id()
+	c.typ = typ
+	c.ClearLogPrefixes()
+	c.AddLogPrefix(c.Id())
+	c.Info("Renamed connection %s", oldId)
+}
+
 func Join(c Conn, c2 Conn) (int64, int64) {
 	done := make(chan error)
 	pipe := func(to Conn, from Conn, bytesCopied *int64) {

+ 5 - 0
src/ngrok/log/logger.go

@@ -26,6 +26,7 @@ func LogTo(target string) {
 
 type Logger interface {
 	AddLogPrefix(string)
+	ClearLogPrefixes()
 	Debug(string, ...interface{})
 	Info(string, ...interface{})
 	Warn(string, ...interface{}) error
@@ -75,6 +76,10 @@ func (pl *PrefixLogger) AddLogPrefix(prefix string) {
 	pl.prefix += "[" + prefix + "]"
 }
 
+func (pl *PrefixLogger) ClearLogPrefixes() {
+	pl.prefix = ""
+}
+
 // we should never really use these . . . always prefer logging through a prefix logger
 func Debug(arg0 string, args ...interface{}) {
 	root.Debug(arg0, args...)

+ 6 - 2
src/ngrok/msg/msg.go

@@ -15,6 +15,7 @@ func init() {
 	TypeMap["RegAckMsg"] = t((*RegAckMsg)(nil))
 	TypeMap["RegProxyMsg"] = t((*RegProxyMsg)(nil))
 	TypeMap["ReqProxyMsg"] = t((*ReqProxyMsg)(nil))
+	TypeMap["StartProxyMsg"] = t((*StartProxyMsg)(nil))
 	TypeMap["PingMsg"] = t((*PingMsg)(nil))
 	TypeMap["PongMsg"] = t((*PongMsg)(nil))
 	TypeMap["VerisonMsg"] = t((*VersionMsg)(nil))
@@ -48,14 +49,17 @@ type RegAckMsg struct {
 	Url       string
 	ProxyAddr string
 	Error     string
+	ClientId  string
+}
+
+type ReqProxyMsg struct {
 }
 
 type RegProxyMsg struct {
-	Url      string
 	ClientId string
 }
 
-type ReqProxyMsg struct {
+type StartProxyMsg struct {
 	Url string
 }
 

+ 109 - 3
src/ngrok/server/control.go

@@ -5,8 +5,10 @@ import (
 	"io"
 	"ngrok/conn"
 	"ngrok/msg"
+	"ngrok/util"
 	"ngrok/version"
 	"runtime/debug"
+	"sync/atomic"
 	"time"
 )
 
@@ -37,26 +39,53 @@ type Control struct {
 
 	// all of the tunnels this control connection handles
 	tunnels []*Tunnel
+
+	// proxy connections
+	proxies chan conn.Conn
+
+	// closing indicator
+	closing int32
+
+	// identifier
+	id string
 }
 
-func NewControl(conn conn.Conn, regMsg *msg.RegMsg) {
+func NewControl(ctlConn conn.Conn, regMsg *msg.RegMsg) {
 	// create the object
 	// channels are buffered because we read and write to them
 	// from the same goroutine in managerThread()
 	c := &Control{
-		conn:     conn,
+		conn:     ctlConn,
 		out:      make(chan msg.Message, 5),
 		in:       make(chan msg.Message, 5),
 		stop:     make(chan msg.Message, 5),
+		proxies:  make(chan conn.Conn, 10),
 		lastPing: time.Now(),
 	}
 
+	// assign the random id
+	serverId, err := util.RandId(8)
+	if err != nil {
+		c.stop <- &msg.RegAckMsg{Error: err.Error()}
+	}
+	c.id = fmt.Sprintf("%s-%s", regMsg.ClientId, serverId)
+
+	// register the control
+	err = controlRegistry.Add(c.id, c)
+	if err != nil {
+		c.stop <- &msg.RegAckMsg{Error: err.Error()}
+	}
+
+	// set logging prefix
+	ctlConn.SetType("ctl")
+
 	// register the first tunnel
 	c.in <- regMsg
 
 	// manage the connection
 	go c.managerThread()
 	go c.readThread()
+
 }
 
 // Register a new tunnel on this control connection
@@ -87,6 +116,7 @@ func (c *Control) registerTunnel(regMsg *msg.RegMsg) {
 		ProxyAddr: fmt.Sprintf("%s:%d", opts.domain, opts.tunnelPort),
 		Version:   version.Proto,
 		MmVersion: version.MajorMinor(),
+		ClientId:  c.id,
 	}
 
 	if regMsg.Protocol == "http" {
@@ -105,13 +135,32 @@ func (c *Control) managerThread() {
 			c.conn.Info("Control::managerThread failed with error %v: %s", err, debug.Stack())
 		}
 
+		// remove from the control registry
+		controlRegistry.Del(c.id)
+
+		// mark that we're shutting down
+		atomic.StoreInt32(&c.closing, 1)
+
+		// stop the reaping timer
 		reap.Stop()
+
+		// close the connection
 		c.conn.Close()
 
 		// shutdown all of the tunnels
 		for _, t := range c.tunnels {
-			t.shutdown()
+			t.Shutdown()
 		}
+
+		// we're safe to close(c.proxies) because c.closing
+		// protects us inside of RegisterProxy
+		close(c.proxies)
+
+		// shut down all of the proxy connections
+		for p := range c.proxies {
+			p.Close()
+		}
+
 	}()
 
 	for {
@@ -172,3 +221,60 @@ func (c *Control) readThread() {
 		}
 	}
 }
+
+func (c *Control) RegisterProxy(conn conn.Conn) {
+	if atomic.LoadInt32(&c.closing) == 1 {
+		c.conn.Debug("Can't register proxies for a control that is closing")
+		conn.Close()
+		return
+	}
+
+	select {
+	case c.proxies <- conn:
+		c.conn.Info("Registered proxy connection %s", conn.Id())
+	default:
+		// c.proxies buffer is full, discard this one
+		conn.Close()
+	}
+}
+
+// Remove a proxy connection from the pool and return it
+// If not proxy connections are in the pool, request one
+// and wait until it is available
+// Returns an error if we couldn't get a proxy because it took too long
+// or the tunnel is closing
+func (c *Control) GetProxy() (proxyConn conn.Conn, err error) {
+	// initial timeout is zero to try to get a proxy connection without asking for one
+	timeout := time.NewTimer(0)
+
+	// get a proxy connection. if we timeout, request one over the control channel
+	for proxyConn == nil {
+		var ok bool
+		select {
+		case proxyConn, ok = <-c.proxies:
+			if !ok {
+				err = fmt.Errorf("No proxy connections available, control is closing")
+				return
+			}
+			continue
+		case <-timeout.C:
+			c.conn.Debug("Requesting new proxy connection")
+			// request a proxy connection
+			c.out <- &msg.ReqProxyMsg{}
+			// timeout after 1 second if we don't get one
+			timeout.Reset(1 * time.Second)
+		}
+	}
+
+	// To try to reduce latency hanndling tunnel connections, we employ
+	// the following curde heuristic:
+	// If the proxy connection pool is empty, request a new one.
+	// The idea is to always have at least one proxy connection available for immediate use.
+	// There are two major issues with this strategy: it's not thread safe and it's not predictive.
+	// It should be a good start though.
+	if len(c.proxies) == 0 {
+		c.out <- &msg.ReqProxyMsg{}
+	}
+
+	return
+}

+ 1 - 1
src/ngrok/server/http.go

@@ -74,7 +74,7 @@ func httpHandler(tcpConn net.Conn, proto string) {
 	conn.Debug("Found hostname %s in request", host)
 
 	// multiplex to find the right backend host
-	tunnel := tunnels.Get(fmt.Sprintf("%s://%s", proto, host))
+	tunnel := tunnelRegistry.Get(fmt.Sprintf("%s://%s", proto, host))
 	if tunnel == nil {
 		conn.Info("No tunnel found for hostname %s", host)
 		conn.Write([]byte(fmt.Sprintf(NotFound, len(host)+18, host)))

+ 23 - 17
src/ngrok/server/main.go

@@ -1,18 +1,20 @@
 package server
 
 import (
-	"fmt"
+	"math/rand"
 	"net"
 	"ngrok/conn"
 	log "ngrok/log"
 	"ngrok/msg"
+	"ngrok/util"
 	"os"
 )
 
 // GLOBALS
 var (
 	opts              *Options
-	tunnels           *TunnelRegistry
+	tunnelRegistry    *TunnelRegistry
+	controlRegistry   *ControlRegistry
 	registryCacheSize uint64 = 1024 * 1024 // 1 MB
 	domain            string
 	publicPort        int
@@ -27,22 +29,18 @@ func NewProxy(pxyConn conn.Conn, regPxy *msg.RegProxyMsg) {
 		}
 	}()
 
-	// add log prefix
-	pxyConn.AddLogPrefix("pxy")
+	// set logging prefix
+	pxyConn.SetType("pxy")
 
-	// look up the tunnel for this proxy
-	pxyConn.Info("Registering new proxy for %s", regPxy.Url)
-	tunnel := tunnels.Get(regPxy.Url)
-	if tunnel == nil {
-		panic("No tunnel found for: " + regPxy.Url)
-	}
+	// look up the control connection for this proxy
+	pxyConn.Info("Registering new proxy for %s", regPxy.ClientId)
+	ctl := controlRegistry.Get(regPxy.ClientId)
 
-	if regPxy.ClientId != tunnel.regMsg.ClientId {
-		panic(fmt.Sprintf("Client identifier %s does not match tunnel's %s", regPxy.ClientId, tunnel.regMsg.ClientId))
+	if ctl == nil {
+		panic("No client found for identifier: " + regPxy.ClientId)
 	}
 
-	// register the proxy connection with the tunnel
-	tunnel.RegisterProxy(pxyConn)
+	ctl.RegisterProxy(pxyConn)
 }
 
 // Listen for incoming control and proxy connections
@@ -52,7 +50,7 @@ func NewProxy(pxyConn conn.Conn, regPxy *msg.RegProxyMsg) {
 // restrictive firewalls.
 func tunnelListener(addr *net.TCPAddr, domain string) {
 	// listen for incoming connections
-	listener, err := conn.Listen(addr, "ctl", tlsConfig)
+	listener, err := conn.Listen(addr, "tun", tlsConfig)
 	if err != nil {
 		panic(err)
 	}
@@ -82,9 +80,17 @@ func Main() {
 	// init logging
 	log.LogTo(opts.logto)
 
-	// init tunnel registry
+	// seed random number generator
+	seed, err := util.RandomSeed()
+	if err != nil {
+		panic(err)
+	}
+	rand.Seed(seed)
+
+	// init tunnel/control registry
 	registryCacheFile := os.Getenv("REGISTRY_CACHE_FILE")
-	tunnels = NewTunnelRegistry(registryCacheSize, registryCacheFile)
+	tunnelRegistry = NewTunnelRegistry(registryCacheSize, registryCacheFile)
+	controlRegistry = NewControlRegistry()
 
 	// ngrok clients
 	go tunnelListener(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: opts.tunnelPort}, opts.domain)

+ 45 - 1
src/ngrok/server/registry.go

@@ -32,7 +32,7 @@ func NewTunnelRegistry(cacheSize uint64, cacheFile string) *TunnelRegistry {
 	registry := &TunnelRegistry{
 		tunnels:  make(map[string]*Tunnel),
 		affinity: cache.NewLRUCache(cacheSize),
-		Logger:   log.NewPrefixLogger("registry"),
+		Logger:   log.NewPrefixLogger("registry", "tun"),
 	}
 
 	// LRUCache uses Gob encoding. Unfortunately, Gob is fickle and will fail
@@ -159,3 +159,47 @@ func (r *TunnelRegistry) Get(url string) *Tunnel {
 	defer r.RUnlock()
 	return r.tunnels[url]
 }
+
+// ControlRegistry maps a client ID to Control structures
+type ControlRegistry struct {
+	controls map[string]*Control
+	log.Logger
+	sync.RWMutex
+}
+
+func NewControlRegistry() *ControlRegistry {
+	return &ControlRegistry{
+		controls: make(map[string]*Control),
+		Logger:   log.NewPrefixLogger("registry", "ctl"),
+	}
+}
+
+func (r *ControlRegistry) Get(clientId string) *Control {
+	r.RLock()
+	defer r.RUnlock()
+	return r.controls[clientId]
+}
+
+func (r *ControlRegistry) Add(clientId string, ctl *Control) error {
+	r.Lock()
+	defer r.Unlock()
+	if r.controls[clientId] == nil {
+		r.Info("Registered control with id %s", clientId)
+		r.controls[clientId] = ctl
+		return nil
+	} else {
+		return fmt.Errorf("Client with id %s already registered!", clientId)
+	}
+}
+
+func (r *ControlRegistry) Del(clientId string) error {
+	r.Lock()
+	defer r.Unlock()
+	if r.controls[clientId] == nil {
+		return fmt.Errorf("No control found for client id: %s", clientId)
+	} else {
+		r.Info("Removed control registry id %s", clientId)
+		delete(r.controls, clientId)
+		return nil
+	}
+}

+ 22 - 64
src/ngrok/server/tunnel.go

@@ -41,9 +41,6 @@ type Tunnel struct {
 	// control connection
 	ctl *Control
 
-	// proxy connections
-	proxies chan conn.Conn
-
 	// logger
 	log.Logger
 
@@ -73,18 +70,18 @@ func registerVhost(t *Tunnel, protocol string, servingPort int) (err error) {
 	hostname := strings.TrimSpace(t.regMsg.Hostname)
 	if hostname != "" {
 		t.url = fmt.Sprintf("%s://%s", protocol, hostname)
-		return tunnels.Register(t.url, t)
+		return tunnelRegistry.Register(t.url, t)
 	}
 
 	// Register for specific subdomain
 	subdomain := strings.TrimSpace(t.regMsg.Subdomain)
 	if subdomain != "" {
 		t.url = fmt.Sprintf("%s://%s.%s", protocol, subdomain, vhost)
-		return tunnels.Register(t.url, t)
+		return tunnelRegistry.Register(t.url, t)
 	}
 
 	// Register for random URL
-	t.url, err = tunnels.RegisterRepeat(func() string {
+	t.url, err = tunnelRegistry.RegisterRepeat(func() string {
 		return fmt.Sprintf("%s://%x.%s", protocol, rand.Int31(), vhost)
 	}, t)
 
@@ -95,11 +92,10 @@ func registerVhost(t *Tunnel, protocol string, servingPort int) (err error) {
 // on a control channel
 func NewTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel, err error) {
 	t = &Tunnel{
-		regMsg:  m,
-		start:   time.Now(),
-		ctl:     ctl,
-		proxies: make(chan conn.Conn, 10),
-		Logger:  log.NewPrefixLogger(),
+		regMsg: m,
+		start:  time.Now(),
+		ctl:    ctl,
+		Logger: log.NewPrefixLogger(),
 	}
 
 	switch t.regMsg.Protocol {
@@ -107,7 +103,7 @@ func NewTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel, err error) {
 		var port int = 0
 
 		// try to return to you the same port you had before
-		cachedUrl := tunnels.GetCachedRegistration(t)
+		cachedUrl := tunnelRegistry.GetCachedRegistration(t)
 		if cachedUrl != "" {
 			parts := strings.Split(cachedUrl, ":")
 			portPart := parts[len(parts)-1]
@@ -139,7 +135,7 @@ func NewTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel, err error) {
 		t.url = fmt.Sprintf("tcp://%s:%d", domain, addr.Port)
 
 		// register it
-		if err = tunnels.RegisterAndCache(t.url, t); err != nil {
+		if err = tunnelRegistry.RegisterAndCache(t.url, t); err != nil {
 			// This should never be possible because the OS will
 			// only assign available ports to us.
 			t.listener.Close()
@@ -177,7 +173,7 @@ func NewTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel, err error) {
 	return
 }
 
-func (t *Tunnel) shutdown() {
+func (t *Tunnel) Shutdown() {
 	t.Info("Shutting down")
 
 	// mark that we're shutting down
@@ -189,22 +185,13 @@ func (t *Tunnel) shutdown() {
 	}
 
 	// remove ourselves from the tunnel registry
-	tunnels.Del(t.url)
+	tunnelRegistry.Del(t.url)
 
 	// let the control connection know we're shutting down
 	// currently, only the control connection shuts down tunnels,
 	// so it doesn't need to know about it
 	// t.ctl.stoptunnel <- t
 
-	// we're safe to close(t.proxies) because t.closing
-	// protects us inside of RegisterProxy
-	close(t.proxies)
-
-	// shut down all of the proxy connections
-	for c := range t.proxies {
-		c.Close()
-	}
-
 	metrics.CloseTunnel(t)
 }
 
@@ -253,49 +240,20 @@ func (t *Tunnel) HandlePublicConnection(publicConn conn.Conn) {
 	startTime := time.Now()
 	metrics.OpenConnection(t, publicConn)
 
-	// initial timeout is zero to try to get a proxy connection without asking for one
-	timeout := time.NewTimer(0)
-	var proxyConn conn.Conn
-
-	// get a proxy connection. if we timeout, request one over the control channel
-	for proxyConn == nil {
-		var ok bool
-		select {
-		case proxyConn, ok = <-t.proxies:
-			if !ok {
-				publicConn.Info("Dropping connection because tunnel is shutting down")
-				return
-			}
-			continue
-		case <-timeout.C:
-			t.Debug("Requesting new proxy connection")
-			// request a proxy connection
-			t.ctl.out <- &msg.ReqProxyMsg{Url: t.url}
-			// timeout after 1 second if we don't get one
-			timeout.Reset(1 * time.Second)
-		}
+	// get a proxy connection
+	proxyConn, err := t.ctl.GetProxy()
+	if err != nil {
+		t.Warn("Failed to get proxy connection: %v", err)
+		return
 	}
+	defer proxyConn.Close()
 	t.Info("Got proxy connection %s", proxyConn.Id())
+	proxyConn.AddLogPrefix(t.Id())
 
-	defer proxyConn.Close()
-	bytesIn, bytesOut := conn.Join(publicConn, proxyConn)
+	// tell the client we're going to start using this proxy connection
+	msg.WriteMsg(proxyConn, &msg.StartProxyMsg{Url: t.url})
 
+	// join the public and proxy connections
+	bytesIn, bytesOut := conn.Join(publicConn, proxyConn)
 	metrics.CloseConnection(t, publicConn, startTime, bytesIn, bytesOut)
 }
-
-func (t *Tunnel) RegisterProxy(conn conn.Conn) {
-	if atomic.LoadInt32(&t.closing) == 1 {
-		t.Debug("Can't register proxies for a tunnel that is closing")
-		conn.Close()
-		return
-	}
-
-	t.Info("Registered proxy connection %s", conn.Id())
-	conn.AddLogPrefix(t.Id())
-	select {
-	case t.proxies <- conn:
-	default:
-		// t.proxies buffer is full, discard this one
-		conn.Close()
-	}
-}

+ 20 - 0
src/ngrok/util/id.go

@@ -5,6 +5,26 @@ import (
 	"fmt"
 )
 
+func RandomSeed() (int64, error) {
+	b := make([]byte, 8)
+	n, err := rand.Read(b)
+	if n != 8 {
+		return 0, fmt.Errorf("Only generated %d random bytes, %d requested", n, 8)
+	}
+
+	if err != nil {
+		return 0, err
+	}
+
+	var seed int64
+	var i uint
+	for i = 0; i < 8; i++ {
+		seed = seed | int64(b[i]<<(i*8))
+	}
+
+	return seed, nil
+}
+
 // create a random identifier for this client
 func RandId(idlen int) (id string, err error) {
 	b := make([]byte, idlen)