Browse Source

allow ngrokd to bind to specific interfaces

Alan Shreve 12 years ago
parent
commit
ca72fcb05c
5 changed files with 76 additions and 72 deletions
  1. 26 34
      src/ngrok/conn/conn.go
  2. 9 9
      src/ngrok/server/cli.go
  3. 11 7
      src/ngrok/server/http.go
  4. 18 16
      src/ngrok/server/main.go
  5. 12 6
      src/ngrok/server/tunnel.go

+ 26 - 34
src/ngrok/conn/conn.go

@@ -19,7 +19,7 @@ type Conn interface {
 	SetType(string)
 }
 
-type tcpConn struct {
+type loggedConn struct {
 	net.Conn
 	log.Logger
 	id  int32
@@ -27,84 +27,76 @@ type tcpConn struct {
 }
 
 type Listener struct {
-	*net.TCPAddr
+	net.Addr
 	Conns chan Conn
 }
 
-func wrapTcpConn(conn net.Conn, typ string) *tcpConn {
-	c := &tcpConn{conn, log.NewPrefixLogger(), rand.Int31(), typ}
+func wrapConn(conn net.Conn, typ string) *loggedConn {
+	c := &loggedConn{conn, log.NewPrefixLogger(), rand.Int31(), typ}
 	c.AddLogPrefix(c.Id())
 	return c
 }
 
-func Listen(addr *net.TCPAddr, typ string, tlsCfg *tls.Config) (l *Listener, err error) {
+func Listen(addr, typ string, tlsCfg *tls.Config) (l *Listener, err error) {
 	// listen for incoming connections
-	listener, err := net.ListenTCP("tcp", addr)
+	listener, err := net.Listen("tcp", addr)
 	if err != nil {
 		return
 	}
 
 	l = &Listener{
-		TCPAddr: listener.Addr().(*net.TCPAddr),
+		Addr: listener.Addr(),
 		Conns:   make(chan Conn),
 	}
 
 	go func() {
 		for {
-			tcpConn, err := listener.AcceptTCP()
+			rawConn, err := listener.Accept()
 			if err != nil {
 				log.Error("Failed to accept new TCP connection of type %s: %v", typ, err)
 				continue
 			}
 
-			c := wrapTcpConn(tcpConn, typ)
+			c := wrapConn(rawConn, typ)
 			if tlsCfg != nil {
 				c.Conn = tls.Server(c.Conn, tlsCfg)
 			}
-			c.Info("New connection from %v", tcpConn.RemoteAddr())
+			c.Info("New connection from %v", c.RemoteAddr())
 			l.Conns <- c
 		}
 	}()
 	return
 }
 
-func Wrap(conn net.Conn, typ string) *tcpConn {
-	return wrapTcpConn(conn, typ)
+func Wrap(conn net.Conn, typ string) *loggedConn {
+	return wrapConn(conn, typ)
 }
 
-func Dial(addr, typ string, tlsCfg *tls.Config) (conn *tcpConn, err error) {
-	var (
-		tcpAddr *net.TCPAddr
-		tcpConn net.Conn
-	)
-
-	if tcpAddr, err = net.ResolveTCPAddr("tcp", addr); err != nil {
-		return
-	}
-
-	if tcpConn, err = net.DialTCP("tcp", nil, tcpAddr); err != nil {
+func Dial(addr, typ string, tlsCfg *tls.Config) (conn *loggedConn, err error) {
+	var rawConn net.Conn
+	if rawConn, err = net.Dial("tcp", addr); err != nil {
 		return
 	}
 
 	if tlsCfg != nil {
-		tcpConn = tls.Client(tcpConn, tlsCfg)
+		rawConn = tls.Client(rawConn, tlsCfg)
 	}
 
-	conn = wrapTcpConn(tcpConn, typ)
-	conn.Debug("New connection to: %v", tcpAddr)
-	return conn, nil
+	conn = wrapConn(rawConn, typ)
+	conn.Debug("New connection to: %v", rawConn.RemoteAddr())
+	return
 }
 
-func (c *tcpConn) Close() error {
+func (c *loggedConn) Close() error {
 	c.Debug("Closing")
 	return c.Conn.Close()
 }
 
-func (c *tcpConn) Id() string {
+func (c *loggedConn) Id() string {
 	return fmt.Sprintf("%s:%x", c.typ, c.id)
 }
 
-func (c *tcpConn) SetType(typ string) {
+func (c *loggedConn) SetType(typ string) {
 	oldId := c.Id()
 	c.typ = typ
 	c.ClearLogPrefixes()
@@ -138,23 +130,23 @@ func Join(c Conn, c2 Conn) (int64, int64) {
 }
 
 type httpConn struct {
-	*tcpConn
+	*loggedConn
 	reqBuf *bytes.Buffer
 }
 
 func NewHttp(conn net.Conn, typ string) *httpConn {
 	return &httpConn{
-		wrapTcpConn(conn, typ),
+		wrapConn(conn, typ),
 		bytes.NewBuffer(make([]byte, 0, 1024)),
 	}
 }
 
 func (c *httpConn) ReadRequest() (*http.Request, error) {
-	r := io.TeeReader(c.tcpConn, c.reqBuf)
+	r := io.TeeReader(c.loggedConn, c.reqBuf)
 	return http.ReadRequest(bufio.NewReader(r))
 }
 
-func (c *tcpConn) ReadFrom(r io.Reader) (n int64, err error) {
+func (c *loggedConn) ReadFrom(r io.Reader) (n int64, err error) {
 	// special case when we're reading from an http request where
 	// we had to parse the request and consume bytes from the socket
 	// and store them in a temporary request buffer

+ 9 - 9
src/ngrok/server/cli.go

@@ -5,17 +5,17 @@ import (
 )
 
 type Options struct {
-	httpPort   int
-	httpsPort  int
-	tunnelPort int
+	httpAddr   string
+	httpsAddr  string
+	tunnelAddr string
 	domain     string
 	logto      string
 }
 
 func parseArgs() *Options {
-	httpPort := flag.Int("httpPort", 80, "Public HTTP port, -1 to disable")
-	httpsPort := flag.Int("httpsPort", 443, "Public HTTPS port, -1 to disable")
-	tunnelPort := flag.Int("tunnelPort", 4443, "Port to which ngrok clients connect")
+	httpAddr := flag.String("httpAddr", ":80", "Public address for HTTP connections, empty string to disable")
+	httpsAddr := flag.String("httpsAddr", ":443", "Public address listening for HTTPS connections, emptry string to disable")
+	tunnelAddr := flag.String("tunnelAddr", ":4443", "Public address listening for ngrok client")
 	domain := flag.String("domain", "ngrok.com", "Domain where the tunnels are hosted")
 	logto := flag.String(
 		"log",
@@ -25,9 +25,9 @@ func parseArgs() *Options {
 	flag.Parse()
 
 	return &Options{
-		httpPort:   *httpPort,
-		httpsPort:  *httpsPort,
-		tunnelPort: *tunnelPort,
+		httpAddr:   *httpAddr,
+		httpsAddr:  *httpsAddr,
+		tunnelAddr: *tunnelAddr,
 		domain:     *domain,
 		logto:      *logto,
 	}

+ 11 - 7
src/ngrok/server/http.go

@@ -31,10 +31,10 @@ Bad Request
 )
 
 // Listens for new http(s) connections from the public internet
-func httpListener(addr *net.TCPAddr, tlsCfg *tls.Config) {
+func startHttpListener(addr string, tlsCfg *tls.Config) (listener *conn.Listener) {
 	// bind/listen for incoming connections
-	listener, err := conn.Listen(addr, "pub", tlsCfg)
-	if err != nil {
+	var err error
+	if listener, err = conn.Listen(addr, "pub", tlsCfg); err != nil {
 		panic(err)
 	}
 
@@ -43,10 +43,14 @@ func httpListener(addr *net.TCPAddr, tlsCfg *tls.Config) {
 		proto = "https"
 	}
 
-	log.Info("Listening for public %s connections on %v", proto, listener.Port)
-	for conn := range listener.Conns {
-		go httpHandler(conn, proto)
-	}
+	log.Info("Listening for public %s connections on %v", proto, listener.Addr.String())
+	go func() {
+		for conn := range listener.Conns {
+			go httpHandler(conn, proto)
+		}
+	}()
+
+	return
 }
 
 // Handles a new http connection from the public internet

+ 18 - 16
src/ngrok/server/main.go

@@ -2,7 +2,6 @@ package server
 
 import (
 	"math/rand"
-	"net"
 	"ngrok/conn"
 	log "ngrok/log"
 	"ngrok/msg"
@@ -10,14 +9,18 @@ import (
 	"os"
 )
 
+const (
+	registryCacheSize uint64 = 1024 * 1024 // 1 MB
+)
+
 // GLOBALS
 var (
-	opts              *Options
 	tunnelRegistry    *TunnelRegistry
 	controlRegistry   *ControlRegistry
-	registryCacheSize uint64 = 1024 * 1024 // 1 MB
-	domain            string
-	publicPort        int
+
+	// XXX: kill these global variables - they're only used in tunnel.go for constructing forwarding URLs
+	opts              *Options
+	listeners         map[string] *conn.Listener
 )
 
 func NewProxy(pxyConn conn.Conn, regPxy *msg.RegProxyMsg) {
@@ -48,14 +51,14 @@ func NewProxy(pxyConn conn.Conn, regPxy *msg.RegProxyMsg) {
 // for ease of deployment. The hope is that by running on port 443, using
 // TLS and running all connections over the same port, we can bust through
 // restrictive firewalls.
-func tunnelListener(addr *net.TCPAddr, domain string) {
+func tunnelListener(addr string) {
 	// listen for incoming connections
 	listener, err := conn.Listen(addr, "tun", tlsConfig)
 	if err != nil {
 		panic(err)
 	}
 
-	log.Info("Listening for control and proxy connections on %d", listener.Port)
+	log.Info("Listening for control and proxy connections on %d", listener.Addr.String())
 	for c := range listener.Conns {
 		var rawMsg msg.Message
 		if rawMsg, err = msg.ReadMsg(c); err != nil {
@@ -92,20 +95,19 @@ func Main() {
 	tunnelRegistry = NewTunnelRegistry(registryCacheSize, registryCacheFile)
 	controlRegistry = NewControlRegistry()
 
-	// ngrok clients
-	go tunnelListener(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: opts.tunnelPort}, opts.domain)
+	// start listeners
+	listeners = make(map[string] *conn.Listener)
 
 	// listen for http
-	if opts.httpPort != -1 {
-		go httpListener(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: opts.httpPort}, nil)
+	if opts.httpAddr != "" {
+		listeners["http"] = startHttpListener(opts.httpAddr, nil)
 	}
 
 	// listen for https
-	if opts.httpsPort != -1 {
-		go httpListener(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: opts.httpsPort}, tlsConfig)
+	if opts.httpsAddr != "" {
+		listeners["https"] = startHttpListener(opts.httpsAddr, tlsConfig)
 	}
 
-	// wait forever
-	done := make(chan int)
-	<-done
+	// ngrok clients
+	tunnelListener(opts.tunnelAddr)
 }

+ 12 - 6
src/ngrok/server/tunnel.go

@@ -102,7 +102,8 @@ func NewTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel, err error) {
 		Logger: log.NewPrefixLogger(),
 	}
 
-	switch t.regMsg.Protocol {
+	proto := t.regMsg.Protocol
+	switch proto {
 	case "tcp":
 		var port int = 0
 
@@ -136,7 +137,7 @@ func NewTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel, err error) {
 
 		// create the url
 		addr := t.listener.Addr().(*net.TCPAddr)
-		t.url = fmt.Sprintf("tcp://%s:%d", domain, addr.Port)
+		t.url = fmt.Sprintf("tcp://%s:%d", opts.domain, addr.Port)
 
 		// register it
 		if err = tunnelRegistry.RegisterAndCache(t.url, t); err != nil {
@@ -149,15 +150,20 @@ func NewTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel, err error) {
 
 		go t.listenTcp(t.listener)
 
-	case "http":
-		if err = registerVhost(t, "http", opts.httpPort); err != nil {
+	case "http", "https":
+		l, ok := listeners[proto]
+		if !ok {
+			err = fmt.Errorf("Not listeneing for %s connections", proto)
 			return
 		}
 
-	case "https":
-		if err = registerVhost(t, "https", opts.httpsPort); err != nil {
+		if err = registerVhost(t, proto, l.Addr.(*net.TCPAddr).Port); err != nil {
 			return
 		}
+
+	default:
+		err = fmt.Errorf("Protocol %s is not supported", proto)
+		return
 	}
 
 	if m.Version != version.Proto {