diff options
author | Daniel García <[email protected]> | 2024-07-12 17:51:19 +0200 |
---|---|---|
committer | Daniel García <[email protected]> | 2024-07-12 19:01:23 +0200 |
commit | bbbbd2fb482684149b573aa1acbc98f7fb1408b9 (patch) | |
tree | e37ab37d386a76d6b3b5dc6429b8003d34c8c231 /src | |
parent | a4ab014ade53e4e60bda0b9cbce3af9de7eac753 (diff) | |
download | vaultwarden-bbbbd2fb482684149b573aa1acbc98f7fb1408b9.tar.gz vaultwarden-bbbbd2fb482684149b573aa1acbc98f7fb1408b9.zip |
Improved HTTP client
Diffstat (limited to 'src')
-rw-r--r-- | src/api/admin.rs | 17 | ||||
-rw-r--r-- | src/api/core/mod.rs | 8 | ||||
-rw-r--r-- | src/api/core/two_factor/duo.rs | 7 | ||||
-rw-r--r-- | src/api/icons.rs | 46 | ||||
-rw-r--r-- | src/api/mod.rs | 2 | ||||
-rw-r--r-- | src/api/push.rs | 27 | ||||
-rw-r--r-- | src/config.rs | 25 | ||||
-rw-r--r-- | src/error.rs | 5 | ||||
-rw-r--r-- | src/http_client.rs | 247 | ||||
-rw-r--r-- | src/main.rs | 1 | ||||
-rw-r--r-- | src/util.rs | 146 |
11 files changed, 321 insertions, 210 deletions
diff --git a/src/api/admin.rs b/src/api/admin.rs index 58a056b6..1ea9aa59 100644 --- a/src/api/admin.rs +++ b/src/api/admin.rs @@ -1,4 +1,5 @@ use once_cell::sync::Lazy; +use reqwest::Method; use serde::de::DeserializeOwned; use serde_json::Value; use std::env; @@ -21,10 +22,10 @@ use crate::{ config::ConfigBuilder, db::{backup_database, get_sql_server_version, models::*, DbConn, DbConnType}, error::{Error, MapResult}, + http_client::make_http_request, mail, util::{ - container_base_image, format_naive_datetime_local, get_display_size, get_reqwest_client, - is_running_in_container, NumberOrString, + container_base_image, format_naive_datetime_local, get_display_size, is_running_in_container, NumberOrString, }, CONFIG, VERSION, }; @@ -594,15 +595,15 @@ struct TimeApi { } async fn get_json_api<T: DeserializeOwned>(url: &str) -> Result<T, Error> { - let json_api = get_reqwest_client(); - - Ok(json_api.get(url).send().await?.error_for_status()?.json::<T>().await?) + Ok(make_http_request(Method::GET, url)?.send().await?.error_for_status()?.json::<T>().await?) } async fn has_http_access() -> bool { - let http_access = get_reqwest_client(); - - match http_access.head("https://github.com/dani-garcia/vaultwarden").send().await { + let req = match make_http_request(Method::HEAD, "https://github.com/dani-garcia/vaultwarden") { + Ok(r) => r, + Err(_) => return false, + }; + match req.send().await { Ok(r) => r.status().is_success(), _ => false, } diff --git a/src/api/core/mod.rs b/src/api/core/mod.rs index 9da0e886..41bd4d6b 100644 --- a/src/api/core/mod.rs +++ b/src/api/core/mod.rs @@ -12,6 +12,7 @@ pub use accounts::purge_auth_requests; pub use ciphers::{purge_trashed_ciphers, CipherData, CipherSyncData, CipherSyncType}; pub use emergency_access::{emergency_notification_reminder_job, emergency_request_timeout_job}; pub use events::{event_cleanup_job, log_event, log_user_event}; +use reqwest::Method; pub use sends::purge_sends; pub fn routes() -> Vec<Route> { @@ -53,7 +54,8 @@ use crate::{ auth::Headers, db::DbConn, error::Error, - util::{get_reqwest_client, parse_experimental_client_feature_flags}, + http_client::make_http_request, + util::parse_experimental_client_feature_flags, }; #[derive(Debug, Serialize, Deserialize)] @@ -139,9 +141,7 @@ async fn hibp_breach(username: &str) -> JsonResult { ); if let Some(api_key) = crate::CONFIG.hibp_api_key() { - let hibp_client = get_reqwest_client(); - - let res = hibp_client.get(&url).header("hibp-api-key", api_key).send().await?; + let res = make_http_request(Method::GET, &url)?.header("hibp-api-key", api_key).send().await?; // If we get a 404, return a 404, it means no breached accounts if res.status() == 404 { diff --git a/src/api/core/two_factor/duo.rs b/src/api/core/two_factor/duo.rs index c5bfa9e5..8554999c 100644 --- a/src/api/core/two_factor/duo.rs +++ b/src/api/core/two_factor/duo.rs @@ -15,7 +15,7 @@ use crate::{ DbConn, }, error::MapResult, - util::get_reqwest_client, + http_client::make_http_request, CONFIG, }; @@ -210,10 +210,7 @@ async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) let m = Method::from_str(method).unwrap_or_default(); - let client = get_reqwest_client(); - - client - .request(m, &url) + make_http_request(m, &url)? .basic_auth(username, Some(password)) .header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)") .header(header::DATE, date) diff --git a/src/api/icons.rs b/src/api/icons.rs index 94fab3f8..f8fe059b 100644 --- a/src/api/icons.rs +++ b/src/api/icons.rs @@ -1,6 +1,6 @@ use std::{ net::IpAddr, - sync::{Arc, Mutex}, + sync::Arc, time::{Duration, SystemTime}, }; @@ -22,7 +22,8 @@ use html5gum::{Emitter, HtmlString, InfallibleTokenizer, Readable, StringReader, use crate::{ error::Error, - util::{get_reqwest_client_builder, Cached, CustomDnsResolver, CustomResolverError}, + http_client::{get_reqwest_client_builder, should_block_address, CustomHttpClientError}, + util::Cached, CONFIG, }; @@ -53,7 +54,6 @@ static CLIENT: Lazy<Client> = Lazy::new(|| { .timeout(icon_download_timeout) .pool_max_idle_per_host(5) // Configure the Hyper Pool to only have max 5 idle connections .pool_idle_timeout(pool_idle_timeout) // Configure the Hyper Pool to timeout after 10 seconds - .dns_resolver(CustomDnsResolver::instance()) .default_headers(default_headers.clone()) .build() .expect("Failed to build client") @@ -69,7 +69,8 @@ fn icon_external(domain: &str) -> Option<Redirect> { return None; } - if is_domain_blacklisted(domain) { + if should_block_address(domain) { + warn!("Blocked address: {}", domain); return None; } @@ -99,6 +100,15 @@ async fn icon_internal(domain: &str) -> Cached<(ContentType, Vec<u8>)> { ); } + if should_block_address(domain) { + warn!("Blocked address: {}", domain); + return Cached::ttl( + (ContentType::new("image", "png"), FALLBACK_ICON.to_vec()), + CONFIG.icon_cache_negttl(), + true, + ); + } + match get_icon(domain).await { Some((icon, icon_type)) => { Cached::ttl((ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true) @@ -144,30 +154,6 @@ fn is_valid_domain(domain: &str) -> bool { true } -pub fn is_domain_blacklisted(domain: &str) -> bool { - let Some(config_blacklist) = CONFIG.icon_blacklist_regex() else { - return false; - }; - - // Compiled domain blacklist - static COMPILED_BLACKLIST: Mutex<Option<(String, Regex)>> = Mutex::new(None); - let mut guard = COMPILED_BLACKLIST.lock().unwrap(); - - // If the stored regex is up to date, use it - if let Some((value, regex)) = &*guard { - if value == &config_blacklist { - return regex.is_match(domain); - } - } - - // If we don't have a regex stored, or it's not up to date, recreate it - let regex = Regex::new(&config_blacklist).unwrap(); - let is_match = regex.is_match(domain); - *guard = Some((config_blacklist, regex)); - - is_match -} - async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> { let path = format!("{}/{}.png", CONFIG.icon_cache_folder(), domain); @@ -197,7 +183,7 @@ async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> { Err(e) => { // If this error comes from the custom resolver, this means this is a blacklisted domain // or non global IP, don't save the miss file in this case to avoid leaking it - if let Some(error) = CustomResolverError::downcast_ref(&e) { + if let Some(error) = CustomHttpClientError::downcast_ref(&e) { warn!("{error}"); return None; } @@ -353,7 +339,7 @@ async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> { // First check the domain as given during the request for HTTPS. let resp = match get_page(&ssldomain).await { - Err(e) if CustomResolverError::downcast_ref(&e).is_none() => { + Err(e) if CustomHttpClientError::downcast_ref(&e).is_none() => { // If we get an error that is not caused by the blacklist, we retry with HTTP match get_page(&httpdomain).await { mut sub_resp @ Err(_) => { diff --git a/src/api/mod.rs b/src/api/mod.rs index d5281bda..27a3775f 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -20,7 +20,7 @@ pub use crate::api::{ core::two_factor::send_incomplete_2fa_notifications, core::{emergency_notification_reminder_job, emergency_request_timeout_job}, core::{event_cleanup_job, events_routes as core_events_routes}, - icons::{is_domain_blacklisted, routes as icons_routes}, + icons::routes as icons_routes, identity::routes as identity_routes, notifications::routes as notifications_routes, notifications::{AnonymousNotify, Notify, UpdateType, WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS}, diff --git a/src/api/push.rs b/src/api/push.rs index 607fb7ea..eaf304f9 100644 --- a/src/api/push.rs +++ b/src/api/push.rs @@ -1,11 +1,14 @@ -use reqwest::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; +use reqwest::{ + header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}, + Method, +}; use serde_json::Value; use tokio::sync::RwLock; use crate::{ api::{ApiResult, EmptyResult, UpdateType}, db::models::{Cipher, Device, Folder, Send, User}, - util::get_reqwest_client, + http_client::make_http_request, CONFIG, }; @@ -50,8 +53,7 @@ async fn get_auth_push_token() -> ApiResult<String> { ("client_secret", &client_secret), ]; - let res = match get_reqwest_client() - .post(&format!("{}/connect/token", CONFIG.push_identity_uri())) + let res = match make_http_request(Method::POST, &format!("{}/connect/token", CONFIG.push_identity_uri()))? .form(¶ms) .send() .await @@ -104,8 +106,7 @@ pub async fn register_push_device(device: &mut Device, conn: &mut crate::db::DbC let auth_push_token = get_auth_push_token().await?; let auth_header = format!("Bearer {}", &auth_push_token); - if let Err(e) = get_reqwest_client() - .post(CONFIG.push_relay_uri() + "/push/register") + if let Err(e) = make_http_request(Method::POST, &(CONFIG.push_relay_uri() + "/push/register"))? .header(CONTENT_TYPE, "application/json") .header(ACCEPT, "application/json") .header(AUTHORIZATION, auth_header) @@ -132,8 +133,7 @@ pub async fn unregister_push_device(push_uuid: Option<String>) -> EmptyResult { let auth_header = format!("Bearer {}", &auth_push_token); - match get_reqwest_client() - .delete(CONFIG.push_relay_uri() + "/push/" + &push_uuid.unwrap()) + match make_http_request(Method::DELETE, &(CONFIG.push_relay_uri() + "/push/" + &push_uuid.unwrap()))? .header(AUTHORIZATION, auth_header) .send() .await @@ -266,8 +266,15 @@ async fn send_to_push_relay(notification_data: Value) { let auth_header = format!("Bearer {}", &auth_push_token); - if let Err(e) = get_reqwest_client() - .post(CONFIG.push_relay_uri() + "/push/send") + let req = match make_http_request(Method::POST, &(CONFIG.push_relay_uri() + "/push/send")) { + Ok(r) => r, + Err(e) => { + error!("An error occurred while sending a send update to the push relay: {}", e); + return; + } + }; + + if let Err(e) = req .header(ACCEPT, "application/json") .header(CONTENT_TYPE, "application/json") .header(AUTHORIZATION, &auth_header) diff --git a/src/config.rs b/src/config.rs index 489a229d..7ce2a255 100644 --- a/src/config.rs +++ b/src/config.rs @@ -146,6 +146,13 @@ macro_rules! make_config { config.signups_domains_whitelist = config.signups_domains_whitelist.trim().to_lowercase(); config.org_creation_users = config.org_creation_users.trim().to_lowercase(); + + // Copy the values from the deprecated flags to the new ones + if config.http_request_blacklist_regex.is_none() { + config.http_request_blacklist_non_global_ips = config.icon_blacklist_non_global_ips; + config.http_request_blacklist_regex = config.icon_blacklist_regex.clone(); + } + config } } @@ -531,12 +538,18 @@ make_config! { icon_cache_negttl: u64, true, def, 259_200; /// Icon download timeout |> Number of seconds when to stop attempting to download an icon. icon_download_timeout: u64, true, def, 10; - /// Icon blacklist Regex |> Any domains or IPs that match this regex won't be fetched by the icon service. + + /// [Deprecated] Icon blacklist Regex |> Use `icon_blacklist_regex` instead + icon_blacklist_regex: String, false, option; + /// [Deprecated] Icon blacklist non global IPs |> Use `http_request_blacklist_non_global_ips` instead + icon_blacklist_non_global_ips: bool, false, def, true; + + /// HTTP blacklist Regex |> Any domains or IPs that match this regex won't be fetched by the internal HTTP client. /// Useful to hide other servers in the local network. Check the WIKI for more details - icon_blacklist_regex: String, true, option; - /// Icon blacklist non global IPs |> Any IP which is not defined as a global IP will be blacklisted. + http_request_blacklist_regex: String, true, option; + /// Blacklist non global IPs |> Enabling this will cause the internal HTTP client to refuse to connect to any non global IP address. /// Useful to secure your internal environment: See https://en.wikipedia.org/wiki/Reserved_IP_addresses for a list of IPs which it will block - icon_blacklist_non_global_ips: bool, true, def, true; + http_request_blacklist_non_global_ips: bool, true, def, true; /// Disable Two-Factor remember |> Enabling this would force the users to use a second factor to login every time. /// Note that the checkbox would still be present, but ignored. @@ -900,11 +913,11 @@ fn validate_config(cfg: &ConfigItems) -> Result<(), Error> { } // Check if the icon blacklist regex is valid - if let Some(ref r) = cfg.icon_blacklist_regex { + if let Some(ref r) = cfg.http_request_blacklist_regex { let validate_regex = regex::Regex::new(r); match validate_regex { Ok(_) => (), - Err(e) => err!(format!("`ICON_BLACKLIST_REGEX` is invalid: {e:#?}")), + Err(e) => err!(format!("`HTTP_REQUEST_BLACKLIST_REGEX` is invalid: {e:#?}")), } } diff --git a/src/error.rs b/src/error.rs index afb1dc83..51d60982 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,6 +2,7 @@ // Error generator macro // use crate::db::models::EventType; +use crate::http_client::CustomHttpClientError; use std::error::Error as StdError; macro_rules! make_error { @@ -68,6 +69,10 @@ make_error! { Empty(Empty): _no_source, _serialize, // Used to represent err! calls Simple(String): _no_source, _api_error, + + // Used in our custom http client to handle non-global IPs and blacklisted domains + CustomHttpClient(CustomHttpClientError): _has_source, _api_error, + // Used for special return values, like 2FA errors Json(Value): _no_source, _serialize, Db(DieselErr): _has_source, _api_error, diff --git a/src/http_client.rs b/src/http_client.rs new file mode 100644 index 00000000..46f28d77 --- /dev/null +++ b/src/http_client.rs @@ -0,0 +1,247 @@ +use std::{ + fmt, + net::{IpAddr, SocketAddr}, + str::FromStr, + sync::{Arc, Mutex}, + time::Duration, +}; + +use hickory_resolver::{system_conf::read_system_conf, TokioAsyncResolver}; +use once_cell::sync::Lazy; +use regex::Regex; +use reqwest::{ + dns::{Name, Resolve, Resolving}, + header, Client, ClientBuilder, +}; +use url::Host; + +use crate::{util::is_global, CONFIG}; + +pub fn make_http_request(method: reqwest::Method, url: &str) -> Result<reqwest::RequestBuilder, crate::Error> { + let Ok(url) = url::Url::parse(url) else { + err!("Invalid URL"); + }; + let Some(host) = url.host() else { + err!("Invalid host"); + }; + + should_block_host(host)?; + + static INSTANCE: Lazy<Client> = Lazy::new(|| get_reqwest_client_builder().build().expect("Failed to build client")); + + Ok(INSTANCE.request(method, url)) +} + +pub fn get_reqwest_client_builder() -> ClientBuilder { + let mut headers = header::HeaderMap::new(); + headers.insert(header::USER_AGENT, header::HeaderValue::from_static("Vaultwarden")); + + let redirect_policy = reqwest::redirect::Policy::custom(|attempt| { + if attempt.previous().len() >= 5 { + return attempt.error("Too many redirects"); + } + + let Some(host) = attempt.url().host() else { + return attempt.error("Invalid host"); + }; + + if let Err(e) = should_block_host(host) { + return attempt.error(e); + } + + attempt.follow() + }); + + Client::builder() + .default_headers(headers) + .redirect(redirect_policy) + .dns_resolver(CustomDnsResolver::instance()) + .timeout(Duration::from_secs(10)) +} + +pub fn should_block_address(domain_or_ip: &str) -> bool { + if let Ok(ip) = IpAddr::from_str(domain_or_ip) { + if should_block_ip(ip) { + return true; + } + } + + should_block_address_blacklist(domain_or_ip) +} + +fn should_block_ip(ip: IpAddr) -> bool { + if !CONFIG.http_request_blacklist_non_global_ips() { + return false; + } + + !is_global(ip) +} + +fn should_block_address_blacklist(domain_or_ip: &str) -> bool { + let Some(config_blacklist) = CONFIG.http_request_blacklist_regex() else { + return false; + }; + + // Compiled domain blacklist + static COMPILED_BLACKLIST: Mutex<Option<(String, Regex)>> = Mutex::new(None); + let mut guard = COMPILED_BLACKLIST.lock().unwrap(); + + // If the stored regex is up to date, use it + if let Some((value, regex)) = &*guard { + if value == &config_blacklist { + return regex.is_match(domain_or_ip); + } + } + + // If we don't have a regex stored, or it's not up to date, recreate it + let regex = Regex::new(&config_blacklist).unwrap(); + let is_match = regex.is_match(domain_or_ip); + *guard = Some((config_blacklist, regex)); + + is_match +} + +fn should_block_host(host: Host<&str>) -> Result<(), CustomHttpClientError> { + let (ip, host_str): (Option<IpAddr>, String) = match host { + url::Host::Ipv4(ip) => (Some(ip.into()), ip.to_string()), + url::Host::Ipv6(ip) => (Some(ip.into()), ip.to_string()), + url::Host::Domain(d) => (None, d.to_string()), + }; + + if let Some(ip) = ip { + if should_block_ip(ip) { + return Err(CustomHttpClientError::NonGlobalIp { + domain: None, + ip, + }); + } + } + + if should_block_address_blacklist(&host_str) { + return Err(CustomHttpClientError::Blacklist { + domain: host_str, + }); + } + + Ok(()) +} + +#[derive(Debug, Clone)] +pub enum CustomHttpClientError { + Blacklist { + domain: String, + }, + NonGlobalIp { + domain: Option<String>, + ip: IpAddr, + }, +} + +impl CustomHttpClientError { + pub fn downcast_ref(e: &dyn std::error::Error) -> Option<&Self> { + let mut source = e.source(); + + while let Some(err) = source { + source = err.source(); + if let Some(err) = err.downcast_ref::<CustomHttpClientError>() { + return Some(err); + } + } + None + } +} + +impl fmt::Display for CustomHttpClientError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Blacklist { + domain, + } => write!(f, "Blacklisted domain: {domain} matched HTTP_REQUEST_BLACKLIST_REGEX"), + Self::NonGlobalIp { + domain: Some(domain), + ip, + } => write!(f, "IP {ip} for domain '{domain}' is not a global IP!"), + Self::NonGlobalIp { + domain: None, + ip, + } => write!(f, "IP {ip} is not a global IP!"), + } + } +} + +impl std::error::Error for CustomHttpClientError {} + +#[derive(Debug, Clone)] +enum CustomDnsResolver { + Default(), + Hickory(Arc<TokioAsyncResolver>), +} +type BoxError = Box<dyn std::error::Error + Send + Sync>; + +impl CustomDnsResolver { + fn instance() -> Arc<Self> { + static INSTANCE: Lazy<Arc<CustomDnsResolver>> = Lazy::new(CustomDnsResolver::new); + Arc::clone(&*INSTANCE) + } + + fn new() -> Arc<Self> { + match read_system_conf() { + Ok((config, opts)) => { + let resolver = TokioAsyncResolver::tokio(config.clone(), opts.clone()); + Arc::new(Self::Hickory(Arc::new(resolver))) + } + Err(e) => { + warn!("Error creating Hickory resolver, falling back to default: {e:?}"); + Arc::new(Self::Default()) + } + } + } + + // Note that we get an iterator of addresses, but we only grab the first one for convenience + async fn resolve_domain(&self, name: &str) -> Result<Option<SocketAddr>, BoxError> { + pre_resolve(name)?; + + let result = match self { + Self::Default() => tokio::net::lookup_host(name).await?.next(), + Self::Hickory(r) => r.lookup_ip(name).await?.iter().next().map(|a| SocketAddr::new(a, 0)), + }; + + if let Some(addr) = &result { + post_resolve(name, addr.ip())?; + } + + Ok(result) + } +} + +fn pre_resolve(name: &str) -> Result<(), CustomHttpClientError> { + if should_block_address(name) { + return Err(CustomHttpClientError::Blacklist { + domain: name.to_string(), + }); + } + + Ok(()) +} + +fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomHttpClientError> { + if should_block_ip(ip) { + Err(CustomHttpClientError::NonGlobalIp { + domain: Some(name.to_string()), + ip, + }) + } else { + Ok(()) + } +} + +impl Resolve for CustomDnsResolver { + fn resolve(&self, name: Name) -> Resolving { + let this = self.clone(); + Box::pin(async move { + let name = name.as_str(); + let result = this.resolve_domain(name).await?; + Ok::<reqwest::dns::Addrs, _>(Box::new(result.into_iter())) + }) + } +} diff --git a/src/main.rs b/src/main.rs index 73085901..ecc4f320 100644 --- a/src/main.rs +++ b/src/main.rs @@ -47,6 +47,7 @@ mod config; mod crypto; #[macro_use] mod db; +mod http_client; mod mail; mod ratelimit; mod util; diff --git a/src/util.rs b/src/util.rs index 29df7bbc..04fedbfb 100644 --- a/src/util.rs +++ b/src/util.rs @@ -4,7 +4,6 @@ use std::{collections::HashMap, io::Cursor, ops::Deref, path::Path}; use num_traits::ToPrimitive; -use once_cell::sync::Lazy; use rocket::{ fairing::{Fairing, Info, Kind}, http::{ContentType, Header, HeaderMap, Method, Status}, @@ -686,19 +685,6 @@ where } } -use reqwest::{header, Client, ClientBuilder}; - -pub fn get_reqwest_client() -> &'static Client { - static INSTANCE: Lazy<Client> = Lazy::new(|| get_reqwest_client_builder().build().expect("Failed to build client")); - &INSTANCE -} - -pub fn get_reqwest_client_builder() -> ClientBuilder { - let mut headers = header::HeaderMap::new(); - headers.insert(header::USER_AGENT, header::HeaderValue::from_static("Vaultwarden")); - Client::builder().default_headers(headers).timeout(Duration::from_secs(10)) -} - pub fn convert_json_key_lcase_first(src_json: Value) -> Value { match src_json { Value::Array(elm) => { @@ -750,138 +736,6 @@ pub fn parse_experimental_client_feature_flags(experimental_client_feature_flags feature_states } -mod dns_resolver { - use std::{ - fmt, - net::{IpAddr, SocketAddr}, - sync::Arc, - }; - - use hickory_resolver::{system_conf::read_system_conf, TokioAsyncResolver}; - use once_cell::sync::Lazy; - use reqwest::dns::{Name, Resolve, Resolving}; - - use crate::{util::is_global, CONFIG}; - - #[derive(Debug, Clone)] - pub enum CustomResolverError { - Blacklist { - domain: String, - }, - NonGlobalIp { - domain: String, - ip: IpAddr, - }, - } - - impl CustomResolverError { - pub fn downcast_ref(e: &dyn std::error::Error) -> Option<&Self> { - let mut source = e.source(); - - while let Some(err) = source { - source = err.source(); - if let Some(err) = err.downcast_ref::<CustomResolverError>() { - return Some(err); - } - } - None - } - } - - impl fmt::Display for CustomResolverError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Blacklist { - domain, - } => write!(f, "Blacklisted domain: {domain} matched ICON_BLACKLIST_REGEX"), - Self::NonGlobalIp { - domain, - ip, - } => write!(f, "IP {ip} for domain '{domain}' is not a global IP!"), - } - } - } - - impl std::error::Error for CustomResolverError {} - - #[derive(Debug, Clone)] - pub enum CustomDnsResolver { - Default(), - Hickory(Arc<TokioAsyncResolver>), - } - type BoxError = Box<dyn std::error::Error + Send + Sync>; - - impl CustomDnsResolver { - pub fn instance() -> Arc<Self> { - static INSTANCE: Lazy<Arc<CustomDnsResolver>> = Lazy::new(CustomDnsResolver::new); - Arc::clone(&*INSTANCE) - } - - fn new() -> Arc<Self> { - match read_system_conf() { - Ok((config, opts)) => { - let resolver = TokioAsyncResolver::tokio(config.clone(), opts.clone()); - Arc::new(Self::Hickory(Arc::new(resolver))) - } - Err(e) => { - warn!("Error creating Hickory resolver, falling back to default: {e:?}"); - Arc::new(Self::Default()) - } - } - } - - // Note that we get an iterator of addresses, but we only grab the first one for convenience - async fn resolve_domain(&self, name: &str) -> Result<Option<SocketAddr>, BoxError> { - pre_resolve(name)?; - - let result = match self { - Self::Default() => tokio::net::lookup_host(name).await?.next(), - Self::Hickory(r) => r.lookup_ip(name).await?.iter().next().map(|a| SocketAddr::new(a, 0)), - }; - - if let Some(addr) = &result { - post_resolve(name, addr.ip())?; - } - - Ok(result) - } - } - - fn pre_resolve(name: &str) -> Result<(), CustomResolverError> { - if crate::api::is_domain_blacklisted(name) { - return Err(CustomResolverError::Blacklist { - domain: name.to_string(), - }); - } - - Ok(()) - } - - fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomResolverError> { - if CONFIG.icon_blacklist_non_global_ips() && !is_global(ip) { - Err(CustomResolverError::NonGlobalIp { - domain: name.to_string(), - ip, - }) - } else { - Ok(()) - } - } - - impl Resolve for CustomDnsResolver { - fn resolve(&self, name: Name) -> Resolving { - let this = self.clone(); - Box::pin(async move { - let name = name.as_str(); - let result = this.resolve_domain(name).await?; - Ok::<reqwest::dns::Addrs, _>(Box::new(result.into_iter())) - }) - } - } -} - -pub use dns_resolver::{CustomDnsResolver, CustomResolverError}; - /// TODO: This is extracted from IpAddr::is_global, which is unstable: /// https://doc.rust-lang.org/nightly/std/net/enum.IpAddr.html#method.is_global /// Remove once https://github.com/rust-lang/rust/issues/27709 is merged |