Browse Source

make sure every new tunnel connection and every new http connection has a read/write timeouts until handled by a tunnel. fix a bug where a slow tunnel or proxy connection could block all others. fix a bug where sending a non-expected valid ngrok protocol message over a new connection could leak a connection

Alan Shreve 12 years ago
parent
commit
13be54d4e7
2 changed files with 35 additions and 14 deletions
  1. 8 0
      src/ngrok/server/http.go
  2. 27 14
      src/ngrok/server/main.go

+ 8 - 0
src/ngrok/server/http.go

@@ -7,6 +7,7 @@ import (
 	"ngrok/conn"
 	"ngrok/log"
 	"strings"
+	"time"
 )
 
 const (
@@ -66,6 +67,9 @@ func httpHandler(tcpConn net.Conn, proto string) {
 		}
 	}()
 
+	// Make sure we detect dead connections while we decide how to multiplex
+	conn.SetDeadline(time.Now().Add(connReadTimeout))
+
 	// read out the http request
 	req, err := conn.ReadRequest()
 	if err != nil {
@@ -95,5 +99,9 @@ func httpHandler(tcpConn net.Conn, proto string) {
 		return
 	}
 
+	// dead connections will now be handled by tunnel heartbeating and the client
+	conn.SetDeadline(time.Time{})
+
+	// let the tunnel handle the connection now
 	tunnel.HandlePublicConnection(conn)
 }

+ 27 - 14
src/ngrok/server/main.go

@@ -7,10 +7,12 @@ import (
 	"ngrok/msg"
 	"ngrok/util"
 	"os"
+	"time"
 )
 
 const (
-	registryCacheSize uint64 = 1024 * 1024 // 1 MB
+	registryCacheSize uint64        = 1024 * 1024 // 1 MB
+	connReadTimeout   time.Duration = 10 * time.Second
 )
 
 // GLOBALS
@@ -60,19 +62,30 @@ func tunnelListener(addr string) {
 
 	log.Info("Listening for control and proxy connections on %s", listener.Addr.String())
 	for c := range listener.Conns {
-		var rawMsg msg.Message
-		if rawMsg, err = msg.ReadMsg(c); err != nil {
-			c.Error("Failed to read message: %v", err)
-			c.Close()
-		}
-
-		switch m := rawMsg.(type) {
-		case *msg.Auth:
-			go NewControl(c, m)
-
-		case *msg.RegProxy:
-			go NewProxy(c, m)
-		}
+		go func(tunnelConn conn.Conn) {
+			tunnelConn.SetReadDeadline(time.Now().Add(connReadTimeout))
+			var rawMsg msg.Message
+			if rawMsg, err = msg.ReadMsg(tunnelConn); err != nil {
+				tunnelConn.Error("Failed to read message: %v", err)
+				tunnelConn.Close()
+				return
+			}
+
+			// don't timeout after the initial read, tunnel heartbeating will kill
+			// dead connections
+			tunnelConn.SetReadDeadline(time.Time{})
+
+			switch m := rawMsg.(type) {
+			case *msg.Auth:
+				NewControl(tunnelConn, m)
+
+			case *msg.RegProxy:
+				NewProxy(tunnelConn, m)
+
+			default:
+				tunnelConn.Close()
+			}
+		}(c)
 	}
 }