Browse Source

safely shutdown connections by using TCPConn.CloseRead() so that they don't throw errors

Alan Shreve 12 years ago
parent
commit
ef9aa6251d
1 changed files with 30 additions and 11 deletions
  1. 30 11
      src/ngrok/conn/conn.go

+ 30 - 11
src/ngrok/conn/conn.go

@@ -12,6 +12,7 @@ import (
 	"net/http"
 	"net/url"
 	"ngrok/log"
+	"sync"
 )
 
 type Conn interface {
@@ -19,9 +20,11 @@ type Conn interface {
 	log.Logger
 	Id() string
 	SetType(string)
+	CloseRead() error
 }
 
 type loggedConn struct {
+	tcp *net.TCPConn
 	net.Conn
 	log.Logger
 	id  int32
@@ -34,9 +37,16 @@ type Listener struct {
 }
 
 func wrapConn(conn net.Conn, typ string) *loggedConn {
-	c := &loggedConn{conn, log.NewPrefixLogger(), rand.Int31(), typ}
-	c.AddLogPrefix(c.Id())
-	return c
+	switch c := conn.(type) {
+	case *loggedConn:
+		return c
+	case *net.TCPConn:
+		wrapped := &loggedConn{c, conn, log.NewPrefixLogger(), rand.Int31(), typ}
+		wrapped.AddLogPrefix(wrapped.Id())
+		return wrapped
+	}
+
+	return nil
 }
 
 func Listen(addr, typ string, tlsCfg *tls.Config) (l *Listener, err error) {
@@ -152,9 +162,11 @@ func (c *loggedConn) StartTLS(tlsCfg *tls.Config) {
 	c.Conn = tls.Client(c.Conn, tlsCfg)
 }
 
-func (c *loggedConn) Close() error {
-	c.Debug("Closing")
-	return c.Conn.Close()
+func (c *loggedConn) Close() (err error) {
+	if err := c.Conn.Close(); err != nil {
+		c.Debug("Closing")
+	}
+	return
 }
 
 func (c *loggedConn) Id() string {
@@ -169,28 +181,35 @@ func (c *loggedConn) SetType(typ string) {
 	c.Info("Renamed connection %s", oldId)
 }
 
+func (c *loggedConn) CloseRead() error {
+	return c.tcp.CloseRead()
+}
+
 func Join(c Conn, c2 Conn) (int64, int64) {
-	done := make(chan error)
+	var wait sync.WaitGroup
+
 	pipe := func(to Conn, from Conn, bytesCopied *int64) {
+		defer to.CloseRead()
+		defer from.CloseRead()
+		defer wait.Done()
+
 		var err error
 		*bytesCopied, err = io.Copy(to, from)
 		if err != nil {
 			from.Warn("Copied %d bytes to %s before failing with error %v", *bytesCopied, to.Id(), err)
-			done <- err
 		} else {
 			from.Debug("Copied %d bytes from to %s", *bytesCopied, to.Id())
-			done <- nil
 		}
 	}
 
+	wait.Add(2)
 	var fromBytes, toBytes int64
 	go pipe(c, c2, &fromBytes)
 	go pipe(c2, c, &toBytes)
 	c.Info("Joined with connection %s", c2.Id())
-	<-done
+	wait.Wait()
 	c.Close()
 	c2.Close()
-	<-done
 	return fromBytes, toBytes
 }