summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--caddyhttp/proxy/reverseproxy.go184
1 files changed, 140 insertions, 44 deletions
diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go
index 49df16036..c980e9051 100644
--- a/caddyhttp/proxy/reverseproxy.go
+++ b/caddyhttp/proxy/reverseproxy.go
@@ -27,6 +27,11 @@ import (
"github.com/mholt/caddy/caddyhttp/httpserver"
)
+var defaultDialer = &net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+}
+
var bufferPool = sync.Pool{New: createBuffer}
func createBuffer() interface{} {
@@ -135,11 +140,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
// just use default transport, to avoid creating
// a brand new transport
transport := &http.Transport{
- Proxy: http.ProxyFromEnvironment,
- Dial: (&net.Dialer{
- Timeout: 30 * time.Second,
- KeepAlive: 30 * time.Second,
- }).Dial,
+ Proxy: http.ProxyFromEnvironment,
+ Dial: defaultDialer.Dial,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
@@ -162,11 +164,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
func (rp *ReverseProxy) UseInsecureTransport() {
if rp.Transport == nil {
transport := &http.Transport{
- Proxy: http.ProxyFromEnvironment,
- Dial: (&net.Dialer{
- Timeout: 30 * time.Second,
- KeepAlive: 30 * time.Second,
- }).Dial,
+ Proxy: http.ProxyFromEnvironment,
+ Dial: defaultDialer.Dial,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
@@ -341,51 +340,148 @@ type connHijackerTransport struct {
}
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
- transport := &http.Transport{
- Proxy: http.ProxyFromEnvironment,
- Dial: (&net.Dialer{
- Timeout: 30 * time.Second,
- KeepAlive: 30 * time.Second,
- }).Dial,
- TLSHandshakeTimeout: 10 * time.Second,
+ t := &http.Transport{
MaxIdleConnsPerHost: -1,
}
- if base != nil {
- if baseTransport, ok := base.(*http.Transport); ok {
- transport.Proxy = baseTransport.Proxy
- transport.TLSClientConfig = baseTransport.TLSClientConfig
- transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout
- transport.Dial = baseTransport.Dial
- transport.DialTLS = baseTransport.DialTLS
- transport.MaxIdleConnsPerHost = -1
+ if b, _ := base.(*http.Transport); b != nil {
+ t.Proxy = b.Proxy
+ t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig)
+ t.TLSClientConfig.NextProtos = nil
+ t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
+ t.Dial = b.Dial
+ t.DialTLS = b.DialTLS
+ } else {
+ t.Proxy = http.ProxyFromEnvironment
+ t.TLSHandshakeTimeout = 10 * time.Second
+ }
+ hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]}
+
+ dial := getTransportDial(t)
+ dialTLS := getTransportDialTLS(t)
+
+ t.Dial = func(network, addr string) (net.Conn, error) {
+ c, err := dial(network, addr)
+ hj.Conn = c
+ return &hijackedConn{c, hj}, err
+ }
+
+ if dialTLS != nil {
+ t.DialTLS = func(network, addr string) (net.Conn, error) {
+ c, err := dialTLS(network, addr)
+ hj.Conn = c
+ return &hijackedConn{c, hj}, err
}
}
- hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]}
- oldDial := transport.Dial
- oldDialTLS := transport.DialTLS
- if oldDial == nil {
- oldDial = (&net.Dialer{
- Timeout: 30 * time.Second,
- KeepAlive: 30 * time.Second,
- }).Dial
+
+ return hj
+}
+
+// getTransportDial always returns a plain Dialer
+// and defaults to the existing t.Dial.
+func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) {
+ if t.Dial != nil {
+ return t.Dial
+ }
+ return defaultDialer.Dial
+}
+
+// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil
+// and defaults to the existing t.DialTLS.
+func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
+ if t.DialTLS != nil {
+ return t.DialTLS
}
- hjTransport.Dial = func(network, addr string) (net.Conn, error) {
- c, err := oldDial(network, addr)
- hjTransport.Conn = c
- return &hijackedConn{c, hjTransport}, err
+ if t.TLSClientConfig == nil {
+ return nil
}
- if oldDialTLS != nil {
- hjTransport.DialTLS = func(network, addr string) (net.Conn, error) {
- c, err := oldDialTLS(network, addr)
- hjTransport.Conn = c
- return &hijackedConn{c, hjTransport}, err
+
+ // newConnHijackerTransport will modify t.Dial after calling this method
+ // => Create a backup reference.
+ plainDial := getTransportDial(t)
+
+ return func(network, addr string) (net.Conn, error) {
+ plainConn, err := plainDial(network, addr)
+ if err != nil {
+ return nil, err
}
+
+ tlsConn := tls.Client(plainConn, t.TLSClientConfig)
+ errc := make(chan error, 2)
+ var timer *time.Timer
+ if d := t.TLSHandshakeTimeout; d != 0 {
+ timer = time.AfterFunc(d, func() {
+ errc <- tlsHandshakeTimeoutError{}
+ })
+ }
+ go func() {
+ err := tlsConn.Handshake()
+ if timer != nil {
+ timer.Stop()
+ }
+ errc <- err
+ }()
+ if err := <-errc; err != nil {
+ plainConn.Close()
+ return nil, err
+ }
+ if !t.TLSClientConfig.InsecureSkipVerify {
+ serverName := t.TLSClientConfig.ServerName
+ if serverName == "" {
+ serverName = addr
+ idx := strings.LastIndex(serverName, ":")
+ if idx != -1 {
+ serverName = serverName[:idx]
+ }
+ }
+ if err := tlsConn.VerifyHostname(serverName); err != nil {
+ plainConn.Close()
+ return nil, err
+ }
+ }
+
+ return tlsConn, nil
+ }
+}
+
+type tlsHandshakeTimeoutError struct{}
+
+func (tlsHandshakeTimeoutError) Timeout() bool { return true }
+func (tlsHandshakeTimeoutError) Temporary() bool { return true }
+func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
+
+// cloneTLSClientConfig is like cloneTLSConfig but omits
+// the fields SessionTicketsDisabled and SessionTicketKey.
+// This makes it safe to call cloneTLSClientConfig on a config
+// in active use by a server.
+func cloneTLSClientConfig(cfg *tls.Config) *tls.Config {
+ if cfg == nil {
+ return &tls.Config{}
+ }
+ return &tls.Config{
+ Rand: cfg.Rand,
+ Time: cfg.Time,
+ Certificates: cfg.Certificates,
+ NameToCertificate: cfg.NameToCertificate,
+ GetCertificate: cfg.GetCertificate,
+ RootCAs: cfg.RootCAs,
+ NextProtos: cfg.NextProtos,
+ ServerName: cfg.ServerName,
+ ClientAuth: cfg.ClientAuth,
+ ClientCAs: cfg.ClientCAs,
+ InsecureSkipVerify: cfg.InsecureSkipVerify,
+ CipherSuites: cfg.CipherSuites,
+ PreferServerCipherSuites: cfg.PreferServerCipherSuites,
+ ClientSessionCache: cfg.ClientSessionCache,
+ MinVersion: cfg.MinVersion,
+ MaxVersion: cfg.MaxVersion,
+ CurvePreferences: cfg.CurvePreferences,
+ DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
+ Renegotiation: cfg.Renegotiation,
}
- return hjTransport
}
func requestIsWebsocket(req *http.Request) bool {
- return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade"))
+ return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
}
type writeFlusher interface {