diff options
Diffstat (limited to 'modules/caddyhttp/responsewriter.go')
-rw-r--r-- | modules/caddyhttp/responsewriter.go | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go index 51f672eee..12627d45c 100644 --- a/modules/caddyhttp/responsewriter.go +++ b/modules/caddyhttp/responsewriter.go @@ -66,6 +66,8 @@ type responseRecorder struct { size int wroteHeader bool stream bool + + readSize *int } // NewResponseRecorder returns a new ResponseRecorder that can be @@ -240,6 +242,12 @@ func (rr *responseRecorder) FlushError() error { return nil } +// Private interface so it can only be used in this package +// #TODO: maybe export it later +func (rr *responseRecorder) setReadSize(size *int) { + rr.readSize = size +} + func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { //nolint:bodyclose conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack() @@ -249,6 +257,15 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { // Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not conn = &hijackedConn{conn, rr} brw.Writer.Reset(conn) + + buffered := brw.Reader.Buffered() + if buffered != 0 { + conn.(*hijackedConn).updateReadSize(buffered) + data, _ := brw.Peek(buffered) + brw.Reader.Reset(io.MultiReader(bytes.NewReader(data), conn)) + } else { + brw.Reader.Reset(conn) + } return conn, brw, nil } @@ -258,6 +275,24 @@ type hijackedConn struct { rr *responseRecorder } +func (hc *hijackedConn) updateReadSize(n int) { + if hc.rr.readSize != nil { + *hc.rr.readSize += n + } +} + +func (hc *hijackedConn) Read(p []byte) (int, error) { + n, err := hc.Conn.Read(p) + hc.updateReadSize(n) + return n, err +} + +func (hc *hijackedConn) WriteTo(w io.Writer) (int64, error) { + n, err := io.Copy(w, hc.Conn) + hc.updateReadSize(int(n)) + return n, err +} + func (hc *hijackedConn) Write(p []byte) (int, error) { n, err := hc.Conn.Write(p) hc.rr.size += n @@ -298,4 +333,6 @@ var ( _ io.ReaderFrom = (*ResponseWriterWrapper)(nil) _ io.ReaderFrom = (*responseRecorder)(nil) _ io.ReaderFrom = (*hijackedConn)(nil) + + _ io.WriterTo = (*hijackedConn)(nil) ) |