Browse Source

move more net code into conn, factor out a Listen() pattern

Alan Shreve 12 years ago
parent
commit
6a28c1b645
5 changed files with 82 additions and 68 deletions
  1. 17 40
      src/ngrok/client/main.go
  2. 49 3
      src/ngrok/conn/conn.go
  3. 2 3
      src/ngrok/server/control.go
  4. 12 21
      src/ngrok/server/main.go
  5. 2 1
      src/ngrok/server/tunnel.go

+ 17 - 40
src/ngrok/client/main.go

@@ -4,7 +4,6 @@ import (
 	log "code.google.com/p/log4go"
 	"fmt"
 	"io/ioutil"
-	"net"
 	"ngrok/client/ui"
 	"ngrok/client/views/term"
 	"ngrok/client/views/web"
@@ -23,56 +22,33 @@ const (
 	maxPongLatency = 15 * time.Second
 )
 
-/** 
- * Connect to the ngrok server
- */
-func connect(addr string, typ string) (c conn.Conn, err error) {
-	var (
-		tcpAddr *net.TCPAddr
-		tcpConn *net.TCPConn
-	)
-
-	if tcpAddr, err = net.ResolveTCPAddr("tcp", addr); err != nil {
-		return
-	}
-
-	log.Debug("Dialing %v", addr)
-	if tcpConn, err = net.DialTCP("tcp", nil, tcpAddr); err != nil {
-		return
-	}
-
-	c = conn.NewTCP(tcpConn, typ)
-	c.Debug("Connected to: %v", tcpAddr)
-	return c, nil
-}
-
 /**
  * Establishes and manages a tunnel proxy connection with the server
  */
 func proxy(proxyAddr string, s *State, ctl *ui.Controller) {
 	start := time.Now()
-	remoteConn, err := connect(proxyAddr, "pxy")
+	remoteConn, err := conn.Dial(proxyAddr, "pxy")
 	if err != nil {
-                // XXX: What is the proper response here?
-                // display something to the user?
-                // retry?
-                // reset control connection?
-                log.Error("Failed to establish proxy connection: %v", err)
-                return
+		// XXX: What is the proper response here?
+		// display something to the user?
+		// retry?
+		// reset control connection?
+		log.Error("Failed to establish proxy connection: %v", err)
+		return
 	}
 
 	defer remoteConn.Close()
 	err = msg.WriteMsg(remoteConn, &msg.RegProxyMsg{Url: s.publicUrl})
 	if err != nil {
-                // XXX: What is the proper response here?
-                // display something to the user?
-                // retry?
-                // reset control connection?
-                log.Error("Failed to write RegProxyMsg: %v", err)
-                return
+		// XXX: What is the proper response here?
+		// display something to the user?
+		// retry?
+		// reset control connection?
+		log.Error("Failed to write RegProxyMsg: %v", err)
+		return
 	}
 
-	localConn, err := connect(s.opts.localaddr, "prv")
+	localConn, err := conn.Dial(s.opts.localaddr, "prv")
 	if err != nil {
 		remoteConn.Warn("Failed to open private leg %s: %v", s.opts.localaddr, err)
 		return
@@ -147,7 +123,7 @@ func control(s *State, ctl *ui.Controller) {
 	}()
 
 	// establish control channel
-	conn, err := connect(s.opts.server, "ctl")
+	conn, err := conn.Dial(s.opts.server, "ctl")
 	if err != nil {
 		panic(err)
 	}
@@ -258,7 +234,8 @@ func Main() {
 				case ui.REPLAY:
 					go func() {
 						payload := cmd.Payload.([]byte)
-						localConn, err := connect(s.opts.localaddr, "prv")
+						var localConn conn.Conn
+						localConn, err := conn.Dial(s.opts.localaddr, "prv")
 						if err != nil {
 							log.Warn("Failed to open private leg %s: %v", s.opts.localaddr, err)
 							return

+ 49 - 3
src/ngrok/conn/conn.go

@@ -24,13 +24,59 @@ type tcpConn struct {
 	typ string
 }
 
-func NewTCP(conn net.Conn, typ string) *tcpConn {
+func wrapTcpConn(conn net.Conn, typ string) *tcpConn {
 	c := &tcpConn{conn, log.NewPrefixLogger(), rand.Int31(), typ}
 	c.AddLogPrefix(c.Id())
-	c.Info("New connection from %v", conn.RemoteAddr())
 	return c
 }
 
+func Listen(addr *net.TCPAddr, typ string) (conns chan Conn, err error) {
+	// listen for incoming connections
+	listener, err := net.ListenTCP("tcp", addr)
+	if err != nil {
+		return
+	}
+
+	conns = make(chan Conn)
+	go func() {
+		for {
+			tcpConn, err := listener.AcceptTCP()
+			if err != nil {
+				panic(err)
+			}
+
+			c := wrapTcpConn(tcpConn, typ)
+			c.Info("New connection from %v", tcpConn.RemoteAddr())
+			conns <- c
+		}
+	}()
+	return
+}
+
+func Wrap(conn net.Conn, typ string) *tcpConn {
+	return wrapTcpConn(conn, typ)
+}
+
+func Dial(addr, typ string) (conn *tcpConn, err error) {
+	var (
+		tcpAddr *net.TCPAddr
+		tcpConn *net.TCPConn
+	)
+
+	if tcpAddr, err = net.ResolveTCPAddr("tcp", addr); err != nil {
+		return
+	}
+
+	//log.Debug("Dialing %v", addr)
+	if tcpConn, err = net.DialTCP("tcp", nil, tcpAddr); err != nil {
+		return
+	}
+
+	conn = wrapTcpConn(tcpConn, typ)
+	conn.Debug("New connection to: %v", tcpAddr)
+	return conn, nil
+}
+
 func (c *tcpConn) Close() error {
 	c.Debug("Closing")
 	return c.Conn.Close()
@@ -72,7 +118,7 @@ type httpConn struct {
 
 func NewHttp(conn net.Conn, typ string) *httpConn {
 	return &httpConn{
-		NewTCP(conn, typ),
+		wrapTcpConn(conn, typ),
 		bytes.NewBuffer(make([]byte, 0, 1024)),
 	}
 }

+ 2 - 3
src/ngrok/server/control.go

@@ -2,7 +2,6 @@ package server
 
 import (
 	"io"
-	"net"
 	"ngrok/conn"
 	"ngrok/msg"
 	"runtime/debug"
@@ -30,9 +29,9 @@ type Control struct {
 	tun *Tunnel
 }
 
-func NewControl(tcpConn *net.TCPConn) {
+func NewControl(conn conn.Conn) {
 	c := &Control{
-		conn:     conn.NewTCP(tcpConn, "ctl"),
+		conn:     conn,
 		out:      make(chan (interface{}), 1),
 		in:       make(chan (msg.Message), 1),
 		stop:     make(chan (msg.Message), 1),

+ 12 - 21
src/ngrok/server/main.go

@@ -55,19 +55,14 @@ func getTCPPort(addr net.Addr) int {
  */
 func controlListener(addr *net.TCPAddr, domain string) {
 	// listen for incoming connections
-	listener, err := net.ListenTCP("tcp", addr)
+	conns, err := conn.Listen(addr, "ctl")
 	if err != nil {
 		panic(err)
 	}
 
 	log.Info("Listening for control connections on %d", getTCPPort(addr))
-	for {
-		tcpConn, err := listener.AcceptTCP()
-		if err != nil {
-			panic(err)
-		}
-
-		NewControl(tcpConn)
+	for c := range conns {
+		NewControl(c)
 	}
 }
 
@@ -75,23 +70,17 @@ func controlListener(addr *net.TCPAddr, domain string) {
  * Listens for new proxy connections from tunnel clients
  */
 func proxyListener(addr *net.TCPAddr, domain string) {
-	listener, err := net.ListenTCP("tcp", addr)
-	proxyAddr = fmt.Sprintf("%s:%d", domain, getTCPPort(listener.Addr()))
-
+	conns, err := conn.Listen(addr, "pxy")
 	if err != nil {
 		panic(err)
 	}
 
-	log.Info("Listening for proxy connection on %d", getTCPPort(listener.Addr()))
-	for {
-		tcpConn, err := listener.AcceptTCP()
-		if err != nil {
-			panic(err)
-		}
-
-		conn := conn.NewTCP(tcpConn, "pxy")
-
+	// set global proxy addr variable
+	proxyAddr = fmt.Sprintf("%s:%d", domain, getTCPPort(addr))
+	log.Info("Listening for proxy connection on %d", getTCPPort(addr))
+	for conn := range conns {
 		go func() {
+			// fail gracefully if the proxy connection dies
 			defer func() {
 				if r := recover(); r != nil {
 					conn.Warn("Failed with error: %v", r)
@@ -99,18 +88,20 @@ func proxyListener(addr *net.TCPAddr, domain string) {
 				}
 			}()
 
+			// read the proxy register message
 			var regPxy msg.RegProxyMsg
 			if err = msg.ReadMsgInto(conn, &regPxy); err != nil {
 				panic(err)
 			}
 
+			// look up the tunnel for this proxy
 			conn.Info("Registering new proxy for %s", regPxy.Url)
-
 			tunnel := tunnels.Get(regPxy.Url)
 			if tunnel == nil {
 				panic("No tunnel found for: " + regPxy.Url)
 			}
 
+			// register the proxy connection with the tunnel
 			tunnel.RegisterProxy(conn)
 		}()
 	}

+ 2 - 1
src/ngrok/server/tunnel.go

@@ -120,8 +120,9 @@ func (t *Tunnel) listenTcp(listener *net.TCPListener) {
 			panic(err)
 		}
 
-		conn := conn.NewTCP(tcpConn, "pub")
+		conn := conn.Wrap(tcpConn, "pub")
 		conn.AddLogPrefix(t.Id())
+		conn.Info("New connection from %v", conn.RemoteAddr())
 
 		go t.HandlePublicConnection(conn)
 	}