Browse Source

use vhost library to multiplex connections. prevent debug builds from inlining for easier debugging

Alan Shreve 11 years ago
parent
commit
5b14f7b832
3 changed files with 35 additions and 59 deletions
  1. 2 2
      Makefile
  2. 6 35
      src/ngrok/conn/conn.go
  3. 27 22
      src/ngrok/server/http.go

+ 2 - 2
Makefile

@@ -7,13 +7,13 @@ deps:
 	go get -tags '$(BUILDTAGS)' -d -v ngrok/...
 	go get -tags '$(BUILDTAGS)' -d -v ngrok/...
 
 
 server: deps
 server: deps
-	go install -tags '$(BUILDTAGS)' ngrok/main/ngrokd
+	go install -gcflags "-N -l" -tags '$(BUILDTAGS)' ngrok/main/ngrokd
 
 
 fmt:
 fmt:
 	go fmt ngrok/...
 	go fmt ngrok/...
 
 
 client: deps
 client: deps
-	go install -tags '$(BUILDTAGS)' ngrok/main/ngrok
+	go install -gcflags "-N -l" -tags '$(BUILDTAGS)' ngrok/main/ngrok
 
 
 client-assets:
 client-assets:
 	go get github.com/inconshreveable/go-bindata
 	go get github.com/inconshreveable/go-bindata

+ 6 - 35
src/ngrok/conn/conn.go

@@ -2,10 +2,10 @@ package conn
 
 
 import (
 import (
 	"bufio"
 	"bufio"
-	"bytes"
 	"crypto/tls"
 	"crypto/tls"
 	"encoding/base64"
 	"encoding/base64"
 	"fmt"
 	"fmt"
+	vhost "github.com/inconshreveable/go-vhost"
 	"io"
 	"io"
 	"math/rand"
 	"math/rand"
 	"net"
 	"net"
@@ -33,11 +33,14 @@ type loggedConn struct {
 
 
 type Listener struct {
 type Listener struct {
 	net.Addr
 	net.Addr
-	Conns chan Conn
+	Conns chan *loggedConn
 }
 }
 
 
 func wrapConn(conn net.Conn, typ string) *loggedConn {
 func wrapConn(conn net.Conn, typ string) *loggedConn {
 	switch c := conn.(type) {
 	switch c := conn.(type) {
+	case *vhost.HTTPConn:
+		wrapped := c.Conn.(*loggedConn)
+		return &loggedConn{wrapped.tcp, conn, wrapped.Logger, wrapped.id, wrapped.typ}
 	case *loggedConn:
 	case *loggedConn:
 		return c
 		return c
 	case *net.TCPConn:
 	case *net.TCPConn:
@@ -58,7 +61,7 @@ func Listen(addr, typ string, tlsCfg *tls.Config) (l *Listener, err error) {
 
 
 	l = &Listener{
 	l = &Listener{
 		Addr:  listener.Addr(),
 		Addr:  listener.Addr(),
-		Conns: make(chan Conn),
+		Conns: make(chan *loggedConn),
 	}
 	}
 
 
 	go func() {
 	go func() {
@@ -214,35 +217,3 @@ func Join(c Conn, c2 Conn) (int64, int64) {
 	wait.Wait()
 	wait.Wait()
 	return fromBytes, toBytes
 	return fromBytes, toBytes
 }
 }
-
-type httpConn struct {
-	*loggedConn
-	reqBuf *bytes.Buffer
-}
-
-func NewHttp(conn net.Conn, typ string) *httpConn {
-	return &httpConn{
-		wrapConn(conn, typ),
-		bytes.NewBuffer(make([]byte, 0, 1024)),
-	}
-}
-
-func (c *httpConn) ReadRequest() (*http.Request, error) {
-	r := io.TeeReader(c.loggedConn, c.reqBuf)
-	return http.ReadRequest(bufio.NewReader(r))
-}
-
-func (c *loggedConn) ReadFrom(r io.Reader) (n int64, err error) {
-	// special case when we're reading from an http request where
-	// we had to parse the request and consume bytes from the socket
-	// and store them in a temporary request buffer
-	if httpConn, ok := r.(*httpConn); ok {
-		if n, err = httpConn.reqBuf.WriteTo(c); err != nil {
-			return
-		}
-	}
-
-	nCopied, err := io.Copy(c.Conn, r)
-	n += nCopied
-	return
-}

+ 27 - 22
src/ngrok/server/http.go

@@ -3,7 +3,8 @@ package server
 import (
 import (
 	"crypto/tls"
 	"crypto/tls"
 	"fmt"
 	"fmt"
-	"net"
+	vhost "github.com/inconshreveable/go-vhost"
+	//"net"
 	"ngrok/conn"
 	"ngrok/conn"
 	"ngrok/log"
 	"ngrok/log"
 	"strings"
 	"strings"
@@ -55,53 +56,57 @@ func startHttpListener(addr string, tlsCfg *tls.Config) (listener *conn.Listener
 }
 }
 
 
 // Handles a new http connection from the public internet
 // Handles a new http connection from the public internet
-func httpHandler(tcpConn net.Conn, proto string) {
-	// wrap up the connection for logging
-	conn := conn.NewHttp(tcpConn, "pub")
-
-	defer conn.Close()
+func httpHandler(c conn.Conn, proto string) {
+	defer c.Close()
 	defer func() {
 	defer func() {
 		// recover from failures
 		// recover from failures
 		if r := recover(); r != nil {
 		if r := recover(); r != nil {
-			conn.Warn("httpHandler failed with error %v", r)
+			c.Warn("httpHandler failed with error %v", r)
 		}
 		}
 	}()
 	}()
 
 
 	// Make sure we detect dead connections while we decide how to multiplex
 	// Make sure we detect dead connections while we decide how to multiplex
-	conn.SetDeadline(time.Now().Add(connReadTimeout))
+	c.SetDeadline(time.Now().Add(connReadTimeout))
 
 
-	// read out the http request
-	req, err := conn.ReadRequest()
+	// multiplex by extracting the Host header, the vhost library
+	vhostConn, err := vhost.HTTP(c)
 	if err != nil {
 	if err != nil {
-		conn.Warn("Failed to read valid %s request: %v", proto, err)
-		conn.Write([]byte(BadRequest))
+		c.Warn("Failed to read valid %s request: %v", proto, err)
+		c.Write([]byte(BadRequest))
 		return
 		return
 	}
 	}
 
 
-	// read out the Host header from the request
-	host := strings.ToLower(req.Host)
-	conn.Debug("Found hostname %s in request", host)
+	// read out the Host header and auth from the request
+	host := strings.ToLower(vhostConn.Host())
+	auth := vhostConn.Request.Header.Get("Autorization")
+
+	// done reading mux data, free up the request memory
+	vhostConn.Free()
+
+	// We need to read from the vhost conn now since it mucked around reading the stream
+	c = conn.Wrap(vhostConn, "pub")
 
 
 	// multiplex to find the right backend host
 	// multiplex to find the right backend host
+	c.Debug("Found hostname %s in request", host)
 	tunnel := tunnelRegistry.Get(fmt.Sprintf("%s://%s", proto, host))
 	tunnel := tunnelRegistry.Get(fmt.Sprintf("%s://%s", proto, host))
 	if tunnel == nil {
 	if tunnel == nil {
-		conn.Info("No tunnel found for hostname %s", host)
-		conn.Write([]byte(fmt.Sprintf(NotFound, len(host)+18, host)))
+		c.Info("No tunnel found for hostname %s", host)
+		c.Write([]byte(fmt.Sprintf(NotFound, len(host)+18, host)))
 		return
 		return
 	}
 	}
 
 
 	// If the client specified http auth and it doesn't match this request's auth
 	// If the client specified http auth and it doesn't match this request's auth
 	// then fail the request with 401 Not Authorized and request the client reissue the
 	// then fail the request with 401 Not Authorized and request the client reissue the
 	// request with basic authdeny the request
 	// request with basic authdeny the request
-	if tunnel.req.HttpAuth != "" && req.Header.Get("Authorization") != tunnel.req.HttpAuth {
-		conn.Info("Authentication failed: %s", req.Header.Get("Authorization"))
-		conn.Write([]byte(NotAuthorized))
+	if tunnel.req.HttpAuth != "" && auth != tunnel.req.HttpAuth {
+		c.Info("Authentication failed: %s", auth)
+		c.Write([]byte(NotAuthorized))
 		return
 		return
 	}
 	}
 
 
 	// dead connections will now be handled by tunnel heartbeating and the client
 	// dead connections will now be handled by tunnel heartbeating and the client
-	conn.SetDeadline(time.Time{})
+	c.SetDeadline(time.Time{})
 
 
 	// let the tunnel handle the connection now
 	// let the tunnel handle the connection now
-	tunnel.HandlePublicConnection(conn)
+	tunnel.HandlePublicConnection(c)
 }
 }