diff options
author | Matthew Holt <[email protected]> | 2019-05-20 23:48:43 -0600 |
---|---|---|
committer | Matthew Holt <[email protected]> | 2019-05-20 23:48:43 -0600 |
commit | 65195a726d9ceff4bbf870b7baa7eff20cf35381 (patch) | |
tree | 6b6f19517e4874831197b535395cdc891d11dfbd | |
parent | b84cb058484e7900d60324da1289d319c77f5447 (diff) | |
download | caddy-65195a726d9ceff4bbf870b7baa7eff20cf35381.tar.gz caddy-65195a726d9ceff4bbf870b7baa7eff20cf35381.zip |
Implement rewrite middleware; fix middleware stack bugs
-rw-r--r-- | cmd/caddy2/main.go | 1 | ||||
-rw-r--r-- | modules/caddyhttp/fileserver/staticfiles.go | 2 | ||||
-rw-r--r-- | modules/caddyhttp/headers/headers.go | 20 | ||||
-rw-r--r-- | modules/caddyhttp/matchers_test.go | 6 | ||||
-rw-r--r-- | modules/caddyhttp/replacer.go | 6 | ||||
-rw-r--r-- | modules/caddyhttp/rewrite/rewrite.go | 71 | ||||
-rw-r--r-- | modules/caddyhttp/routes.go | 49 | ||||
-rw-r--r-- | modules/caddyhttp/server.go | 12 |
8 files changed, 133 insertions, 34 deletions
diff --git a/cmd/caddy2/main.go b/cmd/caddy2/main.go index 4276d7d08..59239a2f1 100644 --- a/cmd/caddy2/main.go +++ b/cmd/caddy2/main.go @@ -9,6 +9,7 @@ import ( _ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/fileserver" _ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/headers" _ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/reverseproxy" + _ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/rewrite" _ "bitbucket.org/lightcodelabs/caddy2/modules/caddytls" ) diff --git a/modules/caddyhttp/fileserver/staticfiles.go b/modules/caddyhttp/fileserver/staticfiles.go index e859abec9..873c5fb81 100644 --- a/modules/caddyhttp/fileserver/staticfiles.go +++ b/modules/caddyhttp/fileserver/staticfiles.go @@ -104,7 +104,7 @@ func (fsrv *FileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) error if filename == "" { // no files worked, so resort to fallback if fsrv.Fallback != nil { - fallback := fsrv.Fallback.BuildCompositeRoute(w, r) + fallback, w := fsrv.Fallback.BuildCompositeRoute(w, r) return fallback.ServeHTTP(w, r) } return caddyhttp.Error(http.StatusNotFound, nil) diff --git a/modules/caddyhttp/headers/headers.go b/modules/caddyhttp/headers/headers.go index 37efc571c..b56bbf9ca 100644 --- a/modules/caddyhttp/headers/headers.go +++ b/modules/caddyhttp/headers/headers.go @@ -37,29 +37,36 @@ type RespHeaderOps struct { } func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { - apply(h.Request, r.Header) + repl := r.Context().Value(caddy2.ReplacerCtxKey).(caddy2.Replacer) + apply(h.Request, r.Header, repl) if h.Response.Deferred { w = &responseWriterWrapper{ ResponseWriterWrapper: &caddyhttp.ResponseWriterWrapper{ResponseWriter: w}, + replacer: repl, headerOps: h.Response.HeaderOps, } } else { - apply(h.Response.HeaderOps, w.Header()) + apply(h.Response.HeaderOps, w.Header(), repl) } return next.ServeHTTP(w, r) } -func apply(ops HeaderOps, hdr http.Header) { +func apply(ops HeaderOps, hdr http.Header, repl caddy2.Replacer) { for fieldName, vals := range ops.Add { + fieldName = repl.ReplaceAll(fieldName, "") for _, v := range vals { - hdr.Add(fieldName, v) + hdr.Add(fieldName, repl.ReplaceAll(v, "")) } } for fieldName, vals := range ops.Set { + fieldName = repl.ReplaceAll(fieldName, "") + for i := range vals { + vals[i] = repl.ReplaceAll(vals[i], "") + } hdr.Set(fieldName, strings.Join(vals, ",")) } for _, fieldName := range ops.Delete { - hdr.Del(fieldName) + hdr.Del(repl.ReplaceAll(fieldName, "")) } } @@ -67,6 +74,7 @@ func apply(ops HeaderOps, hdr http.Header) { // operations until WriteHeader is called. type responseWriterWrapper struct { *caddyhttp.ResponseWriterWrapper + replacer caddy2.Replacer headerOps HeaderOps wroteHeader bool } @@ -83,7 +91,7 @@ func (rww *responseWriterWrapper) WriteHeader(status int) { return } rww.wroteHeader = true - apply(rww.headerOps, rww.ResponseWriterWrapper.Header()) + apply(rww.headerOps, rww.ResponseWriterWrapper.Header(), rww.replacer) rww.ResponseWriterWrapper.WriteHeader(status) } diff --git a/modules/caddyhttp/matchers_test.go b/modules/caddyhttp/matchers_test.go index c279bad9f..c4c7845b6 100644 --- a/modules/caddyhttp/matchers_test.go +++ b/modules/caddyhttp/matchers_test.go @@ -227,9 +227,10 @@ func TestPathREMatcher(t *testing.T) { // set up the fake request and its Replacer req := &http.Request{URL: &url.URL{Path: tc.input}} - repl := newReplacer(req, httptest.NewRecorder()) + repl := caddy2.NewReplacer() ctx := context.WithValue(req.Context(), caddy2.ReplacerCtxKey, repl) req = req.WithContext(ctx) + addHTTPVarsToReplacer(repl, req, httptest.NewRecorder()) actual := tc.match.Match(req) if actual != tc.expect { @@ -344,9 +345,10 @@ func TestHeaderREMatcher(t *testing.T) { // set up the fake request and its Replacer req := &http.Request{Header: tc.input, URL: new(url.URL)} - repl := newReplacer(req, httptest.NewRecorder()) + repl := caddy2.NewReplacer() ctx := context.WithValue(req.Context(), caddy2.ReplacerCtxKey, repl) req = req.WithContext(ctx) + addHTTPVarsToReplacer(repl, req, httptest.NewRecorder()) actual := tc.match.Match(req) if actual != tc.expect { diff --git a/modules/caddyhttp/replacer.go b/modules/caddyhttp/replacer.go index 6feb1437c..16cc1fe99 100644 --- a/modules/caddyhttp/replacer.go +++ b/modules/caddyhttp/replacer.go @@ -13,9 +13,7 @@ import ( // TODO: A simple way to format or escape or encode each value would be nice // ... TODO: Should we just use templates? :-/ yeesh... -func newReplacer(req *http.Request, w http.ResponseWriter) caddy2.Replacer { - repl := caddy2.NewReplacer() - +func addHTTPVarsToReplacer(repl caddy2.Replacer, req *http.Request, w http.ResponseWriter) { httpVars := func() map[string]string { m := make(map[string]string) if req != nil { @@ -78,6 +76,4 @@ func newReplacer(req *http.Request, w http.ResponseWriter) caddy2.Replacer { } repl.Map(httpVars) - - return repl } diff --git a/modules/caddyhttp/rewrite/rewrite.go b/modules/caddyhttp/rewrite/rewrite.go new file mode 100644 index 000000000..1afb8a421 --- /dev/null +++ b/modules/caddyhttp/rewrite/rewrite.go @@ -0,0 +1,71 @@ +package headers + +import ( + "net/http" + "net/url" + "strings" + + "bitbucket.org/lightcodelabs/caddy2" + "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp" +) + +func init() { + caddy2.RegisterModule(caddy2.Module{ + Name: "http.middleware.rewrite", + New: func() (interface{}, error) { return new(Rewrite), nil }, + }) +} + +// Rewrite is a middleware which can rewrite HTTP requests. +type Rewrite struct { + Method string `json:"method"` + URI string `json:"uri"` + Rehandle bool `json:"rehandle"` +} + +func (rewr Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { + repl := r.Context().Value(caddy2.ReplacerCtxKey).(caddy2.Replacer) + var rehandleNeeded bool + + if rewr.Method != "" { + method := r.Method + r.Method = strings.ToUpper(repl.ReplaceAll(rewr.Method, "")) + if r.Method != method { + rehandleNeeded = true + } + } + + if rewr.URI != "" { + // TODO: clean this all up, I don't think it's right + + oldURI := r.RequestURI + newURI := repl.ReplaceAll(rewr.URI, "") + u, err := url.Parse(newURI) + if err != nil { + return caddyhttp.Error(http.StatusInternalServerError, err) + } + + r.RequestURI = newURI + + r.URL.Path = u.Path + if u.RawQuery != "" { + r.URL.RawQuery = u.RawQuery + } + if u.Fragment != "" { + r.URL.Fragment = u.Fragment + } + + if newURI != oldURI { + rehandleNeeded = true + } + } + + if rehandleNeeded && rewr.Rehandle { + return caddyhttp.ErrRehandle + } + + return next.ServeHTTP(w, r) +} + +// Interface guard +var _ caddyhttp.MiddlewareHandler = (*Rewrite)(nil) diff --git a/modules/caddyhttp/routes.go b/modules/caddyhttp/routes.go index daae08079..92aa3e87b 100644 --- a/modules/caddyhttp/routes.go +++ b/modules/caddyhttp/routes.go @@ -65,23 +65,24 @@ func (routes RouteList) Provision(ctx caddy2.Context) error { return nil } -// BuildCompositeRoute creates a chain of handlers by -// applying all the matching routes. -func (routes RouteList) BuildCompositeRoute(w http.ResponseWriter, r *http.Request) Handler { +// BuildCompositeRoute creates a chain of handlers by applying all the matching +// routes. The returned ResponseWriter should be used instead of rw. +func (routes RouteList) BuildCompositeRoute(rw http.ResponseWriter, req *http.Request) (Handler, http.ResponseWriter) { + mrw := &middlewareResponseWriter{ResponseWriterWrapper: &ResponseWriterWrapper{rw}} + if len(routes) == 0 { - return emptyHandler + return emptyHandler, mrw } var mid []Middleware var responder Handler - mrw := &middlewareResponseWriter{ResponseWriterWrapper: &ResponseWriterWrapper{w}} groups := make(map[string]struct{}) routeLoop: for _, route := range routes { // see if route matches for _, m := range route.matchers { - if !m.Match(r) { + if !m.Match(req) { continue routeLoop } } @@ -102,15 +103,13 @@ routeLoop: // apply the rest of the route for _, m := range route.middleware { - mid = append(mid, func(next HandlerFunc) HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) error { - // TODO: This is where request tracing could be implemented; also - // see below to trace the responder as well - // TODO: Trace a diff of the request, would be cool too! see what changed since the last middleware (host, headers, URI...) - // TODO: see what the std lib gives us in terms of stack trracing too - return m.ServeHTTP(mrw, r, next) - } - }) + // we have to be sure to wrap m outside + // of our current scope so that the + // reference to this m isn't overwritten + // on the next iteration, leaving only + // the last middleware in the chain as + // the ONLY middleware in the chain! + mid = append(mid, wrapMiddleware(m)) } if responder == nil { responder = route.responder @@ -132,7 +131,25 @@ routeLoop: stack = mid[i](stack) } - return stack + return stack, mrw +} + +// wrapMiddleware wraps m such that it can be correctly +// appended to a list of middleware. This is necessary +// so that only the last middleware in a loop does not +// become the only middleware of the stack, repeatedly +// executed (i.e. it is necessary to keep a reference +// to this m outside of the scope of a loop)! +func wrapMiddleware(m MiddlewareHandler) Middleware { + return func(next HandlerFunc) HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { + // TODO: This is where request tracing could be implemented; also + // see below to trace the responder as well + // TODO: Trace a diff of the request, would be cool too! see what changed since the last middleware (host, headers, URI...) + // TODO: see what the std lib gives us in terms of stack tracing too + return m.ServeHTTP(w, r, next) + } + } } type middlewareResponseWriter struct { diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index fbbdae4fa..dea34fb87 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -32,14 +32,18 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // set up the replacer - repl := newReplacer(r, w) + // set up the context for the request + repl := caddy2.NewReplacer() ctx := context.WithValue(r.Context(), caddy2.ReplacerCtxKey, repl) ctx = context.WithValue(ctx, TableCtxKey, make(map[string]interface{})) // TODO: Implement this r = r.WithContext(ctx) + // once the pointer to the request won't change + // anymore, finish setting up the replacer + addHTTPVarsToReplacer(repl, r, w) + // build and execute the main handler chain - stack := s.Routes.BuildCompositeRoute(w, r) + stack, w := s.Routes.BuildCompositeRoute(w, r) err := s.executeCompositeRoute(w, r, stack) if err != nil { // add the raw error value to the request context @@ -64,7 +68,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(handlerErr.StatusCode) } } else { - errStack := s.Errors.Routes.BuildCompositeRoute(w, r) + errStack, w := s.Errors.Routes.BuildCompositeRoute(w, r) err := s.executeCompositeRoute(w, r, errStack) if err != nil { // TODO: what should we do if the error handler has an error? |