Browse Source

correctly handle virtual hosting on non default ports and report the correct public URL back to the client without modifying the protocol. add the VHOST environment variable for more flexibility around how virtual hosting is implemented

Alan Shreve 12 years ago
parent
commit
61dd957018
5 changed files with 18 additions and 18 deletions
  1. 0 1
      src/ngrok/client/main.go
  2. 1 10
      src/ngrok/client/state.go
  3. 0 1
      src/ngrok/msg/msg.go
  4. 4 3
      src/ngrok/server/http.go
  5. 13 3
      src/ngrok/server/tunnel.go

+ 0 - 1
src/ngrok/client/main.go

@@ -192,7 +192,6 @@ func control(s *State, ctl *ui.Controller) {
 
 	// update UI state
 	s.publicUrl = regAck.Url
-	s.publicPort = regAck.Port
 	conn.Info("Tunnel established at %v", s.GetPublicUrl())
 	s.status = "online"
 	s.serverVersion = regAck.MmVersion

+ 1 - 10
src/ngrok/client/state.go

@@ -1,7 +1,6 @@
 package client
 
 import (
-	"fmt"
 	metrics "github.com/inconshreveable/go-metrics"
 	"ngrok/client/ui"
 	"ngrok/proto"
@@ -12,7 +11,6 @@ import (
 type State struct {
 	id            string
 	publicUrl     string
-	publicPort    int
 	serverVersion string
 	update        ui.UpdateStatus
 	protocol      proto.Protocol
@@ -31,14 +29,7 @@ func (s State) GetWebPort() int             { return s.opts.webport }
 func (s State) GetStatus() string           { return s.status }
 func (s State) GetProtocol() proto.Protocol { return s.protocol }
 func (s State) GetUpdate() ui.UpdateStatus  { return s.update }
-
-func (s State) GetPublicUrl() string {
-	publicUrl := s.publicUrl
-	if s.publicPort != 80 && s.publicPort != 443 {
-		publicUrl += fmt.Sprintf(":%d", s.publicPort)
-	}
-	return publicUrl
-}
+func (s State) GetPublicUrl() string        { return s.publicUrl }
 
 func (s State) GetConnectionMetrics() (metrics.Meter, metrics.Timer) {
 	return s.metrics.connMeter, s.metrics.connTimer

+ 0 - 1
src/ngrok/msg/msg.go

@@ -46,7 +46,6 @@ type RegAckMsg struct {
 	Version   string
 	MmVersion string
 	Url       string
-	Port      int
 	ProxyAddr string
 	Error     string
 }

+ 4 - 3
src/ngrok/server/http.go

@@ -5,7 +5,6 @@ import (
 	"net"
 	"ngrok/conn"
 	"ngrok/log"
-	"strings"
 )
 
 const (
@@ -68,9 +67,11 @@ func httpHandler(tcpConn net.Conn) {
 		return
 	}
 
-	// multiplex to find the right backend host
-	host := strings.Split(req.Host, ":")[0]
+	// read out the Host header from the request
+	host := req.Host
 	conn.Debug("Found hostname %s in request", host)
+
+	// multiplex to find the right backend host
 	tunnel := tunnels.Get("http://" + host)
 	if tunnel == nil {
 		conn.Info("No tunnel found for hostname %s", host)

+ 13 - 3
src/ngrok/server/tunnel.go

@@ -9,6 +9,7 @@ import (
 	"ngrok/log"
 	"ngrok/msg"
 	"ngrok/version"
+	"os"
 	"strconv"
 	"strings"
 	"sync/atomic"
@@ -107,10 +108,20 @@ func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
 		go t.listenTcp(t.listener)
 
 	case "http":
+		vhost := os.Getenv("VHOST")
+		if vhost == "" {
+			vhost = fmt.Sprintf("%s:%d", domain, publicPort)
+		}
+
+		// Canonicalize virtual host on default port 80
+		if strings.HasSuffix(vhost, ":80") {
+			vhost = vhost[0 : len(vhost)-3]
+		}
+
 		if strings.TrimSpace(t.regMsg.Hostname) != "" {
 			t.url = fmt.Sprintf("http://%s", t.regMsg.Hostname)
 		} else if strings.TrimSpace(t.regMsg.Subdomain) != "" {
-			t.url = fmt.Sprintf("http://%s.%s", t.regMsg.Subdomain, domain)
+			t.url = fmt.Sprintf("http://%s.%s", t.regMsg.Subdomain, vhost)
 		}
 
 		if t.url != "" {
@@ -120,7 +131,7 @@ func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
 			}
 		} else {
 			t.url, err = tunnels.RegisterRepeat(func() string {
-				return fmt.Sprintf("http://%x.%s", rand.Int31(), domain)
+				return fmt.Sprintf("http://%x.%s", rand.Int31(), vhost)
 			}, t)
 
 			if err != nil {
@@ -145,7 +156,6 @@ func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
 	t.Info("Registered new tunnel")
 	t.ctl.out <- &msg.RegAckMsg{
 		Url:       t.url,
-		Port:      publicPort,
 		ProxyAddr: fmt.Sprintf("%s", proxyAddr),
 		Version:   version.Proto,
 		MmVersion: version.MajorMinor(),