diff options
author | Daniel GarcĂa <[email protected]> | 2024-03-17 14:00:17 +0100 |
---|---|---|
committer | BlackDex <[email protected]> | 2024-04-06 22:17:55 +0200 |
commit | 7bd2a2b74c5c384a9aee36ed27a45ad6621a84ca (patch) | |
tree | cc069d85533c9ec97de34e2dc34f9d8a00f37e61 /src/util.rs | |
parent | e1a8df96dbadfbf5ad36ce9aa2f31f34396166c2 (diff) | |
download | vaultwarden-icons_dns.tar.gz vaultwarden-icons_dns.zip |
Implement custom DNS resolvericons_dns
Diffstat (limited to 'src/util.rs')
-rw-r--r-- | src/util.rs | 256 |
1 files changed, 248 insertions, 8 deletions
diff --git a/src/util.rs b/src/util.rs index 8aae4bd1..a1965226 100644 --- a/src/util.rs +++ b/src/util.rs @@ -4,6 +4,7 @@ use std::{collections::HashMap, io::Cursor, ops::Deref, path::Path}; use num_traits::ToPrimitive; +use once_cell::sync::Lazy; use rocket::{ fairing::{Fairing, Info, Kind}, http::{ContentType, Header, HeaderMap, Method, Status}, @@ -701,14 +702,9 @@ where use reqwest::{header, Client, ClientBuilder}; -pub fn get_reqwest_client() -> Client { - match get_reqwest_client_builder().build() { - Ok(client) => client, - Err(e) => { - error!("Possible trust-dns error, trying with trust-dns disabled: '{e}'"); - get_reqwest_client_builder().hickory_dns(false).build().expect("Failed to build client") - } - } +pub fn get_reqwest_client() -> &'static Client { + static INSTANCE: Lazy<Client> = Lazy::new(|| get_reqwest_client_builder().build().expect("Failed to build client")); + &INSTANCE } pub fn get_reqwest_client_builder() -> ClientBuilder { @@ -767,3 +763,247 @@ pub fn parse_experimental_client_feature_flags(experimental_client_feature_flags feature_states } +mod dns_resolver { + use std::{ + fmt, + net::{IpAddr, SocketAddr}, + sync::Arc, + }; + + use hickory_resolver::{system_conf::read_system_conf, TokioAsyncResolver}; + use once_cell::sync::Lazy; + use reqwest::dns::{Name, Resolve, Resolving}; + + use crate::{util::is_global, CONFIG}; + + #[derive(Debug, Clone)] + pub enum CustomResolverError { + Blacklist { + domain: String, + }, + NonGlobalIp { + domain: String, + ip: IpAddr, + }, + } + + impl CustomResolverError { + pub fn downcast_ref(e: &dyn std::error::Error) -> Option<&Self> { + let mut source = e.source(); + + while let Some(err) = source { + source = err.source(); + if let Some(err) = err.downcast_ref::<CustomResolverError>() { + return Some(err); + } + } + None + } + } + + impl fmt::Display for CustomResolverError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Blacklist { + domain, + } => write!(f, "Blacklisted domain: {domain} matched ICON_BLACKLIST_REGEX"), + Self::NonGlobalIp { + domain, + ip, + } => write!(f, "IP {ip} for domain '{domain}' is not a global IP!"), + } + } + } + + impl std::error::Error for CustomResolverError {} + + #[derive(Debug, Clone)] + pub enum CustomDnsResolver { + Default(), + Hickory(Arc<TokioAsyncResolver>), + } + type BoxError = Box<dyn std::error::Error + Send + Sync>; + + impl CustomDnsResolver { + pub fn instance() -> Arc<Self> { + static INSTANCE: Lazy<Arc<CustomDnsResolver>> = Lazy::new(CustomDnsResolver::new); + Arc::clone(&*INSTANCE) + } + + fn new() -> Arc<Self> { + match read_system_conf() { + Ok((config, opts)) => { + let resolver = TokioAsyncResolver::tokio(config.clone(), opts.clone()); + Arc::new(Self::Hickory(Arc::new(resolver))) + } + Err(e) => { + warn!("Error creating Hickory resolver, falling back to default: {e:?}"); + Arc::new(Self::Default()) + } + } + } + + // Note that we get an iterator of addresses, but we only grab the first one for convenience + async fn resolve_domain(&self, name: &str) -> Result<Option<SocketAddr>, BoxError> { + pre_resolve(name)?; + + let result = match self { + Self::Default() => tokio::net::lookup_host(name).await?.next(), + Self::Hickory(r) => r.lookup_ip(name).await?.iter().next().map(|a| SocketAddr::new(a, 0)), + }; + + if let Some(addr) = &result { + post_resolve(name, addr.ip())?; + } + + Ok(result) + } + } + + fn pre_resolve(name: &str) -> Result<(), CustomResolverError> { + if crate::api::is_domain_blacklisted(name) { + return Err(CustomResolverError::Blacklist { + domain: name.to_string(), + }); + } + + Ok(()) + } + + fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomResolverError> { + if CONFIG.icon_blacklist_non_global_ips() && !is_global(ip) { + Err(CustomResolverError::NonGlobalIp { + domain: name.to_string(), + ip, + }) + } else { + Ok(()) + } + } + + impl Resolve for CustomDnsResolver { + fn resolve(&self, name: Name) -> Resolving { + let this = self.clone(); + Box::pin(async move { + let name = name.as_str(); + let result = this.resolve_domain(name).await?; + Ok::<reqwest::dns::Addrs, _>(Box::new(result.into_iter())) + }) + } + } +} + +pub use dns_resolver::{CustomDnsResolver, CustomResolverError}; + +/// TODO: This is extracted from IpAddr::is_global, which is unstable: +/// https://doc.rust-lang.org/nightly/std/net/enum.IpAddr.html#method.is_global +/// Remove once https://github.com/rust-lang/rust/issues/27709 is merged +#[allow(clippy::nonminimal_bool)] +#[cfg(any(not(feature = "unstable"), test))] +pub fn is_global_hardcoded(ip: std::net::IpAddr) -> bool { + match ip { + std::net::IpAddr::V4(ip) => { + !(ip.octets()[0] == 0 // "This network" + || ip.is_private() + || (ip.octets()[0] == 100 && (ip.octets()[1] & 0b1100_0000 == 0b0100_0000)) //ip.is_shared() + || ip.is_loopback() + || ip.is_link_local() + // addresses reserved for future protocols (`192.0.0.0/24`) + ||(ip.octets()[0] == 192 && ip.octets()[1] == 0 && ip.octets()[2] == 0) + || ip.is_documentation() + || (ip.octets()[0] == 198 && (ip.octets()[1] & 0xfe) == 18) // ip.is_benchmarking() + || (ip.octets()[0] & 240 == 240 && !ip.is_broadcast()) //ip.is_reserved() + || ip.is_broadcast()) + } + std::net::IpAddr::V6(ip) => { + !(ip.is_unspecified() + || ip.is_loopback() + // IPv4-mapped Address (`::ffff:0:0/96`) + || matches!(ip.segments(), [0, 0, 0, 0, 0, 0xffff, _, _]) + // IPv4-IPv6 Translat. (`64:ff9b:1::/48`) + || matches!(ip.segments(), [0x64, 0xff9b, 1, _, _, _, _, _]) + // Discard-Only Address Block (`100::/64`) + || matches!(ip.segments(), [0x100, 0, 0, 0, _, _, _, _]) + // IETF Protocol Assignments (`2001::/23`) + || (matches!(ip.segments(), [0x2001, b, _, _, _, _, _, _] if b < 0x200) + && !( + // Port Control Protocol Anycast (`2001:1::1`) + u128::from_be_bytes(ip.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0001 + // Traversal Using Relays around NAT Anycast (`2001:1::2`) + || u128::from_be_bytes(ip.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0002 + // AMT (`2001:3::/32`) + || matches!(ip.segments(), [0x2001, 3, _, _, _, _, _, _]) + // AS112-v6 (`2001:4:112::/48`) + || matches!(ip.segments(), [0x2001, 4, 0x112, _, _, _, _, _]) + // ORCHIDv2 (`2001:20::/28`) + || matches!(ip.segments(), [0x2001, b, _, _, _, _, _, _] if (0x20..=0x2F).contains(&b)) + )) + || ((ip.segments()[0] == 0x2001) && (ip.segments()[1] == 0xdb8)) // ip.is_documentation() + || ((ip.segments()[0] & 0xfe00) == 0xfc00) //ip.is_unique_local() + || ((ip.segments()[0] & 0xffc0) == 0xfe80)) //ip.is_unicast_link_local() + } + } +} + +#[cfg(not(feature = "unstable"))] +pub use is_global_hardcoded as is_global; + +#[cfg(feature = "unstable")] +#[inline(always)] +pub fn is_global(ip: std::net::IpAddr) -> bool { + ip.is_global() +} + +/// These are some tests to check that the implementations match +/// The IPv4 can be all checked in 30 seconds or so and they are correct as of nightly 2023-07-17 +/// The IPV6 can't be checked in a reasonable time, so we check over a hundred billion random ones, so far correct +/// Note that the is_global implementation is subject to change as new IP RFCs are created +/// +/// To run while showing progress output: +/// cargo +nightly test --release --features sqlite,unstable -- --nocapture --ignored +#[cfg(test)] +#[cfg(feature = "unstable")] +mod tests { + use super::*; + use std::net::IpAddr; + + #[test] + #[ignore] + fn test_ipv4_global() { + for a in 0..u8::MAX { + println!("Iter: {}/255", a); + for b in 0..u8::MAX { + for c in 0..u8::MAX { + for d in 0..u8::MAX { + let ip = IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d)); + assert_eq!(ip.is_global(), is_global_hardcoded(ip), "IP mismatch: {}", ip) + } + } + } + } + } + + #[test] + #[ignore] + fn test_ipv6_global() { + use rand::Rng; + + std::thread::scope(|s| { + for t in 0..16 { + let handle = s.spawn(move || { + let mut v = [0u8; 16]; + let mut rng = rand::thread_rng(); + + for i in 0..20 { + println!("Thread {t} Iter: {i}/50"); + for _ in 0..500_000_000 { + rng.fill(&mut v); + let ip = IpAddr::V6(std::net::Ipv6Addr::from(v)); + assert_eq!(ip.is_global(), is_global_hardcoded(ip), "IP mismatch: {ip}"); + } + } + }); + } + }); + } +} |