summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorLeonard Hecker <[email protected]>2016-12-26 20:52:36 +0100
committerLeonard Hecker <[email protected]>2016-12-26 20:52:36 +0100
commit9f9ad21aaa526768b638105fbc6675c80fdfa0ce (patch)
tree21242590948431a3c949536a599d4badcf71833b
parent53635ba538fb41bbfd38b70282cd6d59e693c0b3 (diff)
downloadcaddy-9f9ad21aaa526768b638105fbc6675c80fdfa0ce.tar.gz
caddy-9f9ad21aaa526768b638105fbc6675c80fdfa0ce.zip
Fixed #1292: Failure to proxy WebSockets over HTTPS
This issue was caused by connHijackerTransport trying to record HTTP response headers by "hijacking" the Read() method of the plain net.Conn. This does not simply work over TLS though since this will record the TLS handshake and encrypted data instead of the actual content. This commit fixes the problem by providing an alternative transport.DialTLS which correctly hijacks the overlying tls.Conn instead.
-rw-r--r--caddyhttp/proxy/reverseproxy.go184
1 files changed, 140 insertions, 44 deletions
diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go
index 49df16036..c980e9051 100644
--- a/caddyhttp/proxy/reverseproxy.go
+++ b/caddyhttp/proxy/reverseproxy.go
@@ -27,6 +27,11 @@ import (
"github.com/mholt/caddy/caddyhttp/httpserver"
)
+var defaultDialer = &net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+}
+
var bufferPool = sync.Pool{New: createBuffer}
func createBuffer() interface{} {
@@ -135,11 +140,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
// just use default transport, to avoid creating
// a brand new transport
transport := &http.Transport{
- Proxy: http.ProxyFromEnvironment,
- Dial: (&net.Dialer{
- Timeout: 30 * time.Second,
- KeepAlive: 30 * time.Second,
- }).Dial,
+ Proxy: http.ProxyFromEnvironment,
+ Dial: defaultDialer.Dial,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
@@ -162,11 +164,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
func (rp *ReverseProxy) UseInsecureTransport() {
if rp.Transport == nil {
transport := &http.Transport{
- Proxy: http.ProxyFromEnvironment,
- Dial: (&net.Dialer{
- Timeout: 30 * time.Second,
- KeepAlive: 30 * time.Second,
- }).Dial,
+ Proxy: http.ProxyFromEnvironment,
+ Dial: defaultDialer.Dial,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
@@ -341,51 +340,148 @@ type connHijackerTransport struct {
}
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
- transport := &http.Transport{
- Proxy: http.ProxyFromEnvironment,
- Dial: (&net.Dialer{
- Timeout: 30 * time.Second,
- KeepAlive: 30 * time.Second,
- }).Dial,
- TLSHandshakeTimeout: 10 * time.Second,
+ t := &http.Transport{
MaxIdleConnsPerHost: -1,
}
- if base != nil {
- if baseTransport, ok := base.(*http.Transport); ok {
- transport.Proxy = baseTransport.Proxy
- transport.TLSClientConfig = baseTransport.TLSClientConfig
- transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout
- transport.Dial = baseTransport.Dial
- transport.DialTLS = baseTransport.DialTLS
- transport.MaxIdleConnsPerHost = -1
+ if b, _ := base.(*http.Transport); b != nil {
+ t.Proxy = b.Proxy
+ t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig)
+ t.TLSClientConfig.NextProtos = nil
+ t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
+ t.Dial = b.Dial
+ t.DialTLS = b.DialTLS
+ } else {
+ t.Proxy = http.ProxyFromEnvironment
+ t.TLSHandshakeTimeout = 10 * time.Second
+ }
+ hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]}
+
+ dial := getTransportDial(t)
+ dialTLS := getTransportDialTLS(t)
+
+ t.Dial = func(network, addr string) (net.Conn, error) {
+ c, err := dial(network, addr)
+ hj.Conn = c
+ return &hijackedConn{c, hj}, err
+ }
+
+ if dialTLS != nil {
+ t.DialTLS = func(network, addr string) (net.Conn, error) {
+ c, err := dialTLS(network, addr)
+ hj.Conn = c
+ return &hijackedConn{c, hj}, err
}
}
- hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]}
- oldDial := transport.Dial
- oldDialTLS := transport.DialTLS
- if oldDial == nil {
- oldDial = (&net.Dialer{
- Timeout: 30 * time.Second,
- KeepAlive: 30 * time.Second,
- }).Dial
+
+ return hj
+}
+
+// getTransportDial always returns a plain Dialer
+// and defaults to the existing t.Dial.
+func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) {
+ if t.Dial != nil {
+ return t.Dial
+ }
+ return defaultDialer.Dial
+}
+
+// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil
+// and defaults to the existing t.DialTLS.
+func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
+ if t.DialTLS != nil {
+ return t.DialTLS
}
- hjTransport.Dial = func(network, addr string) (net.Conn, error) {
- c, err := oldDial(network, addr)
- hjTransport.Conn = c
- return &hijackedConn{c, hjTransport}, err
+ if t.TLSClientConfig == nil {
+ return nil
}
- if oldDialTLS != nil {
- hjTransport.DialTLS = func(network, addr string) (net.Conn, error) {
- c, err := oldDialTLS(network, addr)
- hjTransport.Conn = c
- return &hijackedConn{c, hjTransport}, err
+
+ // newConnHijackerTransport will modify t.Dial after calling this method
+ // => Create a backup reference.
+ plainDial := getTransportDial(t)
+
+ return func(network, addr string) (net.Conn, error) {
+ plainConn, err := plainDial(network, addr)
+ if err != nil {
+ return nil, err
}
+
+ tlsConn := tls.Client(plainConn, t.TLSClientConfig)
+ errc := make(chan error, 2)
+ var timer *time.Timer
+ if d := t.TLSHandshakeTimeout; d != 0 {
+ timer = time.AfterFunc(d, func() {
+ errc <- tlsHandshakeTimeoutError{}
+ })
+ }
+ go func() {
+ err := tlsConn.Handshake()
+ if timer != nil {
+ timer.Stop()
+ }
+ errc <- err
+ }()
+ if err := <-errc; err != nil {
+ plainConn.Close()
+ return nil, err
+ }
+ if !t.TLSClientConfig.InsecureSkipVerify {
+ serverName := t.TLSClientConfig.ServerName
+ if serverName == "" {
+ serverName = addr
+ idx := strings.LastIndex(serverName, ":")
+ if idx != -1 {
+ serverName = serverName[:idx]
+ }
+ }
+ if err := tlsConn.VerifyHostname(serverName); err != nil {
+ plainConn.Close()
+ return nil, err
+ }
+ }
+
+ return tlsConn, nil
+ }
+}
+
+type tlsHandshakeTimeoutError struct{}
+
+func (tlsHandshakeTimeoutError) Timeout() bool { return true }
+func (tlsHandshakeTimeoutError) Temporary() bool { return true }
+func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
+
+// cloneTLSClientConfig is like cloneTLSConfig but omits
+// the fields SessionTicketsDisabled and SessionTicketKey.
+// This makes it safe to call cloneTLSClientConfig on a config
+// in active use by a server.
+func cloneTLSClientConfig(cfg *tls.Config) *tls.Config {
+ if cfg == nil {
+ return &tls.Config{}
+ }
+ return &tls.Config{
+ Rand: cfg.Rand,
+ Time: cfg.Time,
+ Certificates: cfg.Certificates,
+ NameToCertificate: cfg.NameToCertificate,
+ GetCertificate: cfg.GetCertificate,
+ RootCAs: cfg.RootCAs,
+ NextProtos: cfg.NextProtos,
+ ServerName: cfg.ServerName,
+ ClientAuth: cfg.ClientAuth,
+ ClientCAs: cfg.ClientCAs,
+ InsecureSkipVerify: cfg.InsecureSkipVerify,
+ CipherSuites: cfg.CipherSuites,
+ PreferServerCipherSuites: cfg.PreferServerCipherSuites,
+ ClientSessionCache: cfg.ClientSessionCache,
+ MinVersion: cfg.MinVersion,
+ MaxVersion: cfg.MaxVersion,
+ CurvePreferences: cfg.CurvePreferences,
+ DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
+ Renegotiation: cfg.Renegotiation,
}
- return hjTransport
}
func requestIsWebsocket(req *http.Request) bool {
- return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade"))
+ return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
}
type writeFlusher interface {