control.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package server
  2. import (
  3. "io"
  4. "net"
  5. "ngrok/conn"
  6. "ngrok/proto"
  7. "runtime/debug"
  8. "sync/atomic"
  9. "time"
  10. )
  11. const (
  12. pingInterval = 30 * time.Second
  13. connReapInterval = pingInterval * 5
  14. )
  15. type Control struct {
  16. // actual connection
  17. conn conn.Conn
  18. // channels for communicating messages over the connection
  19. out chan (interface{})
  20. in chan (proto.Message)
  21. stop chan (proto.Message)
  22. // heartbeat
  23. lastPong int64
  24. // tunnel
  25. tun *Tunnel
  26. }
  27. func NewControl(tcpConn *net.TCPConn) {
  28. c := &Control{
  29. conn: conn.NewTCP(tcpConn, "ctl"),
  30. out: make(chan (interface{}), 1),
  31. in: make(chan (proto.Message), 1),
  32. stop: make(chan (proto.Message), 1),
  33. lastPong: time.Now().Unix(),
  34. }
  35. go c.managerThread()
  36. go c.readThread()
  37. }
  38. func (c *Control) managerThread() {
  39. ping := time.NewTicker(pingInterval)
  40. reap := time.NewTicker(connReapInterval)
  41. // all shutdown functionality in here
  42. defer func() {
  43. if err := recover(); err != nil {
  44. c.conn.Info("Control::managerThread failed with error %v: %s", err, debug.Stack())
  45. }
  46. ping.Stop()
  47. reap.Stop()
  48. c.conn.Close()
  49. // shutdown the tunnel if it's open
  50. if c.tun != nil {
  51. c.tun.shutdown()
  52. }
  53. }()
  54. for {
  55. select {
  56. case m := <-c.out:
  57. proto.WriteMsg(c.conn, m)
  58. case <-ping.C:
  59. proto.WriteMsg(c.conn, &proto.PingMsg{})
  60. case <-reap.C:
  61. if (time.Now().Unix() - c.lastPong) > 60 {
  62. c.conn.Info("Lost heartbeat")
  63. metrics.lostHeartbeatMeter.Mark(1)
  64. return
  65. }
  66. case m := <-c.stop:
  67. if m != nil {
  68. proto.WriteMsg(c.conn, m)
  69. }
  70. return
  71. case msg := <-c.in:
  72. switch msg.GetType() {
  73. case "RegMsg":
  74. c.conn.Info("Registering new tunnel")
  75. c.tun = newTunnel(msg.(*proto.RegMsg), c)
  76. case "PongMsg":
  77. atomic.StoreInt64(&c.lastPong, time.Now().Unix())
  78. case "VersionReqMsg":
  79. c.out <- &proto.VersionRespMsg{Version: version}
  80. }
  81. }
  82. }
  83. }
  84. func (c *Control) readThread() {
  85. defer func() {
  86. if err := recover(); err != nil {
  87. c.conn.Info("Control::readThread failed with error %v: %s", err, debug.Stack())
  88. }
  89. c.stop <- nil
  90. }()
  91. // read messages from the control channel
  92. for {
  93. if msg, err := proto.ReadMsg(c.conn); err != nil {
  94. if err == io.EOF {
  95. c.conn.Info("EOF")
  96. return
  97. } else {
  98. panic(err)
  99. }
  100. } else {
  101. c.in <- msg
  102. }
  103. }
  104. }