aboutsummaryrefslogtreecommitdiffhomepage
path: root/modules/caddyhttp/reverseproxy/streaming.go
diff options
context:
space:
mode:
Diffstat (limited to 'modules/caddyhttp/reverseproxy/streaming.go')
-rw-r--r--modules/caddyhttp/reverseproxy/streaming.go84
1 files changed, 70 insertions, 14 deletions
diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go
index 3fde10b35..d697eb402 100644
--- a/modules/caddyhttp/reverseproxy/streaming.go
+++ b/modules/caddyhttp/reverseproxy/streaming.go
@@ -19,6 +19,7 @@
package reverseproxy
import (
+ "bufio"
"context"
"errors"
"fmt"
@@ -33,8 +34,29 @@ import (
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"golang.org/x/net/http/httpguts"
+
+ "github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
+type h2ReadWriteCloser struct {
+ io.ReadCloser
+ http.ResponseWriter
+}
+
+func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) {
+ n, err = rwc.ResponseWriter.Write(p)
+ if err != nil {
+ return 0, err
+ }
+
+ //nolint:bodyclose
+ err = http.NewResponseController(rwc.ResponseWriter).Flush()
+ if err != nil {
+ return 0, err
+ }
+ return n, nil
+}
+
func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := upgradeType(req.Header)
resUpType := upgradeType(res.Header)
@@ -67,24 +89,58 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
// like the rest of handler chain.
copyHeader(rw.Header(), res.Header)
normalizeWebsocketHeaders(rw.Header())
- rw.WriteHeader(res.StatusCode)
- logger.Debug("upgrading connection")
+ var (
+ conn io.ReadWriteCloser
+ brw *bufio.ReadWriter
+ )
+ // websocket over http2, assuming backend doesn't support this, the request will be modified to http1.1 upgrade
+ // TODO: once we can reliably detect backend support this, it can be removed for those backends
+ if body, ok := caddyhttp.GetVar(req.Context(), "h2_websocket_body").(io.ReadCloser); ok {
+ req.Body = body
+ rw.Header().Del("Upgrade")
+ rw.Header().Del("Connection")
+ delete(rw.Header(), "Sec-WebSocket-Accept")
+ rw.WriteHeader(http.StatusOK)
+
+ if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil {
+ c.Write(zap.Int("http_version", 2))
+ }
- //nolint:bodyclose
- conn, brw, hijackErr := http.NewResponseController(rw).Hijack()
- if errors.Is(hijackErr, http.ErrNotSupported) {
- if c := logger.Check(zapcore.ErrorLevel, "can't switch protocols using non-Hijacker ResponseWriter"); c != nil {
- c.Write(zap.String("type", fmt.Sprintf("%T", rw)))
+ //nolint:bodyclose
+ flushErr := http.NewResponseController(rw).Flush()
+ if flushErr != nil {
+ if c := h.logger.Check(zap.ErrorLevel, "failed to flush http2 websocket response"); c != nil {
+ c.Write(zap.Error(flushErr))
+ }
+ return
}
- return
- }
+ conn = h2ReadWriteCloser{req.Body, rw}
+ // bufio is not needed, use minimal buffer
+ brw = bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1))
+ } else {
+ rw.WriteHeader(res.StatusCode)
- if hijackErr != nil {
- if c := logger.Check(zapcore.ErrorLevel, "hijack failed on protocol switch"); c != nil {
- c.Write(zap.Error(hijackErr))
+ if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil {
+ c.Write(zap.Int("http_version", req.ProtoMajor))
+ }
+
+ var hijackErr error
+ //nolint:bodyclose
+ conn, brw, hijackErr = http.NewResponseController(rw).Hijack()
+ if errors.Is(hijackErr, http.ErrNotSupported) {
+ if c := h.logger.Check(zap.ErrorLevel, "can't switch protocols using non-Hijacker ResponseWriter"); c != nil {
+ c.Write(zap.String("type", fmt.Sprintf("%T", rw)))
+ }
+ return
+ }
+
+ if hijackErr != nil {
+ if c := h.logger.Check(zap.ErrorLevel, "hijack failed on protocol switch"); c != nil {
+ c.Write(zap.Error(hijackErr))
+ }
+ return
}
- return
}
// adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5
@@ -103,7 +159,7 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
start := time.Now()
defer func() {
conn.Close()
- if c := logger.Check(zapcore.DebugLevel, "hijack failed on protocol switch"); c != nil {
+ if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil {
c.Write(zap.Duration("duration", time.Since(start)))
}
}()