aboutsummaryrefslogtreecommitdiff
path: root/src/util.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/util.rs')
-rw-r--r--src/util.rs256
1 files changed, 248 insertions, 8 deletions
diff --git a/src/util.rs b/src/util.rs
index 8aae4bd1..a1965226 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -4,6 +4,7 @@
use std::{collections::HashMap, io::Cursor, ops::Deref, path::Path};
use num_traits::ToPrimitive;
+use once_cell::sync::Lazy;
use rocket::{
fairing::{Fairing, Info, Kind},
http::{ContentType, Header, HeaderMap, Method, Status},
@@ -701,14 +702,9 @@ where
use reqwest::{header, Client, ClientBuilder};
-pub fn get_reqwest_client() -> Client {
- match get_reqwest_client_builder().build() {
- Ok(client) => client,
- Err(e) => {
- error!("Possible trust-dns error, trying with trust-dns disabled: '{e}'");
- get_reqwest_client_builder().hickory_dns(false).build().expect("Failed to build client")
- }
- }
+pub fn get_reqwest_client() -> &'static Client {
+ static INSTANCE: Lazy<Client> = Lazy::new(|| get_reqwest_client_builder().build().expect("Failed to build client"));
+ &INSTANCE
}
pub fn get_reqwest_client_builder() -> ClientBuilder {
@@ -767,3 +763,247 @@ pub fn parse_experimental_client_feature_flags(experimental_client_feature_flags
feature_states
}
+mod dns_resolver {
+ use std::{
+ fmt,
+ net::{IpAddr, SocketAddr},
+ sync::Arc,
+ };
+
+ use hickory_resolver::{system_conf::read_system_conf, TokioAsyncResolver};
+ use once_cell::sync::Lazy;
+ use reqwest::dns::{Name, Resolve, Resolving};
+
+ use crate::{util::is_global, CONFIG};
+
+ #[derive(Debug, Clone)]
+ pub enum CustomResolverError {
+ Blacklist {
+ domain: String,
+ },
+ NonGlobalIp {
+ domain: String,
+ ip: IpAddr,
+ },
+ }
+
+ impl CustomResolverError {
+ pub fn downcast_ref(e: &dyn std::error::Error) -> Option<&Self> {
+ let mut source = e.source();
+
+ while let Some(err) = source {
+ source = err.source();
+ if let Some(err) = err.downcast_ref::<CustomResolverError>() {
+ return Some(err);
+ }
+ }
+ None
+ }
+ }
+
+ impl fmt::Display for CustomResolverError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::Blacklist {
+ domain,
+ } => write!(f, "Blacklisted domain: {domain} matched ICON_BLACKLIST_REGEX"),
+ Self::NonGlobalIp {
+ domain,
+ ip,
+ } => write!(f, "IP {ip} for domain '{domain}' is not a global IP!"),
+ }
+ }
+ }
+
+ impl std::error::Error for CustomResolverError {}
+
+ #[derive(Debug, Clone)]
+ pub enum CustomDnsResolver {
+ Default(),
+ Hickory(Arc<TokioAsyncResolver>),
+ }
+ type BoxError = Box<dyn std::error::Error + Send + Sync>;
+
+ impl CustomDnsResolver {
+ pub fn instance() -> Arc<Self> {
+ static INSTANCE: Lazy<Arc<CustomDnsResolver>> = Lazy::new(CustomDnsResolver::new);
+ Arc::clone(&*INSTANCE)
+ }
+
+ fn new() -> Arc<Self> {
+ match read_system_conf() {
+ Ok((config, opts)) => {
+ let resolver = TokioAsyncResolver::tokio(config.clone(), opts.clone());
+ Arc::new(Self::Hickory(Arc::new(resolver)))
+ }
+ Err(e) => {
+ warn!("Error creating Hickory resolver, falling back to default: {e:?}");
+ Arc::new(Self::Default())
+ }
+ }
+ }
+
+ // Note that we get an iterator of addresses, but we only grab the first one for convenience
+ async fn resolve_domain(&self, name: &str) -> Result<Option<SocketAddr>, BoxError> {
+ pre_resolve(name)?;
+
+ let result = match self {
+ Self::Default() => tokio::net::lookup_host(name).await?.next(),
+ Self::Hickory(r) => r.lookup_ip(name).await?.iter().next().map(|a| SocketAddr::new(a, 0)),
+ };
+
+ if let Some(addr) = &result {
+ post_resolve(name, addr.ip())?;
+ }
+
+ Ok(result)
+ }
+ }
+
+ fn pre_resolve(name: &str) -> Result<(), CustomResolverError> {
+ if crate::api::is_domain_blacklisted(name) {
+ return Err(CustomResolverError::Blacklist {
+ domain: name.to_string(),
+ });
+ }
+
+ Ok(())
+ }
+
+ fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomResolverError> {
+ if CONFIG.icon_blacklist_non_global_ips() && !is_global(ip) {
+ Err(CustomResolverError::NonGlobalIp {
+ domain: name.to_string(),
+ ip,
+ })
+ } else {
+ Ok(())
+ }
+ }
+
+ impl Resolve for CustomDnsResolver {
+ fn resolve(&self, name: Name) -> Resolving {
+ let this = self.clone();
+ Box::pin(async move {
+ let name = name.as_str();
+ let result = this.resolve_domain(name).await?;
+ Ok::<reqwest::dns::Addrs, _>(Box::new(result.into_iter()))
+ })
+ }
+ }
+}
+
+pub use dns_resolver::{CustomDnsResolver, CustomResolverError};
+
+/// TODO: This is extracted from IpAddr::is_global, which is unstable:
+/// https://doc.rust-lang.org/nightly/std/net/enum.IpAddr.html#method.is_global
+/// Remove once https://github.com/rust-lang/rust/issues/27709 is merged
+#[allow(clippy::nonminimal_bool)]
+#[cfg(any(not(feature = "unstable"), test))]
+pub fn is_global_hardcoded(ip: std::net::IpAddr) -> bool {
+ match ip {
+ std::net::IpAddr::V4(ip) => {
+ !(ip.octets()[0] == 0 // "This network"
+ || ip.is_private()
+ || (ip.octets()[0] == 100 && (ip.octets()[1] & 0b1100_0000 == 0b0100_0000)) //ip.is_shared()
+ || ip.is_loopback()
+ || ip.is_link_local()
+ // addresses reserved for future protocols (`192.0.0.0/24`)
+ ||(ip.octets()[0] == 192 && ip.octets()[1] == 0 && ip.octets()[2] == 0)
+ || ip.is_documentation()
+ || (ip.octets()[0] == 198 && (ip.octets()[1] & 0xfe) == 18) // ip.is_benchmarking()
+ || (ip.octets()[0] & 240 == 240 && !ip.is_broadcast()) //ip.is_reserved()
+ || ip.is_broadcast())
+ }
+ std::net::IpAddr::V6(ip) => {
+ !(ip.is_unspecified()
+ || ip.is_loopback()
+ // IPv4-mapped Address (`::ffff:0:0/96`)
+ || matches!(ip.segments(), [0, 0, 0, 0, 0, 0xffff, _, _])
+ // IPv4-IPv6 Translat. (`64:ff9b:1::/48`)
+ || matches!(ip.segments(), [0x64, 0xff9b, 1, _, _, _, _, _])
+ // Discard-Only Address Block (`100::/64`)
+ || matches!(ip.segments(), [0x100, 0, 0, 0, _, _, _, _])
+ // IETF Protocol Assignments (`2001::/23`)
+ || (matches!(ip.segments(), [0x2001, b, _, _, _, _, _, _] if b < 0x200)
+ && !(
+ // Port Control Protocol Anycast (`2001:1::1`)
+ u128::from_be_bytes(ip.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0001
+ // Traversal Using Relays around NAT Anycast (`2001:1::2`)
+ || u128::from_be_bytes(ip.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0002
+ // AMT (`2001:3::/32`)
+ || matches!(ip.segments(), [0x2001, 3, _, _, _, _, _, _])
+ // AS112-v6 (`2001:4:112::/48`)
+ || matches!(ip.segments(), [0x2001, 4, 0x112, _, _, _, _, _])
+ // ORCHIDv2 (`2001:20::/28`)
+ || matches!(ip.segments(), [0x2001, b, _, _, _, _, _, _] if (0x20..=0x2F).contains(&b))
+ ))
+ || ((ip.segments()[0] == 0x2001) && (ip.segments()[1] == 0xdb8)) // ip.is_documentation()
+ || ((ip.segments()[0] & 0xfe00) == 0xfc00) //ip.is_unique_local()
+ || ((ip.segments()[0] & 0xffc0) == 0xfe80)) //ip.is_unicast_link_local()
+ }
+ }
+}
+
+#[cfg(not(feature = "unstable"))]
+pub use is_global_hardcoded as is_global;
+
+#[cfg(feature = "unstable")]
+#[inline(always)]
+pub fn is_global(ip: std::net::IpAddr) -> bool {
+ ip.is_global()
+}
+
+/// These are some tests to check that the implementations match
+/// The IPv4 can be all checked in 30 seconds or so and they are correct as of nightly 2023-07-17
+/// The IPV6 can't be checked in a reasonable time, so we check over a hundred billion random ones, so far correct
+/// Note that the is_global implementation is subject to change as new IP RFCs are created
+///
+/// To run while showing progress output:
+/// cargo +nightly test --release --features sqlite,unstable -- --nocapture --ignored
+#[cfg(test)]
+#[cfg(feature = "unstable")]
+mod tests {
+ use super::*;
+ use std::net::IpAddr;
+
+ #[test]
+ #[ignore]
+ fn test_ipv4_global() {
+ for a in 0..u8::MAX {
+ println!("Iter: {}/255", a);
+ for b in 0..u8::MAX {
+ for c in 0..u8::MAX {
+ for d in 0..u8::MAX {
+ let ip = IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d));
+ assert_eq!(ip.is_global(), is_global_hardcoded(ip), "IP mismatch: {}", ip)
+ }
+ }
+ }
+ }
+ }
+
+ #[test]
+ #[ignore]
+ fn test_ipv6_global() {
+ use rand::Rng;
+
+ std::thread::scope(|s| {
+ for t in 0..16 {
+ let handle = s.spawn(move || {
+ let mut v = [0u8; 16];
+ let mut rng = rand::thread_rng();
+
+ for i in 0..20 {
+ println!("Thread {t} Iter: {i}/50");
+ for _ in 0..500_000_000 {
+ rng.fill(&mut v);
+ let ip = IpAddr::V6(std::net::Ipv6Addr::from(v));
+ assert_eq!(ip.is_global(), is_global_hardcoded(ip), "IP mismatch: {ip}");
+ }
+ }
+ });
+ }
+ });
+ }
+}