Browse Source

Major ngrokd server refactor improving:
No longer fail to handle a public connection if a proxy request goes unanswered by requesting additional proxy connections
Buffer tunnel proxies channel so we allow a pool of available connections to be used
Don't tie up goroutines if proxy connection pool are full
Allow a single control connection to manage multiple tunnels
Fix a security issue where anyone could could register a proxy for any tunnel
Natively handle serving HTTPS traffic
Multiplex a single port for handling control and proxy connections in order to evade corporate firewalls
Tunnels now shut down all of their resources gracefully

Alan Shreve 12 years ago
parent
commit
3d17d52659

+ 8 - 14
src/ngrok/client/main.go

@@ -33,25 +33,17 @@ const (
 /**
  * Establishes and manages a tunnel proxy connection with the server
  */
-func proxy(proxyAddr string, s *State, ctl *ui.Controller) {
+func proxy(proxyAddr string, url string, s *State, ctl *ui.Controller) {
 	start := time.Now()
 	remoteConn, err := conn.Dial(proxyAddr, "pxy", tlsConfig)
 	if err != nil {
-		// XXX: What is the proper response here?
-		// display something to the user?
-		// retry?
-		// reset control connection?
 		log.Error("Failed to establish proxy connection: %v", err)
 		return
 	}
 
 	defer remoteConn.Close()
-	err = msg.WriteMsg(remoteConn, &msg.RegProxyMsg{Url: s.publicUrl})
+	err = msg.WriteMsg(remoteConn, &msg.RegProxyMsg{Url: url, ClientId: s.id})
 	if err != nil {
-		// XXX: What is the proper response here?
-		// display something to the user?
-		// retry?
-		// reset control connection?
 		log.Error("Failed to write RegProxyMsg: %v", err)
 		return
 	}
@@ -205,17 +197,19 @@ func control(s *State, ctl *ui.Controller) {
 
 	// main control loop
 	for {
-		var m msg.Message
-		if m, err = msg.ReadMsg(conn); err != nil {
+		var rawMsg msg.Message
+		if rawMsg, err = msg.ReadMsg(conn); err != nil {
 			panic(err)
 		}
 
-		switch m.(type) {
+		switch m := rawMsg.(type) {
 		case *msg.ReqProxyMsg:
-			go proxy(regAck.ProxyAddr, s, ctl)
+			go proxy(regAck.ProxyAddr, m.Url, s, ctl)
 
 		case *msg.PongMsg:
 			atomic.StoreInt64(&lastPong, time.Now().UnixNano())
+		default:
+			conn.Warn("Ignoring unknown control message %v ", m)
 		}
 	}
 }

+ 3 - 1
src/ngrok/msg/msg.go

@@ -51,10 +51,12 @@ type RegAckMsg struct {
 }
 
 type RegProxyMsg struct {
-	Url string
+	Url      string
+	ClientId string
 }
 
 type ReqProxyMsg struct {
+	Url string
 }
 
 type PingMsg struct {

+ 34 - 0
src/ngrok/server/cli.go

@@ -0,0 +1,34 @@
+package server
+
+import (
+	"flag"
+)
+
+type Options struct {
+	httpPort   int
+	httpsPort  int
+	tunnelPort int
+	domain     string
+	logto      string
+}
+
+func parseArgs() *Options {
+	httpPort := flag.Int("httpPort", 80, "Public HTTP port, -1 to disable")
+	httpsPort := flag.Int("httpsPort", 443, "Public HTTPS port, -1 to disable")
+	tunnelPort := flag.Int("tunnelPort", 4443, "Port to which ngrok clients connect")
+	domain := flag.String("domain", "ngrok.com", "Domain where the tunnels are hosted")
+	logto := flag.String(
+		"log",
+		"stdout",
+		"Write log messages to this file. 'stdout' and 'none' have special meanings")
+
+	flag.Parse()
+
+	return &Options{
+		httpPort:   *httpPort,
+		httpsPort:  *httpsPort,
+		tunnelPort: *tunnelPort,
+		domain:     *domain,
+		logto:      *logto,
+	}
+}

+ 74 - 21
src/ngrok/server/control.go

@@ -1,6 +1,7 @@
 package server
 
 import (
+	"fmt"
 	"io"
 	"ngrok/conn"
 	"ngrok/msg"
@@ -18,31 +19,83 @@ type Control struct {
 	// actual connection
 	conn conn.Conn
 
-	// channels for communicating messages over the connection
-	out  chan (interface{})
-	in   chan (msg.Message)
+	// put a message in this channel to send it over
+	// conn to the client
+	out chan (msg.Message)
+
+	// read from this channel to get the next message sent
+	// to us over conn by the client
+	in chan (msg.Message)
+
+	// put a message in this channel to send it over
+	// conn to the client and then terminate this
+	// control connection and all of its tunnels
 	stop chan (msg.Message)
 
-	// heartbeat
+	// the last time we received a ping from the client - for heartbeats
 	lastPing time.Time
 
-	// tunnel
-	tun *Tunnel
+	// all of the tunnels this control connection handles
+	tunnels []*Tunnel
 }
 
-func NewControl(conn conn.Conn) {
+func NewControl(conn conn.Conn, regMsg *msg.RegMsg) {
+	// create the object
+	// channels are buffered because we read and write to them
+	// from the same goroutine in managerThread()
 	c := &Control{
 		conn:     conn,
-		out:      make(chan (interface{}), 1),
-		in:       make(chan (msg.Message), 1),
-		stop:     make(chan (msg.Message), 1),
+		out:      make(chan msg.Message, 5),
+		in:       make(chan msg.Message, 5),
+		stop:     make(chan msg.Message, 5),
 		lastPing: time.Now(),
 	}
 
+	// register the first tunnel
+	c.in <- regMsg
+
+	// manage the connection
 	go c.managerThread()
 	go c.readThread()
 }
 
+// Register a new tunnel on this control connection
+func (c *Control) registerTunnel(regMsg *msg.RegMsg) {
+	c.conn.Debug("Registering new tunnel")
+	t, err := NewTunnel(regMsg, c)
+	if err != nil {
+		ack := &msg.RegAckMsg{Error: err.Error()}
+		if len(c.tunnels) == 0 {
+			// you can't fail your first tunnel registration
+			// terminate the control connection
+			c.stop <- ack
+		} else {
+			// inform client of failure
+			c.out <- ack
+		}
+
+		// we're done
+		return
+	}
+
+	// add it to the list of tunnels
+	c.tunnels = append(c.tunnels, t)
+
+	// acknowledge success
+	c.out <- &msg.RegAckMsg{
+		Url:       t.url,
+		ProxyAddr: fmt.Sprintf("%s:%d", opts.domain, opts.tunnelPort),
+		Version:   version.Proto,
+		MmVersion: version.MajorMinor(),
+	}
+
+	if regMsg.Protocol == "http" {
+		httpsRegMsg := *regMsg
+		httpsRegMsg.Protocol = "https"
+		c.in <- &httpsRegMsg
+	}
+}
+
 func (c *Control) managerThread() {
 	reap := time.NewTicker(connReapInterval)
 
@@ -51,12 +104,13 @@ func (c *Control) managerThread() {
 		if err := recover(); err != nil {
 			c.conn.Info("Control::managerThread failed with error %v: %s", err, debug.Stack())
 		}
+
 		reap.Stop()
 		c.conn.Close()
 
-		// shutdown the tunnel if it's open
-		if c.tun != nil {
-			c.tun.shutdown()
+		// shutdown all of the tunnels
+		for _, t := range c.tunnels {
+			t.shutdown()
 		}
 	}()
 
@@ -65,23 +119,22 @@ func (c *Control) managerThread() {
 		case m := <-c.out:
 			msg.WriteMsg(c.conn, m)
 
-		case <-reap.C:
-			if time.Since(c.lastPing) > pingTimeoutInterval {
-				c.conn.Info("Lost heartbeat")
-				return
-			}
-
 		case m := <-c.stop:
 			if m != nil {
 				msg.WriteMsg(c.conn, m)
 			}
 			return
 
+		case <-reap.C:
+			if time.Since(c.lastPing) > pingTimeoutInterval {
+				c.conn.Info("Lost heartbeat")
+				return
+			}
+
 		case mRaw := <-c.in:
 			switch m := mRaw.(type) {
 			case *msg.RegMsg:
-				c.conn.Info("Registering new tunnel")
-				c.tun = newTunnel(m, c)
+				c.registerTunnel(m)
 
 			case *msg.PingMsg:
 				c.lastPing = time.Now()

+ 15 - 13
src/ngrok/server/http.go

@@ -1,6 +1,7 @@
 package server
 
 import (
+	"crypto/tls"
 	"fmt"
 	"net"
 	"ngrok/conn"
@@ -28,26 +29,27 @@ Bad Request
 `
 )
 
-/**
- * Listens for new http connections from the public internet
- */
-func httpListener(addr *net.TCPAddr) {
+// Listens for new http(s) connections from the public internet
+func httpListener(addr *net.TCPAddr, tlsCfg *tls.Config) {
 	// bind/listen for incoming connections
-	listener, err := conn.Listen(addr, "pub", nil)
+	listener, err := conn.Listen(addr, "pub", tlsCfg)
 	if err != nil {
 		panic(err)
 	}
 
-	log.Info("Listening for public http connections on %v", listener.Port)
+	proto := "http"
+	if tlsCfg != nil {
+		proto = "https"
+	}
+
+	log.Info("Listening for public %s connections on %v", proto, listener.Port)
 	for conn := range listener.Conns {
-		go httpHandler(conn)
+		go httpHandler(conn, proto)
 	}
 }
 
-/**
- * Handles a new http connection from the public internet
- */
-func httpHandler(tcpConn net.Conn) {
+// Handles a new http connection from the public internet
+func httpHandler(tcpConn net.Conn, proto string) {
 	// wrap up the connection for logging
 	conn := conn.NewHttp(tcpConn, "pub")
 
@@ -62,7 +64,7 @@ func httpHandler(tcpConn net.Conn) {
 	// read out the http request
 	req, err := conn.ReadRequest()
 	if err != nil {
-		conn.Warn("Failed to read valid http request: %v", err)
+		conn.Warn("Failed to read valid %s request: %v", proto, err)
 		conn.Write([]byte(BadRequest))
 		return
 	}
@@ -72,7 +74,7 @@ func httpHandler(tcpConn net.Conn) {
 	conn.Debug("Found hostname %s in request", host)
 
 	// multiplex to find the right backend host
-	tunnel := tunnels.Get("http://" + host)
+	tunnel := tunnels.Get(fmt.Sprintf("%s://%s", proto, host))
 	if tunnel == nil {
 		conn.Info("No tunnel found for hostname %s", host)
 		conn.Write([]byte(fmt.Sprintf(NotFound, len(host)+18, host)))

+ 59 - 82
src/ngrok/server/main.go

@@ -1,7 +1,6 @@
 package server
 
 import (
-	"flag"
 	"fmt"
 	"net"
 	"ngrok/conn"
@@ -10,106 +9,75 @@ import (
 	"os"
 )
 
-type Options struct {
-	publicPort int
-	proxyPort  int
-	tunnelPort int
-	domain     string
-	logto      string
-}
-
-/* GLOBALS */
+// GLOBALS
 var (
-	proxyAddr         string
+	opts              *Options
 	tunnels           *TunnelRegistry
 	registryCacheSize uint64 = 1024 * 1024 // 1 MB
 	domain            string
 	publicPort        int
 )
 
-func parseArgs() *Options {
-	publicPort := flag.Int("publicport", 80, "Public port")
-	tunnelPort := flag.Int("tunnelport", 4443, "Tunnel port")
-	proxyPort := flag.Int("proxyPort", 0, "Proxy port")
-	domain := flag.String("domain", "ngrok.com", "Domain where the tunnels are hosted")
-	logto := flag.String(
-		"log",
-		"stdout",
-		"Write log messages to this file. 'stdout' and 'none' have special meanings")
-
-	flag.Parse()
-
-	return &Options{
-		publicPort: *publicPort,
-		tunnelPort: *tunnelPort,
-		proxyPort:  *proxyPort,
-		domain:     *domain,
-		logto:      *logto,
+func NewProxy(pxyConn conn.Conn, regPxy *msg.RegProxyMsg) {
+	// fail gracefully if the proxy connection fails to register
+	defer func() {
+		if r := recover(); r != nil {
+			pxyConn.Warn("Failed with error: %v", r)
+			pxyConn.Close()
+		}
+	}()
+
+	// add log prefix
+	pxyConn.AddLogPrefix("pxy")
+
+	// look up the tunnel for this proxy
+	pxyConn.Info("Registering new proxy for %s", regPxy.Url)
+	tunnel := tunnels.Get(regPxy.Url)
+	if tunnel == nil {
+		panic("No tunnel found for: " + regPxy.Url)
 	}
-}
 
-/**
- * Listens for new control connections from tunnel clients
- */
-func controlListener(addr *net.TCPAddr, domain string) {
-	// listen for incoming connections
-	listener, err := conn.Listen(addr, "ctl", tlsConfig)
-	if err != nil {
-		panic(err)
+	if regPxy.ClientId != tunnel.regMsg.ClientId {
+		panic(fmt.Sprintf("Client identifier %s does not match tunnel's %s", regPxy.ClientId, tunnel.regMsg.ClientId))
 	}
 
-	log.Info("Listening for control connections on %d", listener.Port)
-	for c := range listener.Conns {
-		NewControl(c)
-	}
+	// register the proxy connection with the tunnel
+	tunnel.RegisterProxy(pxyConn)
 }
 
-/**
- * Listens for new proxy connections from tunnel clients
- */
-func proxyListener(addr *net.TCPAddr, domain string) {
-	listener, err := conn.Listen(addr, "pxy", tlsConfig)
+// Listen for incoming control and proxy connections
+// We listen for incoming control and proxy connections on the same port
+// for ease of deployment. The hope is that by running on port 443, using
+// TLS and running all connections over the same port, we can bust through
+// restrictive firewalls.
+func tunnelListener(addr *net.TCPAddr, domain string) {
+	// listen for incoming connections
+	listener, err := conn.Listen(addr, "ctl", tlsConfig)
 	if err != nil {
 		panic(err)
 	}
 
-	// set global proxy addr variable
-	proxyAddr = fmt.Sprintf("%s:%d", domain, listener.Port)
-	log.Info("Listening for proxy connection on %d", listener.Port)
-	for proxyConn := range listener.Conns {
-		go func(conn conn.Conn) {
-			// fail gracefully if the proxy connection dies
-			defer func() {
-				if r := recover(); r != nil {
-					conn.Warn("Failed with error: %v", r)
-					conn.Close()
-				}
-			}()
-
-			// read the proxy register message
-			var regPxy msg.RegProxyMsg
-			if err = msg.ReadMsgInto(conn, &regPxy); err != nil {
-				panic(err)
-			}
-
-			// look up the tunnel for this proxy
-			conn.Info("Registering new proxy for %s", regPxy.Url)
-			tunnel := tunnels.Get(regPxy.Url)
-			if tunnel == nil {
-				panic("No tunnel found for: " + regPxy.Url)
-			}
-
-			// register the proxy connection with the tunnel
-			tunnel.RegisterProxy(conn)
-		}(proxyConn)
+	log.Info("Listening for control and proxy connections on %d", listener.Port)
+	for c := range listener.Conns {
+		var rawMsg msg.Message
+		if rawMsg, err = msg.ReadMsg(c); err != nil {
+			c.Error("Failed to read message: %v", err)
+			c.Close()
+		}
+
+		switch m := rawMsg.(type) {
+		case *msg.RegMsg:
+			go NewControl(c, m)
+
+		case *msg.RegProxyMsg:
+			go NewProxy(c, m)
+		}
 	}
 }
 
 func Main() {
 	// parse options
-	opts := parseArgs()
-	domain = opts.domain
-	publicPort = opts.publicPort
+	opts = parseArgs()
 
 	// init logging
 	log.LogTo(opts.logto)
@@ -118,9 +86,18 @@ func Main() {
 	registryCacheFile := os.Getenv("REGISTRY_CACHE_FILE")
 	tunnels = NewTunnelRegistry(registryCacheSize, registryCacheFile)
 
-	go proxyListener(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: opts.proxyPort}, opts.domain)
-	go controlListener(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: opts.tunnelPort}, opts.domain)
-	go httpListener(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: opts.publicPort})
+	// ngrok clients
+	go tunnelListener(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: opts.tunnelPort}, opts.domain)
+
+	// listen for http
+	if opts.httpPort != -1 {
+		go httpListener(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: opts.httpPort}, nil)
+	}
+
+	// listen for https
+	if opts.httpsPort != -1 {
+		go httpListener(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: opts.httpsPort}, tlsConfig)
+	}
 
 	// wait forever
 	done := make(chan int)

+ 1 - 2
src/ngrok/server/registry.go

@@ -42,14 +42,13 @@ func NewTunnelRegistry(cacheSize uint64, cacheFile string) *TunnelRegistry {
 	var urlobj cacheUrl
 	gob.Register(urlobj)
 
+	// try to load and then periodically save the affinity cache to file, if specified
 	if cacheFile != "" {
-		// load cache entries from file
 		err := registry.affinity.LoadItemsFromFile(cacheFile)
 		if err != nil {
 			registry.Error("Failed to load affinity cache %s: %v", cacheFile, err)
 		}
 
-		// save cache periodically to file
 		registry.SaveCacheThread(cacheFile, cacheSaveInterval)
 	} else {
 		registry.Info("No affinity cache specified")

+ 109 - 58
src/ngrok/server/tunnel.go

@@ -16,6 +16,12 @@ import (
 	"time"
 )
 
+var defaultPortMap = map[string]int{
+	"http":  80,
+	"https": 443,
+	"smtp":  25,
+}
+
 /**
  * Tunnel: A control connection, metadata and proxy connections which
  *         route public traffic to a firewalled endpoint.
@@ -45,21 +51,57 @@ type Tunnel struct {
 	closing int32
 }
 
-func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
+// Common functionality for registering virtually hosted protocols
+func registerVhost(t *Tunnel, protocol string, servingPort int) (err error) {
+	vhost := os.Getenv("VHOST")
+	if vhost == "" {
+		vhost = fmt.Sprintf("%s:%d", opts.domain, servingPort)
+	}
+
+	// Canonicalize virtual host by removing default port (e.g. :80 on HTTP)
+	defaultPort, ok := defaultPortMap[protocol]
+	if !ok {
+		return fmt.Errorf("Couldn't find default port for protocol %s", protocol)
+	}
+
+	defaultPortSuffix := fmt.Sprintf(":%d", defaultPort)
+	if strings.HasSuffix(vhost, defaultPortSuffix) {
+		vhost = vhost[0 : len(vhost)-len(defaultPortSuffix)]
+	}
+
+	// Register for specific hostname
+	hostname := strings.TrimSpace(t.regMsg.Hostname)
+	if hostname != "" {
+		t.url = fmt.Sprintf("%s://%s", protocol, hostname)
+		return tunnels.Register(t.url, t)
+	}
+
+	// Register for specific subdomain
+	subdomain := strings.TrimSpace(t.regMsg.Subdomain)
+	if subdomain != "" {
+		t.url = fmt.Sprintf("%s://%s.%s", protocol, subdomain, vhost)
+		return tunnels.Register(t.url, t)
+	}
+
+	// Register for random URL
+	t.url, err = tunnels.RegisterRepeat(func() string {
+		return fmt.Sprintf("%s://%x.%s", protocol, rand.Int31(), vhost)
+	}, t)
+
+	return
+}
+
+// Create a new tunnel from aregistration message received
+// on a control channel
+func NewTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel, err error) {
 	t = &Tunnel{
 		regMsg:  m,
 		start:   time.Now(),
 		ctl:     ctl,
-		proxies: make(chan conn.Conn),
+		proxies: make(chan conn.Conn, 10),
 		Logger:  log.NewPrefixLogger(),
 	}
 
-	failReg := func(err error) {
-		t.ctl.stop <- &msg.RegAckMsg{Error: err.Error()}
-	}
-
-	var err error
-
 	switch t.regMsg.Protocol {
 	case "tcp":
 		var port int = 0
@@ -88,7 +130,7 @@ func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
 
 		// we tried to bind with a random port and failed (no more ports available?)
 		if err != nil {
-			failReg(t.ctl.conn.Error("Error binding TCP listener: %v", err))
+			err = t.ctl.conn.Error("Error binding TCP listener: %v", err)
 			return
 		}
 
@@ -101,48 +143,25 @@ func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
 			// This should never be possible because the OS will
 			// only assign available ports to us.
 			t.listener.Close()
-			failReg(fmt.Errorf("TCP listener bound, but failed to register %s", t.url))
+			err = fmt.Errorf("TCP listener bound, but failed to register %s", t.url)
 			return
 		}
 
 		go t.listenTcp(t.listener)
 
 	case "http":
-		vhost := os.Getenv("VHOST")
-		if vhost == "" {
-			vhost = fmt.Sprintf("%s:%d", domain, publicPort)
-		}
-
-		// Canonicalize virtual host on default port 80
-		if strings.HasSuffix(vhost, ":80") {
-			vhost = vhost[0 : len(vhost)-3]
-		}
-
-		if strings.TrimSpace(t.regMsg.Hostname) != "" {
-			t.url = fmt.Sprintf("http://%s", t.regMsg.Hostname)
-		} else if strings.TrimSpace(t.regMsg.Subdomain) != "" {
-			t.url = fmt.Sprintf("http://%s.%s", t.regMsg.Subdomain, vhost)
+		if err = registerVhost(t, "http", opts.httpPort); err != nil {
+			return
 		}
 
-		if t.url != "" {
-			if err := tunnels.Register(t.url, t); err != nil {
-				failReg(err)
-				return
-			}
-		} else {
-			t.url, err = tunnels.RegisterRepeat(func() string {
-				return fmt.Sprintf("http://%x.%s", rand.Int31(), vhost)
-			}, t)
-
-			if err != nil {
-				failReg(err)
-				return
-			}
+	case "https":
+		if err = registerVhost(t, "https", opts.httpsPort); err != nil {
+			return
 		}
 	}
 
 	if m.Version != version.Proto {
-		failReg(fmt.Errorf("Incompatible versions. Server %s, client %s. Download a new version at http://ngrok.com", version.MajorMinor(), m.Version))
+		err = fmt.Errorf("Incompatible versions. Server %s, client %s. Download a new version at http://ngrok.com", version.MajorMinor(), m.Version)
 		return
 	}
 
@@ -151,15 +170,8 @@ func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
 		m.HttpAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(m.HttpAuth))
 	}
 
-	t.ctl.conn.AddLogPrefix(t.Id())
 	t.AddLogPrefix(t.Id())
-	t.Info("Registered new tunnel")
-	t.ctl.out <- &msg.RegAckMsg{
-		Url:       t.url,
-		ProxyAddr: fmt.Sprintf("%s", proxyAddr),
-		Version:   version.Proto,
-		MmVersion: version.MajorMinor(),
-	}
+	t.Info("Registered new tunnel on: %s", t.ctl.conn.Id())
 
 	metrics.OpenTunnel(t)
 	return
@@ -171,7 +183,7 @@ func (t *Tunnel) shutdown() {
 	// mark that we're shutting down
 	atomic.StoreInt32(&t.closing, 1)
 
-	// if we have a public listener (this is a raw TCP tunnel, shut it down
+	// if we have a public listener (this is a raw TCP tunnel), shut it down
 	if t.listener != nil {
 		t.listener.Close()
 	}
@@ -179,7 +191,19 @@ func (t *Tunnel) shutdown() {
 	// remove ourselves from the tunnel registry
 	tunnels.Del(t.url)
 
-	// XXX: shut down all of the proxy connections?
+	// let the control connection know we're shutting down
+	// currently, only the control connection shuts down tunnels,
+	// so it doesn't need to know about it
+	// t.ctl.stoptunnel <- t
+
+	// we're safe to close(t.proxies) because t.closing
+	// protects us inside of RegisterProxy
+	close(t.proxies)
+
+	// shut down all of the proxy connections
+	for c := range t.proxies {
+		c.Close()
+	}
 
 	metrics.CloseTunnel(t)
 }
@@ -188,9 +212,7 @@ func (t *Tunnel) Id() string {
 	return t.url
 }
 
-/**
- * Listens for new public tcp connections from the internet.
- */
+// Listens for new public tcp connections from the internet.
 func (t *Tunnel) listenTcp(listener *net.TCPListener) {
 	for {
 		defer func() {
@@ -231,11 +253,29 @@ func (t *Tunnel) HandlePublicConnection(publicConn conn.Conn) {
 	startTime := time.Now()
 	metrics.OpenConnection(t, publicConn)
 
-	t.Debug("Requesting new proxy connection")
-	t.ctl.out <- &msg.ReqProxyMsg{}
-
-	proxyConn := <-t.proxies
-	t.Info("Returning proxy connection %s", proxyConn.Id())
+	// initial timeout is zero to try to get a proxy connection without asking for one
+	timeout := time.NewTimer(0)
+	var proxyConn conn.Conn
+
+	// get a proxy connection. if we timeout, request one over the control channel
+	for proxyConn == nil {
+		var ok bool
+		select {
+		case proxyConn, ok = <-t.proxies:
+			if !ok {
+				publicConn.Info("Dropping connection because tunnel is shutting down")
+				return
+			}
+			continue
+		case <-timeout.C:
+			t.Debug("Requesting new proxy connection")
+			// request a proxy connection
+			t.ctl.out <- &msg.ReqProxyMsg{Url: t.url}
+			// timeout after 1 second if we don't get one
+			timeout.Reset(1 * time.Second)
+		}
+	}
+	t.Info("Got proxy connection %s", proxyConn.Id())
 
 	defer proxyConn.Close()
 	bytesIn, bytesOut := conn.Join(publicConn, proxyConn)
@@ -244,7 +284,18 @@ func (t *Tunnel) HandlePublicConnection(publicConn conn.Conn) {
 }
 
 func (t *Tunnel) RegisterProxy(conn conn.Conn) {
+	if atomic.LoadInt32(&t.closing) == 1 {
+		t.Debug("Can't register proxies for a tunnel that is closing")
+		conn.Close()
+		return
+	}
+
 	t.Info("Registered proxy connection %s", conn.Id())
 	conn.AddLogPrefix(t.Id())
-	t.proxies <- conn
+	select {
+	case t.proxies <- conn:
+	default:
+		// t.proxies buffer is full, discard this one
+		conn.Close()
+	}
 }