aboutsummaryrefslogtreecommitdiffhomepage
path: root/modules/caddyhttp/responsewriter.go
diff options
context:
space:
mode:
Diffstat (limited to 'modules/caddyhttp/responsewriter.go')
-rw-r--r--modules/caddyhttp/responsewriter.go37
1 files changed, 37 insertions, 0 deletions
diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go
index 51f672eee..12627d45c 100644
--- a/modules/caddyhttp/responsewriter.go
+++ b/modules/caddyhttp/responsewriter.go
@@ -66,6 +66,8 @@ type responseRecorder struct {
size int
wroteHeader bool
stream bool
+
+ readSize *int
}
// NewResponseRecorder returns a new ResponseRecorder that can be
@@ -240,6 +242,12 @@ func (rr *responseRecorder) FlushError() error {
return nil
}
+// Private interface so it can only be used in this package
+// #TODO: maybe export it later
+func (rr *responseRecorder) setReadSize(size *int) {
+ rr.readSize = size
+}
+
func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
//nolint:bodyclose
conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack()
@@ -249,6 +257,15 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not
conn = &hijackedConn{conn, rr}
brw.Writer.Reset(conn)
+
+ buffered := brw.Reader.Buffered()
+ if buffered != 0 {
+ conn.(*hijackedConn).updateReadSize(buffered)
+ data, _ := brw.Peek(buffered)
+ brw.Reader.Reset(io.MultiReader(bytes.NewReader(data), conn))
+ } else {
+ brw.Reader.Reset(conn)
+ }
return conn, brw, nil
}
@@ -258,6 +275,24 @@ type hijackedConn struct {
rr *responseRecorder
}
+func (hc *hijackedConn) updateReadSize(n int) {
+ if hc.rr.readSize != nil {
+ *hc.rr.readSize += n
+ }
+}
+
+func (hc *hijackedConn) Read(p []byte) (int, error) {
+ n, err := hc.Conn.Read(p)
+ hc.updateReadSize(n)
+ return n, err
+}
+
+func (hc *hijackedConn) WriteTo(w io.Writer) (int64, error) {
+ n, err := io.Copy(w, hc.Conn)
+ hc.updateReadSize(int(n))
+ return n, err
+}
+
func (hc *hijackedConn) Write(p []byte) (int, error) {
n, err := hc.Conn.Write(p)
hc.rr.size += n
@@ -298,4 +333,6 @@ var (
_ io.ReaderFrom = (*ResponseWriterWrapper)(nil)
_ io.ReaderFrom = (*responseRecorder)(nil)
_ io.ReaderFrom = (*hijackedConn)(nil)
+
+ _ io.WriterTo = (*hijackedConn)(nil)
)