diff options
Diffstat (limited to 'modules/caddyhttp/reverseproxy/streaming.go')
-rw-r--r-- | modules/caddyhttp/reverseproxy/streaming.go | 84 |
1 files changed, 70 insertions, 14 deletions
diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 3fde10b35..d697eb402 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -19,6 +19,7 @@ package reverseproxy import ( + "bufio" "context" "errors" "fmt" @@ -33,8 +34,29 @@ import ( "go.uber.org/zap" "go.uber.org/zap/zapcore" "golang.org/x/net/http/httpguts" + + "github.com/caddyserver/caddy/v2/modules/caddyhttp" ) +type h2ReadWriteCloser struct { + io.ReadCloser + http.ResponseWriter +} + +func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) { + n, err = rwc.ResponseWriter.Write(p) + if err != nil { + return 0, err + } + + //nolint:bodyclose + err = http.NewResponseController(rwc.ResponseWriter).Flush() + if err != nil { + return 0, err + } + return n, nil +} + func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response) { reqUpType := upgradeType(req.Header) resUpType := upgradeType(res.Header) @@ -67,24 +89,58 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, // like the rest of handler chain. copyHeader(rw.Header(), res.Header) normalizeWebsocketHeaders(rw.Header()) - rw.WriteHeader(res.StatusCode) - logger.Debug("upgrading connection") + var ( + conn io.ReadWriteCloser + brw *bufio.ReadWriter + ) + // websocket over http2, assuming backend doesn't support this, the request will be modified to http1.1 upgrade + // TODO: once we can reliably detect backend support this, it can be removed for those backends + if body, ok := caddyhttp.GetVar(req.Context(), "h2_websocket_body").(io.ReadCloser); ok { + req.Body = body + rw.Header().Del("Upgrade") + rw.Header().Del("Connection") + delete(rw.Header(), "Sec-WebSocket-Accept") + rw.WriteHeader(http.StatusOK) + + if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil { + c.Write(zap.Int("http_version", 2)) + } - //nolint:bodyclose - conn, brw, hijackErr := http.NewResponseController(rw).Hijack() - if errors.Is(hijackErr, http.ErrNotSupported) { - if c := logger.Check(zapcore.ErrorLevel, "can't switch protocols using non-Hijacker ResponseWriter"); c != nil { - c.Write(zap.String("type", fmt.Sprintf("%T", rw))) + //nolint:bodyclose + flushErr := http.NewResponseController(rw).Flush() + if flushErr != nil { + if c := h.logger.Check(zap.ErrorLevel, "failed to flush http2 websocket response"); c != nil { + c.Write(zap.Error(flushErr)) + } + return } - return - } + conn = h2ReadWriteCloser{req.Body, rw} + // bufio is not needed, use minimal buffer + brw = bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1)) + } else { + rw.WriteHeader(res.StatusCode) - if hijackErr != nil { - if c := logger.Check(zapcore.ErrorLevel, "hijack failed on protocol switch"); c != nil { - c.Write(zap.Error(hijackErr)) + if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil { + c.Write(zap.Int("http_version", req.ProtoMajor)) + } + + var hijackErr error + //nolint:bodyclose + conn, brw, hijackErr = http.NewResponseController(rw).Hijack() + if errors.Is(hijackErr, http.ErrNotSupported) { + if c := h.logger.Check(zap.ErrorLevel, "can't switch protocols using non-Hijacker ResponseWriter"); c != nil { + c.Write(zap.String("type", fmt.Sprintf("%T", rw))) + } + return + } + + if hijackErr != nil { + if c := h.logger.Check(zap.ErrorLevel, "hijack failed on protocol switch"); c != nil { + c.Write(zap.Error(hijackErr)) + } + return } - return } // adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5 @@ -103,7 +159,7 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, start := time.Now() defer func() { conn.Close() - if c := logger.Check(zapcore.DebugLevel, "hijack failed on protocol switch"); c != nil { + if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil { c.Write(zap.Duration("duration", time.Since(start))) } }() |