Browse Source

add support for registering multiple protocols in a single tunnel request message

Alan Shreve 12 years ago
parent
commit
1fa6bb644e
4 changed files with 34 additions and 40 deletions
  1. 3 5
      src/ngrok/client/cli.go
  2. 0 11
      src/ngrok/client/model.go
  3. 29 21
      src/ngrok/server/control.go
  4. 2 3
      src/ngrok/server/tunnel.go

+ 3 - 5
src/ngrok/client/cli.go

@@ -81,9 +81,7 @@ func parseLocalAddr() string {
 
 func parseProtocol(proto string) string {
 	switch proto {
-	case "http":
-		fallthrough
-	case "tcp":
+	case "http", "https", "http+https", "tcp":
 		return proto
 	default:
 		fail("%s is not a valid protocol", proto)
@@ -119,8 +117,8 @@ func parseArgs() *Options {
 
 	protocol := flag.String(
 		"proto",
-		"http",
-		"The protocol of the traffic over the tunnel {'http', 'tcp'} (default: 'http')")
+		"http+https",
+		"The protocol of the traffic over the tunnel {'http', 'https', 'tcp'} (default: 'http+https')")
 
 	webport := flag.Int(
 		"webport",

+ 0 - 11
src/ngrok/client/model.go

@@ -210,17 +210,6 @@ func (c *ClientModel) control(reqTunnel *msg.ReqTunnel, localaddr string) {
 		panic(err)
 	}
 
-	// register an https tunnel as well for http tunnels
-	if reqTunnel.Protocol == "http" {
-		httpsReqTunnel := *reqTunnel
-		httpsReqTunnel.Protocol = "https"
-		// httpsReqTunnel.ReqId =
-
-		if err = msg.WriteMsg(conn, &httpsReqTunnel); err != nil {
-			panic(err)
-		}
-	}
-
 	// start the heartbeat
 	lastPong := time.Now().UnixNano()
 	c.ctl.Go(func() { c.heartbeat(&lastPong, conn) })

+ 29 - 21
src/ngrok/server/control.go

@@ -8,6 +8,7 @@ import (
 	"ngrok/util"
 	"ngrok/version"
 	"runtime/debug"
+	"strings"
 	"sync/atomic"
 	"time"
 )
@@ -110,31 +111,38 @@ func NewControl(ctlConn conn.Conn, authMsg *msg.Auth) {
 }
 
 // Register a new tunnel on this control connection
-func (c *Control) registerTunnel(reqTunnel *msg.ReqTunnel) {
-	c.conn.Debug("Registering new tunnel")
-	t, err := NewTunnel(reqTunnel, c)
-	if err != nil {
-		ack := &msg.NewTunnel{Error: err.Error()}
-		if len(c.tunnels) == 0 {
-			// you can't fail your first tunnel registration
-			// terminate the control connection
-			c.stop <- ack
-		} else {
-			// inform client of failure
-			c.out <- ack
+func (c *Control) registerTunnel(rawTunnelReq *msg.ReqTunnel) {
+	for _, proto := range strings.Split(rawTunnelReq.Protocol, "+") {
+		tunnelReq := *rawTunnelReq
+		tunnelReq.Protocol = proto
+
+		c.conn.Debug("Registering new tunnel")
+		t, err := NewTunnel(&tunnelReq, c)
+		if err != nil {
+			ack := &msg.NewTunnel{Error: err.Error()}
+			if len(c.tunnels) == 0 {
+				// you can't fail your first tunnel registration
+				// terminate the control connection
+				c.stop <- ack
+			} else {
+				// inform client of failure
+				c.out <- ack
+			}
+
+			// we're done
+			return
 		}
 
-		// we're done
-		return
-	}
+		// add it to the list of tunnels
+		c.tunnels = append(c.tunnels, t)
 
-	// add it to the list of tunnels
-	c.tunnels = append(c.tunnels, t)
+		// acknowledge success
+		c.out <- &msg.NewTunnel{
+			Url:      t.url,
+			Protocol: proto,
+		}
 
-	// acknowledge success
-	c.out <- &msg.NewTunnel{
-		Url:      t.url,
-		Protocol: reqTunnel.Protocol,
+		rawTunnelReq.Hostname = strings.Replace(t.url, proto+"://", "", 1)
 	}
 }
 

+ 2 - 3
src/ngrok/server/tunnel.go

@@ -68,17 +68,16 @@ func registerVhost(t *Tunnel, protocol string, servingPort int) (err error) {
 
 	// Canonicalize by always using lower-case
 	vhost = strings.ToLower(vhost)
-	t.url = strings.ToLower(t.url)
 
 	// Register for specific hostname
-	hostname := strings.TrimSpace(t.req.Hostname)
+	hostname := strings.ToLower(strings.TrimSpace(t.req.Hostname))
 	if hostname != "" {
 		t.url = fmt.Sprintf("%s://%s", protocol, hostname)
 		return tunnelRegistry.Register(t.url, t)
 	}
 
 	// Register for specific subdomain
-	subdomain := strings.TrimSpace(t.req.Subdomain)
+	subdomain := strings.ToLower(strings.TrimSpace(t.req.Subdomain))
 	if subdomain != "" {
 		t.url = fmt.Sprintf("%s://%s.%s", protocol, subdomain, vhost)
 		return tunnelRegistry.Register(t.url, t)