|
@@ -12,6 +12,7 @@ import (
|
|
|
"net/http"
|
|
|
"net/url"
|
|
|
"ngrok/log"
|
|
|
+ "sync"
|
|
|
)
|
|
|
|
|
|
type Conn interface {
|
|
@@ -19,9 +20,11 @@ type Conn interface {
|
|
|
log.Logger
|
|
|
Id() string
|
|
|
SetType(string)
|
|
|
+ CloseRead() error
|
|
|
}
|
|
|
|
|
|
type loggedConn struct {
|
|
|
+ tcp *net.TCPConn
|
|
|
net.Conn
|
|
|
log.Logger
|
|
|
id int32
|
|
@@ -34,9 +37,16 @@ type Listener struct {
|
|
|
}
|
|
|
|
|
|
func wrapConn(conn net.Conn, typ string) *loggedConn {
|
|
|
- c := &loggedConn{conn, log.NewPrefixLogger(), rand.Int31(), typ}
|
|
|
- c.AddLogPrefix(c.Id())
|
|
|
- return c
|
|
|
+ switch c := conn.(type) {
|
|
|
+ case *loggedConn:
|
|
|
+ return c
|
|
|
+ case *net.TCPConn:
|
|
|
+ wrapped := &loggedConn{c, conn, log.NewPrefixLogger(), rand.Int31(), typ}
|
|
|
+ wrapped.AddLogPrefix(wrapped.Id())
|
|
|
+ return wrapped
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
func Listen(addr, typ string, tlsCfg *tls.Config) (l *Listener, err error) {
|
|
@@ -152,9 +162,11 @@ func (c *loggedConn) StartTLS(tlsCfg *tls.Config) {
|
|
|
c.Conn = tls.Client(c.Conn, tlsCfg)
|
|
|
}
|
|
|
|
|
|
-func (c *loggedConn) Close() error {
|
|
|
- c.Debug("Closing")
|
|
|
- return c.Conn.Close()
|
|
|
+func (c *loggedConn) Close() (err error) {
|
|
|
+ if err := c.Conn.Close(); err != nil {
|
|
|
+ c.Debug("Closing")
|
|
|
+ }
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
func (c *loggedConn) Id() string {
|
|
@@ -169,28 +181,35 @@ func (c *loggedConn) SetType(typ string) {
|
|
|
c.Info("Renamed connection %s", oldId)
|
|
|
}
|
|
|
|
|
|
+func (c *loggedConn) CloseRead() error {
|
|
|
+ return c.tcp.CloseRead()
|
|
|
+}
|
|
|
+
|
|
|
func Join(c Conn, c2 Conn) (int64, int64) {
|
|
|
- done := make(chan error)
|
|
|
+ var wait sync.WaitGroup
|
|
|
+
|
|
|
pipe := func(to Conn, from Conn, bytesCopied *int64) {
|
|
|
+ defer to.CloseRead()
|
|
|
+ defer from.CloseRead()
|
|
|
+ defer wait.Done()
|
|
|
+
|
|
|
var err error
|
|
|
*bytesCopied, err = io.Copy(to, from)
|
|
|
if err != nil {
|
|
|
from.Warn("Copied %d bytes to %s before failing with error %v", *bytesCopied, to.Id(), err)
|
|
|
- done <- err
|
|
|
} else {
|
|
|
from.Debug("Copied %d bytes from to %s", *bytesCopied, to.Id())
|
|
|
- done <- nil
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ wait.Add(2)
|
|
|
var fromBytes, toBytes int64
|
|
|
go pipe(c, c2, &fromBytes)
|
|
|
go pipe(c2, c, &toBytes)
|
|
|
c.Info("Joined with connection %s", c2.Id())
|
|
|
- <-done
|
|
|
+ wait.Wait()
|
|
|
c.Close()
|
|
|
c2.Close()
|
|
|
- <-done
|
|
|
return fromBytes, toBytes
|
|
|
}
|
|
|
|