|
@@ -19,7 +19,7 @@ type Conn interface {
|
|
|
SetType(string)
|
|
|
}
|
|
|
|
|
|
-type tcpConn struct {
|
|
|
+type loggedConn struct {
|
|
|
net.Conn
|
|
|
log.Logger
|
|
|
id int32
|
|
@@ -27,84 +27,76 @@ type tcpConn struct {
|
|
|
}
|
|
|
|
|
|
type Listener struct {
|
|
|
- *net.TCPAddr
|
|
|
+ net.Addr
|
|
|
Conns chan Conn
|
|
|
}
|
|
|
|
|
|
-func wrapTcpConn(conn net.Conn, typ string) *tcpConn {
|
|
|
- c := &tcpConn{conn, log.NewPrefixLogger(), rand.Int31(), typ}
|
|
|
+func wrapConn(conn net.Conn, typ string) *loggedConn {
|
|
|
+ c := &loggedConn{conn, log.NewPrefixLogger(), rand.Int31(), typ}
|
|
|
c.AddLogPrefix(c.Id())
|
|
|
return c
|
|
|
}
|
|
|
|
|
|
-func Listen(addr *net.TCPAddr, typ string, tlsCfg *tls.Config) (l *Listener, err error) {
|
|
|
+func Listen(addr, typ string, tlsCfg *tls.Config) (l *Listener, err error) {
|
|
|
// listen for incoming connections
|
|
|
- listener, err := net.ListenTCP("tcp", addr)
|
|
|
+ listener, err := net.Listen("tcp", addr)
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
l = &Listener{
|
|
|
- TCPAddr: listener.Addr().(*net.TCPAddr),
|
|
|
+ Addr: listener.Addr(),
|
|
|
Conns: make(chan Conn),
|
|
|
}
|
|
|
|
|
|
go func() {
|
|
|
for {
|
|
|
- tcpConn, err := listener.AcceptTCP()
|
|
|
+ rawConn, err := listener.Accept()
|
|
|
if err != nil {
|
|
|
log.Error("Failed to accept new TCP connection of type %s: %v", typ, err)
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- c := wrapTcpConn(tcpConn, typ)
|
|
|
+ c := wrapConn(rawConn, typ)
|
|
|
if tlsCfg != nil {
|
|
|
c.Conn = tls.Server(c.Conn, tlsCfg)
|
|
|
}
|
|
|
- c.Info("New connection from %v", tcpConn.RemoteAddr())
|
|
|
+ c.Info("New connection from %v", c.RemoteAddr())
|
|
|
l.Conns <- c
|
|
|
}
|
|
|
}()
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func Wrap(conn net.Conn, typ string) *tcpConn {
|
|
|
- return wrapTcpConn(conn, typ)
|
|
|
+func Wrap(conn net.Conn, typ string) *loggedConn {
|
|
|
+ return wrapConn(conn, typ)
|
|
|
}
|
|
|
|
|
|
-func Dial(addr, typ string, tlsCfg *tls.Config) (conn *tcpConn, err error) {
|
|
|
- var (
|
|
|
- tcpAddr *net.TCPAddr
|
|
|
- tcpConn net.Conn
|
|
|
- )
|
|
|
-
|
|
|
- if tcpAddr, err = net.ResolveTCPAddr("tcp", addr); err != nil {
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- if tcpConn, err = net.DialTCP("tcp", nil, tcpAddr); err != nil {
|
|
|
+func Dial(addr, typ string, tlsCfg *tls.Config) (conn *loggedConn, err error) {
|
|
|
+ var rawConn net.Conn
|
|
|
+ if rawConn, err = net.Dial("tcp", addr); err != nil {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
if tlsCfg != nil {
|
|
|
- tcpConn = tls.Client(tcpConn, tlsCfg)
|
|
|
+ rawConn = tls.Client(rawConn, tlsCfg)
|
|
|
}
|
|
|
|
|
|
- conn = wrapTcpConn(tcpConn, typ)
|
|
|
- conn.Debug("New connection to: %v", tcpAddr)
|
|
|
- return conn, nil
|
|
|
+ conn = wrapConn(rawConn, typ)
|
|
|
+ conn.Debug("New connection to: %v", rawConn.RemoteAddr())
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
-func (c *tcpConn) Close() error {
|
|
|
+func (c *loggedConn) Close() error {
|
|
|
c.Debug("Closing")
|
|
|
return c.Conn.Close()
|
|
|
}
|
|
|
|
|
|
-func (c *tcpConn) Id() string {
|
|
|
+func (c *loggedConn) Id() string {
|
|
|
return fmt.Sprintf("%s:%x", c.typ, c.id)
|
|
|
}
|
|
|
|
|
|
-func (c *tcpConn) SetType(typ string) {
|
|
|
+func (c *loggedConn) SetType(typ string) {
|
|
|
oldId := c.Id()
|
|
|
c.typ = typ
|
|
|
c.ClearLogPrefixes()
|
|
@@ -138,23 +130,23 @@ func Join(c Conn, c2 Conn) (int64, int64) {
|
|
|
}
|
|
|
|
|
|
type httpConn struct {
|
|
|
- *tcpConn
|
|
|
+ *loggedConn
|
|
|
reqBuf *bytes.Buffer
|
|
|
}
|
|
|
|
|
|
func NewHttp(conn net.Conn, typ string) *httpConn {
|
|
|
return &httpConn{
|
|
|
- wrapTcpConn(conn, typ),
|
|
|
+ wrapConn(conn, typ),
|
|
|
bytes.NewBuffer(make([]byte, 0, 1024)),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func (c *httpConn) ReadRequest() (*http.Request, error) {
|
|
|
- r := io.TeeReader(c.tcpConn, c.reqBuf)
|
|
|
+ r := io.TeeReader(c.loggedConn, c.reqBuf)
|
|
|
return http.ReadRequest(bufio.NewReader(r))
|
|
|
}
|
|
|
|
|
|
-func (c *tcpConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|
|
+func (c *loggedConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|
|
// special case when we're reading from an http request where
|
|
|
// we had to parse the request and consume bytes from the socket
|
|
|
// and store them in a temporary request buffer
|