|
@@ -16,6 +16,12 @@ import (
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
+var defaultPortMap = map[string]int{
|
|
|
+ "http": 80,
|
|
|
+ "https": 443,
|
|
|
+ "smtp": 25,
|
|
|
+}
|
|
|
+
|
|
|
/**
|
|
|
* Tunnel: A control connection, metadata and proxy connections which
|
|
|
* route public traffic to a firewalled endpoint.
|
|
@@ -45,21 +51,57 @@ type Tunnel struct {
|
|
|
closing int32
|
|
|
}
|
|
|
|
|
|
-func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
|
|
|
+// Common functionality for registering virtually hosted protocols
|
|
|
+func registerVhost(t *Tunnel, protocol string, servingPort int) (err error) {
|
|
|
+ vhost := os.Getenv("VHOST")
|
|
|
+ if vhost == "" {
|
|
|
+ vhost = fmt.Sprintf("%s:%d", opts.domain, servingPort)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Canonicalize virtual host by removing default port (e.g. :80 on HTTP)
|
|
|
+ defaultPort, ok := defaultPortMap[protocol]
|
|
|
+ if !ok {
|
|
|
+ return fmt.Errorf("Couldn't find default port for protocol %s", protocol)
|
|
|
+ }
|
|
|
+
|
|
|
+ defaultPortSuffix := fmt.Sprintf(":%d", defaultPort)
|
|
|
+ if strings.HasSuffix(vhost, defaultPortSuffix) {
|
|
|
+ vhost = vhost[0 : len(vhost)-len(defaultPortSuffix)]
|
|
|
+ }
|
|
|
+
|
|
|
+ // Register for specific hostname
|
|
|
+ hostname := strings.TrimSpace(t.regMsg.Hostname)
|
|
|
+ if hostname != "" {
|
|
|
+ t.url = fmt.Sprintf("%s://%s", protocol, hostname)
|
|
|
+ return tunnels.Register(t.url, t)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Register for specific subdomain
|
|
|
+ subdomain := strings.TrimSpace(t.regMsg.Subdomain)
|
|
|
+ if subdomain != "" {
|
|
|
+ t.url = fmt.Sprintf("%s://%s.%s", protocol, subdomain, vhost)
|
|
|
+ return tunnels.Register(t.url, t)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Register for random URL
|
|
|
+ t.url, err = tunnels.RegisterRepeat(func() string {
|
|
|
+ return fmt.Sprintf("%s://%x.%s", protocol, rand.Int31(), vhost)
|
|
|
+ }, t)
|
|
|
+
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
+// Create a new tunnel from aregistration message received
|
|
|
+// on a control channel
|
|
|
+func NewTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel, err error) {
|
|
|
t = &Tunnel{
|
|
|
regMsg: m,
|
|
|
start: time.Now(),
|
|
|
ctl: ctl,
|
|
|
- proxies: make(chan conn.Conn),
|
|
|
+ proxies: make(chan conn.Conn, 10),
|
|
|
Logger: log.NewPrefixLogger(),
|
|
|
}
|
|
|
|
|
|
- failReg := func(err error) {
|
|
|
- t.ctl.stop <- &msg.RegAckMsg{Error: err.Error()}
|
|
|
- }
|
|
|
-
|
|
|
- var err error
|
|
|
-
|
|
|
switch t.regMsg.Protocol {
|
|
|
case "tcp":
|
|
|
var port int = 0
|
|
@@ -88,7 +130,7 @@ func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
|
|
|
|
|
|
// we tried to bind with a random port and failed (no more ports available?)
|
|
|
if err != nil {
|
|
|
- failReg(t.ctl.conn.Error("Error binding TCP listener: %v", err))
|
|
|
+ err = t.ctl.conn.Error("Error binding TCP listener: %v", err)
|
|
|
return
|
|
|
}
|
|
|
|
|
@@ -101,48 +143,25 @@ func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
|
|
|
// This should never be possible because the OS will
|
|
|
// only assign available ports to us.
|
|
|
t.listener.Close()
|
|
|
- failReg(fmt.Errorf("TCP listener bound, but failed to register %s", t.url))
|
|
|
+ err = fmt.Errorf("TCP listener bound, but failed to register %s", t.url)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
go t.listenTcp(t.listener)
|
|
|
|
|
|
case "http":
|
|
|
- vhost := os.Getenv("VHOST")
|
|
|
- if vhost == "" {
|
|
|
- vhost = fmt.Sprintf("%s:%d", domain, publicPort)
|
|
|
- }
|
|
|
-
|
|
|
- // Canonicalize virtual host on default port 80
|
|
|
- if strings.HasSuffix(vhost, ":80") {
|
|
|
- vhost = vhost[0 : len(vhost)-3]
|
|
|
- }
|
|
|
-
|
|
|
- if strings.TrimSpace(t.regMsg.Hostname) != "" {
|
|
|
- t.url = fmt.Sprintf("http://%s", t.regMsg.Hostname)
|
|
|
- } else if strings.TrimSpace(t.regMsg.Subdomain) != "" {
|
|
|
- t.url = fmt.Sprintf("http://%s.%s", t.regMsg.Subdomain, vhost)
|
|
|
+ if err = registerVhost(t, "http", opts.httpPort); err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
- if t.url != "" {
|
|
|
- if err := tunnels.Register(t.url, t); err != nil {
|
|
|
- failReg(err)
|
|
|
- return
|
|
|
- }
|
|
|
- } else {
|
|
|
- t.url, err = tunnels.RegisterRepeat(func() string {
|
|
|
- return fmt.Sprintf("http://%x.%s", rand.Int31(), vhost)
|
|
|
- }, t)
|
|
|
-
|
|
|
- if err != nil {
|
|
|
- failReg(err)
|
|
|
- return
|
|
|
- }
|
|
|
+ case "https":
|
|
|
+ if err = registerVhost(t, "https", opts.httpsPort); err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
}
|
|
|
|
|
|
if m.Version != version.Proto {
|
|
|
- failReg(fmt.Errorf("Incompatible versions. Server %s, client %s. Download a new version at http://ngrok.com", version.MajorMinor(), m.Version))
|
|
|
+ err = fmt.Errorf("Incompatible versions. Server %s, client %s. Download a new version at http://ngrok.com", version.MajorMinor(), m.Version)
|
|
|
return
|
|
|
}
|
|
|
|
|
@@ -151,15 +170,8 @@ func newTunnel(m *msg.RegMsg, ctl *Control) (t *Tunnel) {
|
|
|
m.HttpAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(m.HttpAuth))
|
|
|
}
|
|
|
|
|
|
- t.ctl.conn.AddLogPrefix(t.Id())
|
|
|
t.AddLogPrefix(t.Id())
|
|
|
- t.Info("Registered new tunnel")
|
|
|
- t.ctl.out <- &msg.RegAckMsg{
|
|
|
- Url: t.url,
|
|
|
- ProxyAddr: fmt.Sprintf("%s", proxyAddr),
|
|
|
- Version: version.Proto,
|
|
|
- MmVersion: version.MajorMinor(),
|
|
|
- }
|
|
|
+ t.Info("Registered new tunnel on: %s", t.ctl.conn.Id())
|
|
|
|
|
|
metrics.OpenTunnel(t)
|
|
|
return
|
|
@@ -171,7 +183,7 @@ func (t *Tunnel) shutdown() {
|
|
|
// mark that we're shutting down
|
|
|
atomic.StoreInt32(&t.closing, 1)
|
|
|
|
|
|
- // if we have a public listener (this is a raw TCP tunnel, shut it down
|
|
|
+ // if we have a public listener (this is a raw TCP tunnel), shut it down
|
|
|
if t.listener != nil {
|
|
|
t.listener.Close()
|
|
|
}
|
|
@@ -179,7 +191,19 @@ func (t *Tunnel) shutdown() {
|
|
|
// remove ourselves from the tunnel registry
|
|
|
tunnels.Del(t.url)
|
|
|
|
|
|
- // XXX: shut down all of the proxy connections?
|
|
|
+ // let the control connection know we're shutting down
|
|
|
+ // currently, only the control connection shuts down tunnels,
|
|
|
+ // so it doesn't need to know about it
|
|
|
+ // t.ctl.stoptunnel <- t
|
|
|
+
|
|
|
+ // we're safe to close(t.proxies) because t.closing
|
|
|
+ // protects us inside of RegisterProxy
|
|
|
+ close(t.proxies)
|
|
|
+
|
|
|
+ // shut down all of the proxy connections
|
|
|
+ for c := range t.proxies {
|
|
|
+ c.Close()
|
|
|
+ }
|
|
|
|
|
|
metrics.CloseTunnel(t)
|
|
|
}
|
|
@@ -188,9 +212,7 @@ func (t *Tunnel) Id() string {
|
|
|
return t.url
|
|
|
}
|
|
|
|
|
|
-/**
|
|
|
- * Listens for new public tcp connections from the internet.
|
|
|
- */
|
|
|
+// Listens for new public tcp connections from the internet.
|
|
|
func (t *Tunnel) listenTcp(listener *net.TCPListener) {
|
|
|
for {
|
|
|
defer func() {
|
|
@@ -231,11 +253,29 @@ func (t *Tunnel) HandlePublicConnection(publicConn conn.Conn) {
|
|
|
startTime := time.Now()
|
|
|
metrics.OpenConnection(t, publicConn)
|
|
|
|
|
|
- t.Debug("Requesting new proxy connection")
|
|
|
- t.ctl.out <- &msg.ReqProxyMsg{}
|
|
|
-
|
|
|
- proxyConn := <-t.proxies
|
|
|
- t.Info("Returning proxy connection %s", proxyConn.Id())
|
|
|
+ // initial timeout is zero to try to get a proxy connection without asking for one
|
|
|
+ timeout := time.NewTimer(0)
|
|
|
+ var proxyConn conn.Conn
|
|
|
+
|
|
|
+ // get a proxy connection. if we timeout, request one over the control channel
|
|
|
+ for proxyConn == nil {
|
|
|
+ var ok bool
|
|
|
+ select {
|
|
|
+ case proxyConn, ok = <-t.proxies:
|
|
|
+ if !ok {
|
|
|
+ publicConn.Info("Dropping connection because tunnel is shutting down")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ continue
|
|
|
+ case <-timeout.C:
|
|
|
+ t.Debug("Requesting new proxy connection")
|
|
|
+ // request a proxy connection
|
|
|
+ t.ctl.out <- &msg.ReqProxyMsg{Url: t.url}
|
|
|
+ // timeout after 1 second if we don't get one
|
|
|
+ timeout.Reset(1 * time.Second)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ t.Info("Got proxy connection %s", proxyConn.Id())
|
|
|
|
|
|
defer proxyConn.Close()
|
|
|
bytesIn, bytesOut := conn.Join(publicConn, proxyConn)
|
|
@@ -244,7 +284,18 @@ func (t *Tunnel) HandlePublicConnection(publicConn conn.Conn) {
|
|
|
}
|
|
|
|
|
|
func (t *Tunnel) RegisterProxy(conn conn.Conn) {
|
|
|
+ if atomic.LoadInt32(&t.closing) == 1 {
|
|
|
+ t.Debug("Can't register proxies for a tunnel that is closing")
|
|
|
+ conn.Close()
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
t.Info("Registered proxy connection %s", conn.Id())
|
|
|
conn.AddLogPrefix(t.Id())
|
|
|
- t.proxies <- conn
|
|
|
+ select {
|
|
|
+ case t.proxies <- conn:
|
|
|
+ default:
|
|
|
+ // t.proxies buffer is full, discard this one
|
|
|
+ conn.Close()
|
|
|
+ }
|
|
|
}
|