summaryrefslogtreecommitdiffhomepage
path: root/caddyhttp/limits/handler.go
blob: 52fe60ab198dd72ff358dccd3e7228e076e73ca4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package limits

import (
	"io"
	"net/http"

	"github.com/mholt/caddy/caddyhttp/httpserver"
)

// Limit is a middleware to control request body size
type Limit struct {
	Next       httpserver.Handler
	BodyLimits []httpserver.PathLimit
}

func (l Limit) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
	if r.Body == nil {
		return l.Next.ServeHTTP(w, r)
	}

	// apply the path-based request body size limit.
	for _, bl := range l.BodyLimits {
		if httpserver.Path(r.URL.Path).Matches(bl.Path) {
			r.Body = MaxBytesReader(w, r.Body, bl.Limit)
			break
		}
	}

	return l.Next.ServeHTTP(w, r)
}

// MaxBytesReader and its associated methods are borrowed from the
// Go Standard library (comments intact). The only difference is that
// it returns a MaxBytesExceeded error instead of a generic error message
// when the request body has exceeded the requested limit
func MaxBytesReader(w http.ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
	return &maxBytesReader{w: w, r: r, n: n}
}

type maxBytesReader struct {
	w   http.ResponseWriter
	r   io.ReadCloser // underlying reader
	n   int64         // max bytes remaining
	err error         // sticky error
}

func (l *maxBytesReader) Read(p []byte) (n int, err error) {
	if l.err != nil {
		return 0, l.err
	}
	if len(p) == 0 {
		return 0, nil
	}
	// If they asked for a 32KB byte read but only 5 bytes are
	// remaining, no need to read 32KB. 6 bytes will answer the
	// question of the whether we hit the limit or go past it.
	if int64(len(p)) > l.n+1 {
		p = p[:l.n+1]
	}
	n, err = l.r.Read(p)

	if int64(n) <= l.n {
		l.n -= int64(n)
		l.err = err
		return n, err
	}

	n = int(l.n)
	l.n = 0

	// The server code and client code both use
	// maxBytesReader. This "requestTooLarge" check is
	// only used by the server code. To prevent binaries
	// which only using the HTTP Client code (such as
	// cmd/go) from also linking in the HTTP server, don't
	// use a static type assertion to the server
	// "*response" type. Check this interface instead:
	type requestTooLarger interface {
		requestTooLarge()
	}
	if res, ok := l.w.(requestTooLarger); ok {
		res.requestTooLarge()
	}
	l.err = httpserver.MaxBytesExceededErr
	return n, l.err
}

func (l *maxBytesReader) Close() error {
	return l.r.Close()
}