summaryrefslogtreecommitdiffhomepage
path: root/caddyhttp
diff options
context:
space:
mode:
authorBenny Ng <[email protected]>2016-11-01 12:34:39 +0800
committerBenny Ng <[email protected]>2016-11-04 19:15:36 +0800
commitdd4c4d7eb649d4d0ae1b381a501f431c684ab27c (patch)
tree3af890b80ad015204b0f6ce19a1c7fadcdd2b5cc /caddyhttp
parent0cdaaba4b87da8a72c0591ee301bb50df94a6371 (diff)
downloadcaddy-dd4c4d7eb649d4d0ae1b381a501f431c684ab27c.tar.gz
caddy-dd4c4d7eb649d4d0ae1b381a501f431c684ab27c.zip
proxy: record request Body for retry (fixes #1229)
Diffstat (limited to 'caddyhttp')
-rw-r--r--caddyhttp/proxy/body.go40
-rw-r--r--caddyhttp/proxy/body_test.go69
-rw-r--r--caddyhttp/proxy/proxy.go14
-rw-r--r--caddyhttp/proxy/proxy_test.go58
4 files changed, 181 insertions, 0 deletions
diff --git a/caddyhttp/proxy/body.go b/caddyhttp/proxy/body.go
new file mode 100644
index 000000000..38d001659
--- /dev/null
+++ b/caddyhttp/proxy/body.go
@@ -0,0 +1,40 @@
+package proxy
+
+import (
+ "bytes"
+ "io"
+ "io/ioutil"
+)
+
+type bufferedBody struct {
+ *bytes.Reader
+}
+
+func (*bufferedBody) Close() error {
+ return nil
+}
+
+// rewind allows bufferedBody to be read again.
+func (b *bufferedBody) rewind() error {
+ if b == nil {
+ return nil
+ }
+ _, err := b.Seek(0, io.SeekStart)
+ return err
+}
+
+// newBufferedBody returns *bufferedBody to use in place of src. Closes src
+// and returns Read error on src. All content from src is buffered.
+func newBufferedBody(src io.ReadCloser) (*bufferedBody, error) {
+ if src == nil {
+ return nil, nil
+ }
+ b, err := ioutil.ReadAll(src)
+ src.Close()
+ if err != nil {
+ return nil, err
+ }
+ return &bufferedBody{
+ Reader: bytes.NewReader(b),
+ }, nil
+}
diff --git a/caddyhttp/proxy/body_test.go b/caddyhttp/proxy/body_test.go
new file mode 100644
index 000000000..5b72784cf
--- /dev/null
+++ b/caddyhttp/proxy/body_test.go
@@ -0,0 +1,69 @@
+package proxy
+
+import (
+ "bytes"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestBodyRetry(t *testing.T) {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ io.Copy(w, r.Body)
+ r.Body.Close()
+ }))
+ defer ts.Close()
+
+ testcase := "test content"
+ req, err := http.NewRequest(http.MethodPost, ts.URL, bytes.NewBufferString(testcase))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ body, err := newBufferedBody(req.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if body != nil {
+ req.Body = body
+ }
+
+ // simulate fail request
+ host := req.URL.Host
+ req.URL.Host = "example.com"
+ body.rewind()
+ _, _ = http.DefaultTransport.RoundTrip(req)
+
+ // retry request
+ req.URL.Host = host
+ body.rewind()
+ resp, err := http.DefaultTransport.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ result, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+ if string(result) != testcase {
+ t.Fatalf("result = %s, want %s", result, testcase)
+ }
+
+ // try one more time for body reuse
+ body.rewind()
+ resp, err = http.DefaultTransport.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ result, err = ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+ if string(result) != testcase {
+ t.Fatalf("result = %s, want %s", result, testcase)
+ }
+}
diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go
index 71c7476b5..11f2d5d01 100644
--- a/caddyhttp/proxy/proxy.go
+++ b/caddyhttp/proxy/proxy.go
@@ -94,6 +94,15 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// outreq is the request that makes a roundtrip to the backend
outreq := createUpstreamRequest(r)
+ // record and replace outreq body
+ body, err := newBufferedBody(outreq.Body)
+ if err != nil {
+ return http.StatusBadRequest, errors.New("failed to read downstream request body")
+ }
+ if body != nil {
+ outreq.Body = body
+ }
+
// The keepRetrying function will return true if we should
// loop and try to select another host, or false if we
// should break and stop retrying.
@@ -164,6 +173,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
}
+ // rewind request body to its beginning
+ if err := body.rewind(); err != nil {
+ return http.StatusInternalServerError, errors.New("unable to rewind downstream request body")
+ }
+
// tell the proxy to serve the request
atomic.AddInt64(&host.Conns, 1)
backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go
index af02a17c5..290cae938 100644
--- a/caddyhttp/proxy/proxy_test.go
+++ b/caddyhttp/proxy/proxy_test.go
@@ -20,6 +20,7 @@ import (
"testing"
"time"
+ "github.com/mholt/caddy/caddyfile"
"github.com/mholt/caddy/caddyhttp/httpserver"
"golang.org/x/net/websocket"
@@ -836,6 +837,63 @@ func TestProxyDirectorURL(t *testing.T) {
}
}
+func TestReverseProxyRetry(t *testing.T) {
+ log.SetOutput(ioutil.Discard)
+ defer log.SetOutput(os.Stderr)
+
+ // set up proxy
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ io.Copy(w, r.Body)
+ r.Body.Close()
+ }))
+ defer backend.Close()
+
+ su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(`
+ proxy / localhost:65535 localhost:65534 `+backend.URL+` {
+ policy round_robin
+ fail_timeout 5s
+ max_fails 1
+ try_duration 5s
+ try_interval 250ms
+ }
+ `)))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ p := &Proxy{
+ Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
+ Upstreams: su,
+ }
+
+ // middle is required to simulate closable downstream request body
+ middle := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, err = p.ServeHTTP(w, r)
+ if err != nil {
+ t.Error(err)
+ }
+ }))
+ defer middle.Close()
+
+ testcase := "test content"
+ r, err := http.NewRequest("POST", middle.URL, bytes.NewBufferString(testcase))
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp, err := http.DefaultTransport.RoundTrip(r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ b, err := ioutil.ReadAll(resp.Body)
+ resp.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(b) != testcase {
+ t.Fatalf("string(b) = %s, want %s", string(b), testcase)
+ }
+}
+
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
uri, _ := url.Parse(name)
u := &fakeUpstream{