Browse Source

correctly handle virtual hosting on non default ports and report the correct public URL back to the client without modifying the protocol. add the VHOST environment variable for more flexibility around how virtual hosting is implemented

Alan Shreve 12 years ago
parent
commit
61dd957018

+ 0 - 1
src/ngrok/client/main.go

@@ -192,7 +192,6 @@ func control(s *State, ctl *ui.Controller) {
 
 
 	// update UI state
 	// update UI state
 	s.publicUrl = regAck.Url
 	s.publicUrl = regAck.Url
-	s.publicPort = regAck.Port
 	conn.Info("Tunnel established at %v", s.GetPublicUrl())
 	conn.Info("Tunnel established at %v", s.GetPublicUrl())
 	s.status = "online"
 	s.status = "online"
 	s.serverVersion = regAck.MmVersion
 	s.serverVersion = regAck.MmVersion

+ 1 - 10
src/ngrok/client/state.go

@@ -1,7 +1,6 @@
 package client
 package client
 
 
 import (
 import (
-	"fmt"
 	metrics "github.com/inconshreveable/go-metrics"
 	metrics "github.com/inconshreveable/go-metrics"
 	"ngrok/client/ui"
 	"ngrok/client/ui"
 	"ngrok/proto"
 	"ngrok/proto"
@@ -12,7 +11,6 @@ import (
 type State struct {
 type State struct {
 	id            string
 	id            string
 	publicUrl     string
 	publicUrl     string
-	publicPort    int
 	serverVersion string
 	serverVersion string
 	update        ui.UpdateStatus
 	update        ui.UpdateStatus
 	protocol      proto.Protocol
 	protocol      proto.Protocol
@@ -31,14 +29,7 @@ func (s State) GetWebPort() int             { return s.opts.webport }
 func (s State) GetStatus() string           { return s.status }
 func (s State) GetStatus() string           { return s.status }
 func (s State) GetProtocol() proto.Protocol { return s.protocol }
 func (s State) GetProtocol() proto.Protocol { return s.protocol }
 func (s State) GetUpdate() ui.UpdateStatus  { return s.update }
 func (s State) GetUpdate() ui.UpdateStatus  { return s.update }
-
-func (s State) GetPublicUrl() string {
-	publicUrl := s.publicUrl
-	if s.publicPort != 80 && s.publicPort != 443 {
-		publicUrl += fmt.Sprintf(":%d", s.publicPort)
-	}
-	return publicUrl
-}
+func (s State) GetPublicUrl() string        { return s.publicUrl }
 
 
 func (s State) GetConnectionMetrics() (metrics.Meter, metrics.Timer) {
 func (s State) GetConnectionMetrics() (metrics.Meter, metrics.Timer) {
 	return s.metrics.connMeter, s.metrics.connTimer
 	return s.metrics.connMeter, s.metrics.connTimer

+ 0 - 1
src/ngrok/msg/msg.go

@@ -46,7 +46,6 @@ type RegAckMsg struct {
 	Version   string
 	Version   string
 	MmVersion string
 	MmVersion string
 	Url       string
 	Url       string
-	Port      int
 	ProxyAddr string
 	ProxyAddr string
 	Error     string
 	Error     string
 }
 }

+ 4 - 3
src/ngrok/server/http.go

@@ -5,7 +5,6 @@ import (
 	"net"
 	"net"
 	"ngrok/conn"
 	"ngrok/conn"
 	"ngrok/log"
 	"ngrok/log"
-	"strings"
 )
 )
 
 
 const (
 const (
@@ -68,9 +67,11 @@ func httpHandler(tcpConn net.Conn) {
 		return
 		return
 	}
 	}
 
 
-	// multiplex to find the right backend host
-	host := strings.Split(req.Host, ":")[0]
+	// read out the Host header from the request
+	host := req.Host
 	conn.Debug("Found hostname %s in request", host)
 	conn.Debug("Found hostname %s in request", host)
+
+	// multiplex to find the right backend host
 	tunnel := tunnels.Get("http://" + host)
 	tunnel := tunnels.Get("http://" + host)
 	if tunnel == nil {
 	if tunnel == nil {
 		conn.Info("No tunnel found for hostname %s", host)
 		conn.Info("No tunnel found for hostname %s", host)

+ 13 - 3
src/ngrok/server/tunnel.go

@@ -9,6 +9,7 @@ import (
 	"ngrok/log"
 	"ngrok/log"
 	"ngrok/msg"
 	"ngrok/msg"
 	"ngrok/version"
 	"ngrok/version"
+	"os"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"sync/atomic"
 	"sync/atomic"
@@ -107,10 +108,20 @@ func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
 		go t.listenTcp(t.listener)
 		go t.listenTcp(t.listener)
 
 
 	case "http":
 	case "http":
+		vhost := os.Getenv("VHOST")
+		if vhost == "" {
+			vhost = fmt.Sprintf("%s:%d", domain, publicPort)
+		}
+
+		// Canonicalize virtual host on default port 80
+		if strings.HasSuffix(vhost, ":80") {
+			vhost = vhost[0 : len(vhost)-3]
+		}
+
 		if strings.TrimSpace(t.regMsg.Hostname) != "" {
 		if strings.TrimSpace(t.regMsg.Hostname) != "" {
 			t.url = fmt.Sprintf("http://%s", t.regMsg.Hostname)
 			t.url = fmt.Sprintf("http://%s", t.regMsg.Hostname)
 		} else if strings.TrimSpace(t.regMsg.Subdomain) != "" {
 		} else if strings.TrimSpace(t.regMsg.Subdomain) != "" {
-			t.url = fmt.Sprintf("http://%s.%s", t.regMsg.Subdomain, domain)
+			t.url = fmt.Sprintf("http://%s.%s", t.regMsg.Subdomain, vhost)
 		}
 		}
 
 
 		if t.url != "" {
 		if t.url != "" {
@@ -120,7 +131,7 @@ func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
 			}
 			}
 		} else {
 		} else {
 			t.url, err = tunnels.RegisterRepeat(func() string {
 			t.url, err = tunnels.RegisterRepeat(func() string {
-				return fmt.Sprintf("http://%x.%s", rand.Int31(), domain)
+				return fmt.Sprintf("http://%x.%s", rand.Int31(), vhost)
 			}, t)
 			}, t)
 
 
 			if err != nil {
 			if err != nil {
@@ -145,7 +156,6 @@ func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
 	t.Info("Registered new tunnel")
 	t.Info("Registered new tunnel")
 	t.ctl.out <- &msg.RegAckMsg{
 	t.ctl.out <- &msg.RegAckMsg{
 		Url:       t.url,
 		Url:       t.url,
-		Port:      publicPort,
 		ProxyAddr: fmt.Sprintf("%s", proxyAddr),
 		ProxyAddr: fmt.Sprintf("%s", proxyAddr),
 		Version:   version.Proto,
 		Version:   version.Proto,
 		MmVersion: version.MajorMinor(),
 		MmVersion: version.MajorMinor(),