diff options
Diffstat (limited to 'modules/caddytls/matchers.go')
-rw-r--r-- | modules/caddytls/matchers.go | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/modules/caddytls/matchers.go b/modules/caddytls/matchers.go index af1f898bb..17bfe2e4c 100644 --- a/modules/caddytls/matchers.go +++ b/modules/caddytls/matchers.go @@ -30,6 +30,7 @@ import ( func init() { caddy.RegisterModule(MatchServerName{}) caddy.RegisterModule(MatchRemoteIP{}) + caddy.RegisterModule(MatchLocalIP{}) } // MatchServerName matches based on SNI. Names in @@ -144,8 +145,85 @@ func (MatchRemoteIP) matches(ip netip.Addr, ranges []netip.Prefix) bool { return false } +// MatchLocalIP matches based on the IP address of the interface +// receiving the connection. Specific IPs or CIDR ranges can be specified. +type MatchLocalIP struct { + // The IPs or CIDR ranges to match. + Ranges []string `json:"ranges,omitempty"` + + cidrs []netip.Prefix + logger *zap.Logger +} + +// CaddyModule returns the Caddy module information. +func (MatchLocalIP) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "tls.handshake_match.local_ip", + New: func() caddy.Module { return new(MatchLocalIP) }, + } +} + +// Provision parses m's IP ranges, either from IP or CIDR expressions. +func (m *MatchLocalIP) Provision(ctx caddy.Context) error { + m.logger = ctx.Logger() + for _, str := range m.Ranges { + cidrs, err := m.parseIPRange(str) + if err != nil { + return err + } + m.cidrs = append(m.cidrs, cidrs...) + } + return nil +} + +// Match matches hello based on the connection's remote IP. +func (m MatchLocalIP) Match(hello *tls.ClientHelloInfo) bool { + localAddr := hello.Conn.LocalAddr().String() + ipStr, _, err := net.SplitHostPort(localAddr) + if err != nil { + ipStr = localAddr // weird; maybe no port? + } + ipAddr, err := netip.ParseAddr(ipStr) + if err != nil { + m.logger.Error("invalid local IP addresss", zap.String("ip", ipStr)) + return false + } + return (len(m.cidrs) == 0 || m.matches(ipAddr, m.cidrs)) +} + +func (MatchLocalIP) parseIPRange(str string) ([]netip.Prefix, error) { + var cidrs []netip.Prefix + if strings.Contains(str, "/") { + ipNet, err := netip.ParsePrefix(str) + if err != nil { + return nil, fmt.Errorf("parsing CIDR expression: %v", err) + } + cidrs = append(cidrs, ipNet) + } else { + ipAddr, err := netip.ParseAddr(str) + if err != nil { + return nil, fmt.Errorf("invalid IP address: '%s': %v", str, err) + } + ip := netip.PrefixFrom(ipAddr, ipAddr.BitLen()) + cidrs = append(cidrs, ip) + } + return cidrs, nil +} + +func (MatchLocalIP) matches(ip netip.Addr, ranges []netip.Prefix) bool { + for _, ipRange := range ranges { + if ipRange.Contains(ip) { + return true + } + } + return false +} + // Interface guards var ( _ ConnectionMatcher = (*MatchServerName)(nil) _ ConnectionMatcher = (*MatchRemoteIP)(nil) + + _ caddy.Provisioner = (*MatchLocalIP)(nil) + _ ConnectionMatcher = (*MatchLocalIP)(nil) ) |