diff options
author | Tw <[email protected]> | 2016-09-27 15:35:13 +0800 |
---|---|---|
committer | Tw <[email protected]> | 2016-09-27 15:35:13 +0800 |
commit | d0ddfc849df4c7e50487a8d0f1c2cac6d057be3f (patch) | |
tree | c1bd679ddec0ffa6f98dd5458672ac54b2488f67 /caddyhttp | |
parent | 4adbcd256583f83144a74f65ddcd8a769a840071 (diff) | |
download | caddy-d0ddfc849df4c7e50487a8d0f1c2cac6d057be3f.tar.gz caddy-d0ddfc849df4c7e50487a8d0f1c2cac6d057be3f.zip |
header: defer header operations
fix issue #1131
Signed-off-by: Tw <[email protected]>
Diffstat (limited to 'caddyhttp')
-rw-r--r-- | caddyhttp/header/header.go | 68 | ||||
-rw-r--r-- | caddyhttp/header/header_test.go | 5 |
2 files changed, 68 insertions, 5 deletions
diff --git a/caddyhttp/header/header.go b/caddyhttp/header/header.go index 08ba4bd2c..121264d48 100644 --- a/caddyhttp/header/header.go +++ b/caddyhttp/header/header.go @@ -21,22 +21,23 @@ type Headers struct { // setting headers on the response according to the configured rules. func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { replacer := httpserver.NewReplacer(r, nil, "") + rww := &responseWriterWrapper{w: w} for _, rule := range h.Rules { if httpserver.Path(r.URL.Path).Matches(rule.Path) { for _, header := range rule.Headers { // One can either delete a header, add multiple values to a header, or simply // set a header. if strings.HasPrefix(header.Name, "-") { - w.Header().Del(strings.TrimLeft(header.Name, "-")) + rww.delHeader(strings.TrimLeft(header.Name, "-")) } else if strings.HasPrefix(header.Name, "+") { - w.Header().Add(strings.TrimLeft(header.Name, "+"), replacer.Replace(header.Value)) + rww.addHeader(strings.TrimLeft(header.Name, "+"), replacer.Replace(header.Value)) } else { - w.Header().Set(header.Name, replacer.Replace(header.Value)) + rww.setHeader(header.Name, replacer.Replace(header.Value)) } } } } - return h.Next.ServeHTTP(w, r) + return h.Next.ServeHTTP(rww, r) } type ( @@ -53,3 +54,62 @@ type ( Value string } ) + +// headerOperation represents an operation on the header +type headerOperation func(http.Header) + +// responseWriterWrapper wraps the real ResponseWriter. +// It defers header operations until writeHeader +type responseWriterWrapper struct { + w http.ResponseWriter + ops []headerOperation + wroteHeader bool +} + +func (rww *responseWriterWrapper) Header() http.Header { + return rww.w.Header() +} + +func (rww *responseWriterWrapper) Write(d []byte) (int, error) { + if !rww.wroteHeader { + rww.WriteHeader(http.StatusOK) + } + return rww.w.Write(d) +} + +func (rww *responseWriterWrapper) WriteHeader(status int) { + if rww.wroteHeader { + return + } + rww.wroteHeader = true + // capture the original headers + h := rww.Header() + + // perform our revisions + for _, op := range rww.ops { + op(h) + } + + rww.w.WriteHeader(status) +} + +// addHeader registers a http.Header.Add operation +func (rww *responseWriterWrapper) addHeader(key, value string) { + rww.ops = append(rww.ops, func(h http.Header) { + h.Add(key, value) + }) +} + +// delHeader registers a http.Header.Del operation +func (rww *responseWriterWrapper) delHeader(key string) { + rww.ops = append(rww.ops, func(h http.Header) { + h.Del(key) + }) +} + +// setHeader registers a http.Header.Set operation +func (rww *responseWriterWrapper) setHeader(key, value string) { + rww.ops = append(rww.ops, func(h http.Header) { + h.Set(key, value) + }) +} diff --git a/caddyhttp/header/header_test.go b/caddyhttp/header/header_test.go index 787c4d7a5..0e0aaa310 100644 --- a/caddyhttp/header/header_test.go +++ b/caddyhttp/header/header_test.go @@ -1,6 +1,7 @@ package header import ( + "fmt" "net/http" "net/http/httptest" "os" @@ -30,6 +31,8 @@ func TestHeader(t *testing.T) { } { he := Headers{ Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + w.Header().Set("Bar", "Removed in /a") + fmt.Fprint(w, "This is a test") return 0, nil }), Rules: []Rule{ @@ -47,7 +50,6 @@ func TestHeader(t *testing.T) { } rec := httptest.NewRecorder() - rec.Header().Set("Bar", "Removed in /a") he.ServeHTTP(rec, req) @@ -61,6 +63,7 @@ func TestHeader(t *testing.T) { func TestMultipleHeaders(t *testing.T) { he := Headers{ Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + fmt.Fprint(w, "This is a test") return 0, nil }), Rules: []Rule{ |