Browse Source

fix a critical bug which could cause tunnels to freeze indefinitely and never clean up resources. fix a bug where a client trying to reconnect on an intermittent connection would be told that its tunnels were already registered by another user

Alan Shreve 12 years ago
parent
commit
46dd79a949

+ 135 - 87
src/ngrok/server/control.go

@@ -9,13 +9,13 @@ import (
 	"ngrok/version"
 	"runtime/debug"
 	"strings"
-	"sync/atomic"
 	"time"
 )
 
 const (
 	pingTimeoutInterval = 30 * time.Second
 	connReapInterval    = 10 * time.Second
+	controlWriteTimeout = 10 * time.Second
 )
 
 type Control struct {
@@ -33,11 +33,6 @@ type Control struct {
 	// to us over conn by the client
 	in chan (msg.Message)
 
-	// put a message in this channel to send it over
-	// conn to the client and then terminate this
-	// control connection and all of its tunnels
-	stop chan (msg.Message)
-
 	// the last time we received a ping from the client - for heartbeats
 	lastPing time.Time
 
@@ -47,27 +42,34 @@ type Control struct {
 	// proxy connections
 	proxies chan conn.Conn
 
-	// closing indicator
-	closing int32
-
 	// identifier
 	id string
+
+	// synchronizer for controlled shutdown of writer()
+	writerShutdown *util.Shutdown
+
+	// synchronizer for controlled shutdown of manager()
+	managerShutdown *util.Shutdown
+
+	// synchronizer for controller shutdown of entire Control
+	shutdown *util.Shutdown
+
 }
 
 func NewControl(ctlConn conn.Conn, authMsg *msg.Auth) {
 	var err error
 
 	// create the object
-	// channels are buffered because we read and write to them
-	// from the same goroutine in managerThread()
 	c := &Control{
 		auth:     authMsg,
 		conn:     ctlConn,
-		out:      make(chan msg.Message, 5),
-		in:       make(chan msg.Message, 5),
-		stop:     make(chan msg.Message, 5),
+		out:      make(chan msg.Message),
+		in:       make(chan msg.Message),
 		proxies:  make(chan conn.Conn, 10),
 		lastPing: time.Now(),
+		writerShutdown: util.NewShutdown(),
+		managerShutdown: util.NewShutdown(),
+		shutdown: util.NewShutdown(),
 	}
 
 	failAuth := func(e error) {
@@ -91,8 +93,20 @@ func NewControl(ctlConn conn.Conn, authMsg *msg.Auth) {
 	}
 
 	// register the control
-	controlRegistry.Add(c.id, c)
+	if replaced := controlRegistry.Add(c.id, c); replaced != nil {
+		replaced.shutdown.WaitComplete()
+	}
 
+	// set logging prefix
+	ctlConn.SetType("ctl")
+
+	// manage the connection
+	go c.manager()
+	go c.reader()
+	go c.writer()
+	go c.stopper()
+
+	// Respond to authentication
 	c.out <- &msg.AuthResp{
 		Version:   version.Proto,
 		MmVersion: version.MajorMinor(),
@@ -101,13 +115,6 @@ func NewControl(ctlConn conn.Conn, authMsg *msg.Auth) {
 
 	// As a performance optimization, ask for a proxy connection up front
 	c.out <- &msg.ReqProxy{}
-
-	// set logging prefix
-	ctlConn.SetType("ctl")
-
-	// manage the connection
-	go c.managerThread()
-	go c.readThread()
 }
 
 // Register a new tunnel on this control connection
@@ -119,14 +126,9 @@ func (c *Control) registerTunnel(rawTunnelReq *msg.ReqTunnel) {
 		c.conn.Debug("Registering new tunnel")
 		t, err := NewTunnel(&tunnelReq, c)
 		if err != nil {
-			ack := &msg.NewTunnel{Error: err.Error()}
+			c.out <- &msg.NewTunnel{Error: err.Error()}
 			if len(c.tunnels) == 0 {
-				// you can't fail your first tunnel registration
-				// terminate the control connection
-				c.stop <- ack
-			} else {
-				// inform client of failure
-				c.out <- ack
+				c.shutdown.Begin()
 			}
 
 			// we're done
@@ -147,61 +149,36 @@ func (c *Control) registerTunnel(rawTunnelReq *msg.ReqTunnel) {
 	}
 }
 
-func (c *Control) managerThread() {
-	reap := time.NewTicker(connReapInterval)
-
-	// all shutdown functionality in here
+func (c *Control) manager() {
+	// don't crash on panics
 	defer func() {
 		if err := recover(); err != nil {
-			c.conn.Info("Control::managerThread failed with error %v: %s", err, debug.Stack())
-		}
-
-		// remove from the control registry
-		controlRegistry.Del(c.id)
-
-		// mark that we're shutting down
-		atomic.StoreInt32(&c.closing, 1)
-
-		// stop the reaping timer
-		reap.Stop()
-
-		// close the connection
-		c.conn.Close()
-
-		// shutdown all of the tunnels
-		for _, t := range c.tunnels {
-			t.Shutdown()
+			c.conn.Info("Control::manager failed with error %v: %s", err, debug.Stack())
 		}
+	}()
 
-		// we're safe to close(c.proxies) because c.closing
-		// protects us inside of RegisterProxy
-		close(c.proxies)
+	// kill everything if the control manager stops
+	defer c.shutdown.Begin()
 
-		// shut down all of the proxy connections
-		for p := range c.proxies {
-			p.Close()
-		}
+	// notify that manager() has shutdown
+	defer c.managerShutdown.Complete()
 
-	}()
+	// reaping timer for detecting heartbeat failure
+	reap := time.NewTicker(connReapInterval)
+	defer reap.Stop()
 
 	for {
 		select {
-		case m := <-c.out:
-			msg.WriteMsg(c.conn, m)
-
-		case m := <-c.stop:
-			if m != nil {
-				msg.WriteMsg(c.conn, m)
-			}
-			return
-
 		case <-reap.C:
 			if time.Since(c.lastPing) > pingTimeoutInterval {
 				c.conn.Info("Lost heartbeat")
-				return
+				c.shutdown.Begin()
 			}
 
-		case mRaw := <-c.in:
+		case mRaw, ok := <-c.in:
+			// c.in closes to indicate shutdown
+			if !ok { return }
+
 			switch m := mRaw.(type) {
 			case *msg.ReqTunnel:
 				c.registerTunnel(m)
@@ -214,14 +191,39 @@ func (c *Control) managerThread() {
 	}
 }
 
-func (c *Control) readThread() {
+func (c *Control) writer() {
+	defer func() {
+		if err := recover(); err != nil {
+			c.conn.Info("Control::writer failed with error %v: %s", err, debug.Stack())
+		}
+	}()
+
+	// kill everything if the writer() stops
+	defer c.shutdown.Begin()
+
+	// notify that we've flushed all messages
+	defer c.writerShutdown.Complete()
+
+	// write messages to the control channel
+	for m := range c.out {
+		c.conn.SetWriteDeadline(time.Now().Add(controlWriteTimeout))
+		if err := msg.WriteMsg(c.conn, m); err != nil {
+			panic(err)
+		}
+	}
+}
+
+func (c *Control) reader() {
 	defer func() {
 		if err := recover(); err != nil {
-			c.conn.Info("Control::readThread failed with error %v: %s", err, debug.Stack())
+			c.conn.Info("Control::reader failed with error %v: %s", err, debug.Stack())
 		}
-		c.stop <- nil
 	}()
 
+	// kill everything if the reader stops
+	defer c.shutdown.Begin()
+
+
 	// read messages from the control channel
 	for {
 		if msg, err := msg.ReadMsg(c.conn); err != nil {
@@ -232,18 +234,53 @@ func (c *Control) readThread() {
 				panic(err)
 			}
 		} else {
+			// this can also panic during shutdown
 			c.in <- msg
 		}
 	}
 }
 
-func (c *Control) RegisterProxy(conn conn.Conn) {
-	if atomic.LoadInt32(&c.closing) == 1 {
-		c.conn.Debug("Can't register proxies for a control that is closing")
-		conn.Close()
-		return
+func (c *Control) stopper() {
+	defer func() {
+		if r := recover(); r != nil {
+			c.conn.Error("Failed to shut down control: %v", r)
+		}
+	}()
+
+	// wait until we're instructed to shutdown
+	c.shutdown.WaitBegin()
+
+	// remove ourself from the control registry
+	controlRegistry.Del(c.id)
+
+	// shutdown manager() so that we have no more work to do
+	close(c.in)
+	c.managerShutdown.WaitComplete()
+
+	// shutdown writer()
+	close(c.out)
+	c.writerShutdown.WaitComplete()
+
+	// close the connection
+	// XXX: this will kill reader() ungracefully
+	c.conn.Close()
+
+	// shutdown all of the tunnels
+	for _, t := range c.tunnels {
+		t.Shutdown()
 	}
 
+	// shutdown all of the proxy connections
+	close(c.proxies)
+	for p := range c.proxies {
+		p.Close()
+	}
+
+	c.shutdown.Complete()
+	c.conn.Info("Shutdown complete")
+}
+
+func (c *Control) RegisterProxy(conn conn.Conn) {
 	select {
 	case c.proxies <- conn:
 		c.conn.Info("Registered proxy connection %s", conn.Id())
@@ -275,7 +312,10 @@ func (c *Control) GetProxy() (proxyConn conn.Conn, err error) {
 		case <-timeout.C:
 			c.conn.Debug("Requesting new proxy connection")
 			// request a proxy connection
-			c.out <- &msg.ReqProxy{}
+			if err = util.PanicToError(func() { c.out <- &msg.ReqProxy{} }); err != nil {
+				return
+			}
+
 			// timeout after 1 second if we don't get one
 			timeout.Reset(1 * time.Second)
 		}
@@ -283,13 +323,21 @@ func (c *Control) GetProxy() (proxyConn conn.Conn, err error) {
 
 	// To try to reduce latency hanndling tunnel connections, we employ
 	// the following curde heuristic:
-	// If the proxy connection pool is empty, request a new one.
-	// The idea is to always have at least one proxy connection available for immediate use.
-	// There are two major issues with this strategy: it's not thread safe and it's not predictive.
-	// It should be a good start though.
-	if len(c.proxies) == 0 {
-		c.out <- &msg.ReqProxy{}
-	}
-
+	// Whenever we take a proxy connection from the pool, replace it with a new one
+	err = util.PanicToError(func() { c.out <- &msg.ReqProxy{} })
 	return
 }
+
+// Called when this control is replaced by another control
+// this can happen if the network drops out and the client reconnects
+// before the old tunnel has lost its heartbeat
+func (c *Control) Replaced(replacement *Control) {
+	c.conn.Info("Replaced by control: %s", replacement.conn.Id())
+
+	// set the control id to empty string so that when stopper()
+	// calls registry.Del it won't delete the replacement
+	c.id = ""
+
+	// tell the old one to shutdown
+	c.shutdown.Begin()
+}

+ 1 - 1
src/ngrok/server/main.go

@@ -67,7 +67,7 @@ func tunnelListener(addr string, tlsConfig *tls.Config) {
 			tunnelConn.SetReadDeadline(time.Now().Add(connReadTimeout))
 			var rawMsg msg.Message
 			if rawMsg, err = msg.ReadMsg(tunnelConn); err != nil {
-				tunnelConn.Error("Failed to read message: %v", err)
+				tunnelConn.Warn("Failed to read message: %v", err)
 				tunnelConn.Close()
 				return
 			}

+ 8 - 1
src/ngrok/server/registry.go

@@ -180,11 +180,18 @@ func (r *ControlRegistry) Get(clientId string) *Control {
 	return r.controls[clientId]
 }
 
-func (r *ControlRegistry) Add(clientId string, ctl *Control) {
+func (r *ControlRegistry) Add(clientId string, ctl *Control) (oldCtl *Control) {
 	r.Lock()
 	defer r.Unlock()
+
+	oldCtl = r.controls[clientId]
+	if oldCtl != nil {
+		oldCtl.Replaced(ctl)
+	}
+
 	r.controls[clientId] = ctl
 	r.Info("Registered control with id %s", clientId)
+	return
 }
 
 func (r *ControlRegistry) Del(clientId string) error {

+ 14 - 0
src/ngrok/util/trace.go → src/ngrok/util/errors.go

@@ -14,8 +14,22 @@ panic: %v
 
 %s`
 
+
 func MakePanicTrace(err interface{}) string {
 	stackBuf := make([]byte, 4096)
 	n := runtime.Stack(stackBuf, false)
 	return fmt.Sprintf(crashMessage, err, stackBuf[:n])
 }
+
+// Runs the given function and converts any panic encountered while doing so
+// into an error. Useful for sending to channels that will close
+func PanicToError(fn func()) (err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			err = fmt.Errorf("Panic: %v", r)
+		}
+	}()
+	fn()
+	return
+}
+

+ 43 - 0
src/ngrok/util/shutdown.go

@@ -0,0 +1,43 @@
+package util
+
+import (
+	"sync"
+)
+
+// A small utility class for managing controlled shutdowns
+type Shutdown struct {
+	sync.Mutex
+	inProgress bool
+	begin chan int    // closed when the shutdown begins
+	complete chan int // closed when the shutdown completes
+}
+
+func NewShutdown() *Shutdown {
+	return &Shutdown{
+		begin: make(chan int),
+		complete: make(chan int),
+	}
+}
+
+func (s *Shutdown) Begin() {
+	s.Lock()
+	defer s.Unlock()
+	if s.inProgress == true {
+		return
+	} else {
+		s.inProgress = true
+		close(s.begin)
+	}
+}
+
+func (s *Shutdown) WaitBegin() {
+	<-s.begin
+}
+
+func (s *Shutdown) Complete() {
+	close(s.complete)
+}
+
+func (s *Shutdown) WaitComplete() {
+	<-s.complete
+}