summaryrefslogtreecommitdiffhomepage
path: root/listeners.go
diff options
context:
space:
mode:
authorMohammed Al Sahaf <[email protected]>2019-11-12 01:33:38 +0300
committerMatt Holt <[email protected]>2019-11-11 15:33:38 -0700
commit93bc1b72e3cd566e6447ad7a1f832474aad5dfcc (patch)
tree05ddeb324261d7058925948baa0077752fd5e453 /listeners.go
parenta19da07b72d84432341990bcedce511fe2f980da (diff)
downloadcaddy-93bc1b72e3cd566e6447ad7a1f832474aad5dfcc.tar.gz
caddy-93bc1b72e3cd566e6447ad7a1f832474aad5dfcc.zip
core: Use port ranges to avoid OOM with bad inputs (#2859)
* fix OOM issue caught by fuzzing * use ParsedAddress as the struct name for the result of ParseNetworkAddress * simplify code using the ParsedAddress type * minor cleanups
Diffstat (limited to 'listeners.go')
-rw-r--r--listeners.go101
1 files changed, 75 insertions, 26 deletions
diff --git a/listeners.go b/listeners.go
index 4464b7873..37b4c299f 100644
--- a/listeners.go
+++ b/listeners.go
@@ -257,52 +257,94 @@ type globalListener struct {
pc net.PacketConn
}
-var (
- listeners = make(map[string]*globalListener)
- listenersMu sync.Mutex
-)
+// ParsedAddress contains the individual components
+// for a parsed network address of the form accepted
+// by ParseNetworkAddress(). Network should be a
+// network value accepted by Go's net package. Port
+// ranges are given by [StartPort, EndPort].
+type ParsedAddress struct {
+ Network string
+ Host string
+ StartPort uint
+ EndPort uint
+}
+
+// JoinHostPort is like net.JoinHostPort, but where the port
+// is StartPort + offset.
+func (l ParsedAddress) JoinHostPort(offset uint) string {
+ return net.JoinHostPort(l.Host, strconv.Itoa(int(l.StartPort+offset)))
+}
-// ParseNetworkAddress parses addr, a string of the form "network/host:port"
-// (with any part optional) into its component parts. Because a port can
-// also be a port range, there may be multiple addresses returned.
-func ParseNetworkAddress(addr string) (network string, addrs []string, err error) {
+// PortRangeSize returns how many ports are in
+// pa's port range. Port ranges are inclusive,
+// so the size is the difference of start and
+// end ports plus one.
+func (pa ParsedAddress) PortRangeSize() uint {
+ return (pa.EndPort - pa.StartPort) + 1
+}
+
+// String reconstructs the address string to the form expected
+// by ParseNetworkAddress().
+func (pa ParsedAddress) String() string {
+ port := strconv.FormatUint(uint64(pa.StartPort), 10)
+ if pa.StartPort != pa.EndPort {
+ port += "-" + strconv.FormatUint(uint64(pa.EndPort), 10)
+ }
+ return JoinNetworkAddress(pa.Network, pa.Host, port)
+}
+
+// ParseNetworkAddress parses addr into its individual
+// components. The input string is expected to be of
+// the form "network/host:port-range" where any part is
+// optional. The default network, if unspecified, is tcp.
+// Port ranges are inclusive.
+//
+// Network addresses are distinct from URLs and do not
+// use URL syntax.
+func ParseNetworkAddress(addr string) (ParsedAddress, error) {
var host, port string
- network, host, port, err = SplitNetworkAddress(addr)
+ network, host, port, err := SplitNetworkAddress(addr)
if network == "" {
network = "tcp"
}
if err != nil {
- return
+ return ParsedAddress{}, err
}
if network == "unix" || network == "unixgram" || network == "unixpacket" {
- addrs = []string{host}
- return
+ return ParsedAddress{
+ Network: network,
+ Host: host,
+ }, nil
}
ports := strings.SplitN(port, "-", 2)
if len(ports) == 1 {
ports = append(ports, ports[0])
}
- var start, end int
- start, err = strconv.Atoi(ports[0])
+ var start, end uint64
+ start, err = strconv.ParseUint(ports[0], 10, 16)
if err != nil {
- return
+ return ParsedAddress{}, fmt.Errorf("invalid start port: %v", err)
}
- end, err = strconv.Atoi(ports[1])
+ end, err = strconv.ParseUint(ports[1], 10, 16)
if err != nil {
- return
+ return ParsedAddress{}, fmt.Errorf("invalid end port: %v", err)
}
if end < start {
- err = fmt.Errorf("end port must be greater than start port")
- return
+ return ParsedAddress{}, fmt.Errorf("end port must not be less than start port")
}
- for p := start; p <= end; p++ {
- addrs = append(addrs, net.JoinHostPort(host, fmt.Sprintf("%d", p)))
+ if (end - start) > maxPortSpan {
+ return ParsedAddress{}, fmt.Errorf("port range exceeds %d ports", maxPortSpan)
}
- return
+ return ParsedAddress{
+ Network: network,
+ Host: host,
+ StartPort: uint(start),
+ EndPort: uint(end),
+ }, nil
}
// SplitNetworkAddress splits a into its network, host, and port components.
-// Note that port may be a port range, or omitted for unix sockets.
+// Note that port may be a port range (:X-Y), or omitted for unix sockets.
func SplitNetworkAddress(a string) (network, host, port string, err error) {
if idx := strings.Index(a, "/"); idx >= 0 {
network = strings.ToLower(strings.TrimSpace(a[:idx]))
@@ -317,9 +359,9 @@ func SplitNetworkAddress(a string) (network, host, port string, err error) {
}
// JoinNetworkAddress combines network, host, and port into a single
-// address string of the form "network/host:port". Port may be a
-// port range. For unix sockets, the network should be "unix" and
-// the path to the socket should be given in the host argument.
+// address string of the form accepted by ParseNetworkAddress(). For unix sockets, the network
+// should be "unix" and the path to the socket should be given as the
+// host parameter.
func JoinNetworkAddress(network, host, port string) string {
var a string
if network != "" {
@@ -332,3 +374,10 @@ func JoinNetworkAddress(network, host, port string) string {
}
return a
}
+
+var (
+ listeners = make(map[string]*globalListener)
+ listenersMu sync.Mutex
+)
+
+const maxPortSpan = 65535