diff options
-rw-r--r-- | caddyhttp/proxy/reverseproxy.go | 184 |
1 files changed, 140 insertions, 44 deletions
diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index 49df16036..c980e9051 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -27,6 +27,11 @@ import ( "github.com/mholt/caddy/caddyhttp/httpserver" ) +var defaultDialer = &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, +} + var bufferPool = sync.Pool{New: createBuffer} func createBuffer() interface{} { @@ -135,11 +140,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * // just use default transport, to avoid creating // a brand new transport transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, + Proxy: http.ProxyFromEnvironment, + Dial: defaultDialer.Dial, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } @@ -162,11 +164,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * func (rp *ReverseProxy) UseInsecureTransport() { if rp.Transport == nil { transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, + Proxy: http.ProxyFromEnvironment, + Dial: defaultDialer.Dial, TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } @@ -341,51 +340,148 @@ type connHijackerTransport struct { } func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { - transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, - TLSHandshakeTimeout: 10 * time.Second, + t := &http.Transport{ MaxIdleConnsPerHost: -1, } - if base != nil { - if baseTransport, ok := base.(*http.Transport); ok { - transport.Proxy = baseTransport.Proxy - transport.TLSClientConfig = baseTransport.TLSClientConfig - transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout - transport.Dial = baseTransport.Dial - transport.DialTLS = baseTransport.DialTLS - transport.MaxIdleConnsPerHost = -1 + if b, _ := base.(*http.Transport); b != nil { + t.Proxy = b.Proxy + t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig) + t.TLSClientConfig.NextProtos = nil + t.TLSHandshakeTimeout = b.TLSHandshakeTimeout + t.Dial = b.Dial + t.DialTLS = b.DialTLS + } else { + t.Proxy = http.ProxyFromEnvironment + t.TLSHandshakeTimeout = 10 * time.Second + } + hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]} + + dial := getTransportDial(t) + dialTLS := getTransportDialTLS(t) + + t.Dial = func(network, addr string) (net.Conn, error) { + c, err := dial(network, addr) + hj.Conn = c + return &hijackedConn{c, hj}, err + } + + if dialTLS != nil { + t.DialTLS = func(network, addr string) (net.Conn, error) { + c, err := dialTLS(network, addr) + hj.Conn = c + return &hijackedConn{c, hj}, err } } - hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]} - oldDial := transport.Dial - oldDialTLS := transport.DialTLS - if oldDial == nil { - oldDial = (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial + + return hj +} + +// getTransportDial always returns a plain Dialer +// and defaults to the existing t.Dial. +func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) { + if t.Dial != nil { + return t.Dial + } + return defaultDialer.Dial +} + +// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil +// and defaults to the existing t.DialTLS. +func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) { + if t.DialTLS != nil { + return t.DialTLS } - hjTransport.Dial = func(network, addr string) (net.Conn, error) { - c, err := oldDial(network, addr) - hjTransport.Conn = c - return &hijackedConn{c, hjTransport}, err + if t.TLSClientConfig == nil { + return nil } - if oldDialTLS != nil { - hjTransport.DialTLS = func(network, addr string) (net.Conn, error) { - c, err := oldDialTLS(network, addr) - hjTransport.Conn = c - return &hijackedConn{c, hjTransport}, err + + // newConnHijackerTransport will modify t.Dial after calling this method + // => Create a backup reference. + plainDial := getTransportDial(t) + + return func(network, addr string) (net.Conn, error) { + plainConn, err := plainDial(network, addr) + if err != nil { + return nil, err } + + tlsConn := tls.Client(plainConn, t.TLSClientConfig) + errc := make(chan error, 2) + var timer *time.Timer + if d := t.TLSHandshakeTimeout; d != 0 { + timer = time.AfterFunc(d, func() { + errc <- tlsHandshakeTimeoutError{} + }) + } + go func() { + err := tlsConn.Handshake() + if timer != nil { + timer.Stop() + } + errc <- err + }() + if err := <-errc; err != nil { + plainConn.Close() + return nil, err + } + if !t.TLSClientConfig.InsecureSkipVerify { + serverName := t.TLSClientConfig.ServerName + if serverName == "" { + serverName = addr + idx := strings.LastIndex(serverName, ":") + if idx != -1 { + serverName = serverName[:idx] + } + } + if err := tlsConn.VerifyHostname(serverName); err != nil { + plainConn.Close() + return nil, err + } + } + + return tlsConn, nil + } +} + +type tlsHandshakeTimeoutError struct{} + +func (tlsHandshakeTimeoutError) Timeout() bool { return true } +func (tlsHandshakeTimeoutError) Temporary() bool { return true } +func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } + +// cloneTLSClientConfig is like cloneTLSConfig but omits +// the fields SessionTicketsDisabled and SessionTicketKey. +// This makes it safe to call cloneTLSClientConfig on a config +// in active use by a server. +func cloneTLSClientConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return &tls.Config{ + Rand: cfg.Rand, + Time: cfg.Time, + Certificates: cfg.Certificates, + NameToCertificate: cfg.NameToCertificate, + GetCertificate: cfg.GetCertificate, + RootCAs: cfg.RootCAs, + NextProtos: cfg.NextProtos, + ServerName: cfg.ServerName, + ClientAuth: cfg.ClientAuth, + ClientCAs: cfg.ClientCAs, + InsecureSkipVerify: cfg.InsecureSkipVerify, + CipherSuites: cfg.CipherSuites, + PreferServerCipherSuites: cfg.PreferServerCipherSuites, + ClientSessionCache: cfg.ClientSessionCache, + MinVersion: cfg.MinVersion, + MaxVersion: cfg.MaxVersion, + CurvePreferences: cfg.CurvePreferences, + DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled, + Renegotiation: cfg.Renegotiation, } - return hjTransport } func requestIsWebsocket(req *http.Request) bool { - return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")) + return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") } type writeFlusher interface { |