conn.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. package conn
  2. import (
  3. "bufio"
  4. "bytes"
  5. "fmt"
  6. "io"
  7. "math/rand"
  8. "net"
  9. "net/http"
  10. "ngrok"
  11. )
  12. type Conn interface {
  13. net.Conn
  14. ngrok.Logger
  15. Id() string
  16. }
  17. type loggedConn struct {
  18. net.Conn
  19. ngrok.Logger
  20. id int32
  21. typ string
  22. }
  23. func NewLogged(conn net.Conn, typ string) *loggedConn {
  24. c := &loggedConn{conn, ngrok.NewPrefixLogger(), rand.Int31(), typ}
  25. c.AddLogPrefix(c.Id())
  26. c.Info("New connection from %v", conn.RemoteAddr())
  27. return c
  28. }
  29. func (c *loggedConn) Close() error {
  30. c.Debug("Closing")
  31. return c.Conn.Close()
  32. }
  33. func (c *loggedConn) Id() string {
  34. return fmt.Sprintf("%s:%x", c.typ, c.id)
  35. }
  36. func Join(c Conn, c2 Conn) (int64, int64) {
  37. done := make(chan error)
  38. pipe := func(to Conn, from Conn, bytesCopied *int64) {
  39. var err error
  40. *bytesCopied, err = io.Copy(to, from)
  41. if err != nil {
  42. from.Warn("Copied %d bytes to %s before failing with error %v", *bytesCopied, to.Id(), err)
  43. done <- err
  44. } else {
  45. from.Debug("Copied %d bytes from to %s", *bytesCopied, to.Id())
  46. done <- nil
  47. }
  48. }
  49. var fromBytes, toBytes int64
  50. go pipe(c, c2, &fromBytes)
  51. go pipe(c2, c, &toBytes)
  52. c.Info("Joined with connection %s", c2.Id())
  53. <-done
  54. c.Close()
  55. c2.Close()
  56. <-done
  57. return fromBytes, toBytes
  58. }
  59. type loggedHttpConn struct {
  60. *loggedConn
  61. reqBuf *bytes.Buffer
  62. }
  63. func NewHttp(conn net.Conn, typ string) *loggedHttpConn {
  64. return &loggedHttpConn{
  65. NewLogged(conn, typ),
  66. bytes.NewBuffer(make([]byte, 0, 1024)),
  67. }
  68. }
  69. func (c *loggedHttpConn) ReadRequest() (*http.Request, error) {
  70. r := io.TeeReader(c.loggedConn, c.reqBuf)
  71. return http.ReadRequest(bufio.NewReader(r))
  72. }
  73. func (c *loggedConn) ReadFrom(r io.Reader) (n int64, err error) {
  74. // special case when we're reading from an http request where
  75. // we had to parse the request and consume bytes from the socket
  76. // and store them in a temporary request buffer
  77. if httpConn, ok := r.(*loggedHttpConn); ok {
  78. if n, err = httpConn.reqBuf.WriteTo(c); err != nil {
  79. return
  80. }
  81. }
  82. nCopied, err := io.Copy(c.Conn, r)
  83. n += nCopied
  84. return
  85. }