Browse Source

Handle errors creating tunnels more gracefully and allow the client to cleanly shutdown and print an error message

Alan Shreve 12 years ago
parent
commit
5d2f48f285
8 changed files with 37 additions and 22 deletions
  1. 10 1
      client/main.go
  2. 1 1
      client/ui/terminal.go
  3. 4 1
      client/ui/ui.go
  4. 1 0
      proto/msg.go
  5. 7 7
      server/control.go
  6. 1 1
      server/http.go
  7. 6 4
      server/manager.go
  8. 7 7
      server/tunnel.go

+ 10 - 1
client/main.go

@@ -119,6 +119,12 @@ func control(s *State) {
 		panic(err)
 		panic(err)
 	}
 	}
 
 
+	if regAck.Error != "" {
+		emsg := fmt.Sprintf("Server failed to allocate tunnel: %s", regAck.Error)
+		s.ui.Cmds <- ui.Command{ui.QUIT, emsg}
+		return
+	}
+
 	// update UI state
 	// update UI state
 	conn.Info("Tunnel established at %v", regAck.Url)
 	conn.Info("Tunnel established at %v", regAck.Url)
 	//state.version = regAck.Version
 	//state.version = regAck.Version
@@ -195,14 +201,16 @@ func Main() {
 
 
 	go control(s)
 	go control(s)
 
 
+	quitMessage := ""
 	s.ui.Wait.Add(1)
 	s.ui.Wait.Add(1)
 	go func() {
 	go func() {
 		defer s.ui.Wait.Done()
 		defer s.ui.Wait.Done()
 		for {
 		for {
 			select {
 			select {
 			case cmd := <-s.ui.Cmds:
 			case cmd := <-s.ui.Cmds:
-				switch cmd {
+				switch cmd.Code {
 				case ui.QUIT:
 				case ui.QUIT:
+					quitMessage = cmd.Payload.(string)
 					s.stopping = true
 					s.stopping = true
 					s.Update()
 					s.Update()
 					return
 					return
@@ -212,4 +220,5 @@ func Main() {
 	}()
 	}()
 
 
 	s.ui.Wait.Wait()
 	s.ui.Wait.Wait()
+	fmt.Println(quitMessage)
 }
 }

+ 1 - 1
client/ui/terminal.go

@@ -128,7 +128,7 @@ func (t *Term) input() {
 		case termbox.EventKey:
 		case termbox.EventKey:
 			switch ev.Key {
 			switch ev.Key {
 			case termbox.KeyCtrlC:
 			case termbox.KeyCtrlC:
-				t.ui.Cmds <- QUIT
+				t.ui.Cmds <- Command{QUIT, ""}
 				return
 				return
 			}
 			}
 		}
 		}

+ 4 - 1
client/ui/ui.go

@@ -4,7 +4,10 @@ import (
 	"sync"
 	"sync"
 )
 )
 
 
-type Command int
+type Command struct {
+	Code    int
+	Payload interface{}
+}
 
 
 const (
 const (
 	QUIT = iota
 	QUIT = iota

+ 1 - 0
proto/msg.go

@@ -61,6 +61,7 @@ type RegAckMsg struct {
 	Type      string
 	Type      string
 	Url       string
 	Url       string
 	ProxyAddr string
 	ProxyAddr string
+	Error     string
 }
 }
 
 
 type RegProxyMsg struct {
 type RegProxyMsg struct {

+ 7 - 7
server/control.go

@@ -22,7 +22,7 @@ type Control struct {
 	// channels for communicating messages over the connection
 	// channels for communicating messages over the connection
 	out  chan (interface{})
 	out  chan (interface{})
 	in   chan (proto.Message)
 	in   chan (proto.Message)
-	stop chan (int)
+	stop chan (proto.Message)
 
 
 	// heartbeat
 	// heartbeat
 	lastPong int64
 	lastPong int64
@@ -36,7 +36,7 @@ func NewControl(tcpConn *net.TCPConn) {
 		conn:     conn.NewTCP(tcpConn, "ctl"),
 		conn:     conn.NewTCP(tcpConn, "ctl"),
 		out:      make(chan (interface{}), 1),
 		out:      make(chan (interface{}), 1),
 		in:       make(chan (proto.Message), 1),
 		in:       make(chan (proto.Message), 1),
-		stop:     make(chan (int), 1),
+		stop:     make(chan (proto.Message), 1),
 		lastPong: time.Now().Unix(),
 		lastPong: time.Now().Unix(),
 	}
 	}
 
 
@@ -55,9 +55,6 @@ func (c *Control) managerThread() {
 		}
 		}
 		ping.Stop()
 		ping.Stop()
 		reap.Stop()
 		reap.Stop()
-		close(c.out)
-		close(c.in)
-		close(c.stop)
 		c.conn.Close()
 		c.conn.Close()
 	}()
 	}()
 
 
@@ -76,7 +73,10 @@ func (c *Control) managerThread() {
 				return
 				return
 			}
 			}
 
 
-		case <-c.stop:
+		case m := <-c.stop:
+			if m != nil {
+				proto.WriteMsg(c.conn, m)
+			}
 			return
 			return
 
 
 		case msg := <-c.in:
 		case msg := <-c.in:
@@ -100,7 +100,7 @@ func (c *Control) readThread() {
 		if err := recover(); err != nil {
 		if err := recover(); err != nil {
 			c.conn.Info("Control::readThread failed with error %v: %s", err, debug.Stack())
 			c.conn.Info("Control::readThread failed with error %v: %s", err, debug.Stack())
 		}
 		}
-		c.stop <- 1
+		c.stop <- nil
 	}()
 	}()
 
 
 	// read messages from the control channel
 	// read messages from the control channel

+ 1 - 1
server/http.go

@@ -55,7 +55,7 @@ func httpHandler(tcpConn net.Conn) {
 	tunnel := tunnels.Get("http://" + req.Host)
 	tunnel := tunnels.Get("http://" + req.Host)
 
 
 	if tunnel == nil {
 	if tunnel == nil {
-		conn.Info("Not tunnel found for hostname %s", req.Host)
+		conn.Info("No tunnel found for hostname %s", req.Host)
 		return
 		return
 	}
 	}
 
 

+ 6 - 4
server/manager.go

@@ -28,7 +28,7 @@ func NewTunnelManager(domain string) *TunnelManager {
 	}
 	}
 }
 }
 
 
-func (m *TunnelManager) Add(t *Tunnel) {
+func (m *TunnelManager) Add(t *Tunnel) error {
 	assignTunnel := func(url string) bool {
 	assignTunnel := func(url string) bool {
 		m.Lock()
 		m.Lock()
 		defer m.Unlock()
 		defer m.Unlock()
@@ -47,7 +47,7 @@ func (m *TunnelManager) Add(t *Tunnel) {
 		addr := t.listener.Addr().(*net.TCPAddr)
 		addr := t.listener.Addr().(*net.TCPAddr)
 		url = fmt.Sprintf("tcp://%s:%d", m.domain, addr.Port)
 		url = fmt.Sprintf("tcp://%s:%d", m.domain, addr.Port)
 		if !assignTunnel(url) {
 		if !assignTunnel(url) {
-			panic("TCP at %s already registered!")
+			return t.Error("TCP at %s already registered!", url)
 		}
 		}
 		metrics.tcpTunnelMeter.Mark(1)
 		metrics.tcpTunnelMeter.Mark(1)
 
 
@@ -60,7 +60,7 @@ func (m *TunnelManager) Add(t *Tunnel) {
 
 
 		if url != "" {
 		if url != "" {
 			if !assignTunnel(url) {
 			if !assignTunnel(url) {
-				panic(fmt.Sprintf("The tunnel address %s is already registered!", url))
+				return t.Warn("The tunnel address %s is already registered!", url)
 			}
 			}
 		} else {
 		} else {
 			// try to give the same subdomain back if it's available
 			// try to give the same subdomain back if it's available
@@ -85,7 +85,7 @@ func (m *TunnelManager) Add(t *Tunnel) {
 		}
 		}
 
 
 	default:
 	default:
-		panic(t.Error("Unrecognized protocol type %s", t.regMsg.Protocol))
+		return t.Error("Unrecognized protocol type %s", t.regMsg.Protocol)
 	}
 	}
 
 
 	t.url = url
 	t.url = url
@@ -102,6 +102,8 @@ func (m *TunnelManager) Add(t *Tunnel) {
 	default:
 	default:
 		metrics.otherCounter.Inc(1)
 		metrics.otherCounter.Inc(1)
 	}
 	}
+
+	return nil
 }
 }
 
 
 func (m *TunnelManager) Del(url string) {
 func (m *TunnelManager) Del(url string) {

+ 7 - 7
server/tunnel.go

@@ -53,30 +53,30 @@ func newTunnel(msg *proto.RegMsg, ctl *Control) {
 	default:
 	default:
 	}
 	}
 
 
-	tunnels.Add(t)
+	if err := tunnels.Add(t); err != nil {
+		t.ctl.stop <- &proto.RegAckMsg{Error: fmt.Sprint(err)}
+		return
+	}
+
 	t.ctl.conn.AddLogPrefix(t.Id())
 	t.ctl.conn.AddLogPrefix(t.Id())
 	t.AddLogPrefix(t.Id())
 	t.AddLogPrefix(t.Id())
 	t.Info("Registered new tunnel")
 	t.Info("Registered new tunnel")
 	t.ctl.out <- &proto.RegAckMsg{Url: t.url, ProxyAddr: fmt.Sprintf("%s", proxyAddr)}
 	t.ctl.out <- &proto.RegAckMsg{Url: t.url, ProxyAddr: fmt.Sprintf("%s", proxyAddr)}
-
-	//go t.managerThread()
 }
 }
 
 
 func (t *Tunnel) shutdown() {
 func (t *Tunnel) shutdown() {
+	// XXX: this is completely unused right now
 	t.Info("Shutting down")
 	t.Info("Shutting down")
 	// stop any go routines
 	// stop any go routines
 	// close all proxy and public connections
 	// close all proxy and public connections
 	// stop any metrics
 	// stop any metrics
-	t.ctl.stop <- 1
+	t.ctl.stop <- nil
 }
 }
 
 
 func (t *Tunnel) Id() string {
 func (t *Tunnel) Id() string {
 	return t.url
 	return t.url
 }
 }
 
 
-func (t *Tunnel) managerThread() {
-}
-
 /**
 /**
  * Listens for new public tcp connections from the internet.
  * Listens for new public tcp connections from the internet.
  */
  */