diff options
author | Daniel GarcĂa <[email protected]> | 2024-04-27 20:25:34 +0200 |
---|---|---|
committer | GitHub <[email protected]> | 2024-04-27 20:25:34 +0200 |
commit | 27dc67fadd3d45b9f7d8d37407cef9453b8f5802 (patch) | |
tree | 6f2b4de80e0e6bb0b89c91e315139993e589a76b /src/api/icons.rs | |
parent | 2ad33ec97f415edb2af054f527efed52b3b93a9e (diff) | |
download | vaultwarden-27dc67fadd3d45b9f7d8d37407cef9453b8f5802.tar.gz vaultwarden-27dc67fadd3d45b9f7d8d37407cef9453b8f5802.zip |
Implement custom DNS resolver (#3988)
Diffstat (limited to 'src/api/icons.rs')
-rw-r--r-- | src/api/icons.rs | 322 |
1 files changed, 85 insertions, 237 deletions
diff --git a/src/api/icons.rs b/src/api/icons.rs index 2f76b86a..94fab3f8 100644 --- a/src/api/icons.rs +++ b/src/api/icons.rs @@ -1,6 +1,6 @@ use std::{ net::IpAddr, - sync::Arc, + sync::{Arc, Mutex}, time::{Duration, SystemTime}, }; @@ -16,14 +16,13 @@ use rocket::{http::ContentType, response::Redirect, Route}; use tokio::{ fs::{create_dir_all, remove_file, symlink_metadata, File}, io::{AsyncReadExt, AsyncWriteExt}, - net::lookup_host, }; use html5gum::{Emitter, HtmlString, InfallibleTokenizer, Readable, StringReader, Tokenizer}; use crate::{ error::Error, - util::{get_reqwest_client_builder, Cached}, + util::{get_reqwest_client_builder, Cached, CustomDnsResolver, CustomResolverError}, CONFIG, }; @@ -49,48 +48,32 @@ static CLIENT: Lazy<Client> = Lazy::new(|| { let icon_download_timeout = Duration::from_secs(CONFIG.icon_download_timeout()); let pool_idle_timeout = Duration::from_secs(10); // Reuse the client between requests - let client = get_reqwest_client_builder() + get_reqwest_client_builder() .cookie_provider(Arc::clone(&cookie_store)) .timeout(icon_download_timeout) .pool_max_idle_per_host(5) // Configure the Hyper Pool to only have max 5 idle connections .pool_idle_timeout(pool_idle_timeout) // Configure the Hyper Pool to timeout after 10 seconds - .hickory_dns(true) - .default_headers(default_headers.clone()); - - match client.build() { - Ok(client) => client, - Err(e) => { - error!("Possible trust-dns error, trying with trust-dns disabled: '{e}'"); - get_reqwest_client_builder() - .cookie_provider(cookie_store) - .timeout(icon_download_timeout) - .pool_max_idle_per_host(5) // Configure the Hyper Pool to only have max 5 idle connections - .pool_idle_timeout(pool_idle_timeout) // Configure the Hyper Pool to timeout after 10 seconds - .hickory_dns(false) - .default_headers(default_headers) - .build() - .expect("Failed to build client") - } - } + .dns_resolver(CustomDnsResolver::instance()) + .default_headers(default_headers.clone()) + .build() + .expect("Failed to build client") }); // Build Regex only once since this takes a lot of time. static ICON_SIZE_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?x)(\d+)\D*(\d+)").unwrap()); -// Special HashMap which holds the user defined Regex to speedup matching the regex. -static ICON_BLACKLIST_REGEX: Lazy<dashmap::DashMap<String, Regex>> = Lazy::new(dashmap::DashMap::new); - -async fn icon_redirect(domain: &str, template: &str) -> Option<Redirect> { +#[get("/<domain>/icon.png")] +fn icon_external(domain: &str) -> Option<Redirect> { if !is_valid_domain(domain) { warn!("Invalid domain: {}", domain); return None; } - if check_domain_blacklist_reason(domain).await.is_some() { + if is_domain_blacklisted(domain) { return None; } - let url = template.replace("{}", domain); + let url = CONFIG._icon_service_url().replace("{}", domain); match CONFIG.icon_redirect_code() { 301 => Some(Redirect::moved(url)), // legacy permanent redirect 302 => Some(Redirect::found(url)), // legacy temporary redirect @@ -104,11 +87,6 @@ async fn icon_redirect(domain: &str, template: &str) -> Option<Redirect> { } #[get("/<domain>/icon.png")] -async fn icon_external(domain: &str) -> Option<Redirect> { - icon_redirect(domain, &CONFIG._icon_service_url()).await -} - -#[get("/<domain>/icon.png")] async fn icon_internal(domain: &str) -> Cached<(ContentType, Vec<u8>)> { const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png"); @@ -166,153 +144,28 @@ fn is_valid_domain(domain: &str) -> bool { true } -/// TODO: This is extracted from IpAddr::is_global, which is unstable: -/// https://doc.rust-lang.org/nightly/std/net/enum.IpAddr.html#method.is_global -/// Remove once https://github.com/rust-lang/rust/issues/27709 is merged -#[allow(clippy::nonminimal_bool)] -#[cfg(not(feature = "unstable"))] -fn is_global(ip: IpAddr) -> bool { - match ip { - IpAddr::V4(ip) => { - // check if this address is 192.0.0.9 or 192.0.0.10. These addresses are the only two - // globally routable addresses in the 192.0.0.0/24 range. - if u32::from(ip) == 0xc0000009 || u32::from(ip) == 0xc000000a { - return true; - } - !ip.is_private() - && !ip.is_loopback() - && !ip.is_link_local() - && !ip.is_broadcast() - && !ip.is_documentation() - && !(ip.octets()[0] == 100 && (ip.octets()[1] & 0b1100_0000 == 0b0100_0000)) - && !(ip.octets()[0] == 192 && ip.octets()[1] == 0 && ip.octets()[2] == 0) - && !(ip.octets()[0] & 240 == 240 && !ip.is_broadcast()) - && !(ip.octets()[0] == 198 && (ip.octets()[1] & 0xfe) == 18) - // Make sure the address is not in 0.0.0.0/8 - && ip.octets()[0] != 0 - } - IpAddr::V6(ip) => { - if ip.is_multicast() && ip.segments()[0] & 0x000f == 14 { - true - } else { - !ip.is_multicast() - && !ip.is_loopback() - && !((ip.segments()[0] & 0xffc0) == 0xfe80) - && !((ip.segments()[0] & 0xfe00) == 0xfc00) - && !ip.is_unspecified() - && !((ip.segments()[0] == 0x2001) && (ip.segments()[1] == 0xdb8)) - } - } - } -} - -#[cfg(feature = "unstable")] -fn is_global(ip: IpAddr) -> bool { - ip.is_global() -} - -/// These are some tests to check that the implementations match -/// The IPv4 can be all checked in 5 mins or so and they are correct as of nightly 2020-07-11 -/// The IPV6 can't be checked in a reasonable time, so we check about ten billion random ones, so far correct -/// Note that the is_global implementation is subject to change as new IP RFCs are created -/// -/// To run while showing progress output: -/// cargo test --features sqlite,unstable -- --nocapture --ignored -#[cfg(test)] -#[cfg(feature = "unstable")] -mod tests { - use super::*; - - #[test] - #[ignore] - fn test_ipv4_global() { - for a in 0..u8::MAX { - println!("Iter: {}/255", a); - for b in 0..u8::MAX { - for c in 0..u8::MAX { - for d in 0..u8::MAX { - let ip = IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d)); - assert_eq!(ip.is_global(), is_global(ip)) - } - } - } - } - } - - #[test] - #[ignore] - fn test_ipv6_global() { - use ring::rand::{SecureRandom, SystemRandom}; - let mut v = [0u8; 16]; - let rand = SystemRandom::new(); - for i in 0..1_000 { - println!("Iter: {}/1_000", i); - for _ in 0..10_000_000 { - rand.fill(&mut v).expect("Error generating random values"); - let ip = IpAddr::V6(std::net::Ipv6Addr::new( - (v[14] as u16) << 8 | v[15] as u16, - (v[12] as u16) << 8 | v[13] as u16, - (v[10] as u16) << 8 | v[11] as u16, - (v[8] as u16) << 8 | v[9] as u16, - (v[6] as u16) << 8 | v[7] as u16, - (v[4] as u16) << 8 | v[5] as u16, - (v[2] as u16) << 8 | v[3] as u16, - (v[0] as u16) << 8 | v[1] as u16, - )); - assert_eq!(ip.is_global(), is_global(ip)) - } - } - } -} - -#[derive(Clone)] -enum DomainBlacklistReason { - Regex, - IP, -} - -use cached::proc_macro::cached; -#[cached(key = "String", convert = r#"{ domain.to_string() }"#, size = 16, time = 60)] -async fn check_domain_blacklist_reason(domain: &str) -> Option<DomainBlacklistReason> { - // First check the blacklist regex if there is a match. - // This prevents the blocked domain(s) from being leaked via a DNS lookup. - if let Some(blacklist) = CONFIG.icon_blacklist_regex() { - // Use the pre-generate Regex stored in a Lazy HashMap if there's one, else generate it. - let is_match = if let Some(regex) = ICON_BLACKLIST_REGEX.get(&blacklist) { - regex.is_match(domain) - } else { - // Clear the current list if the previous key doesn't exists. - // To prevent growing of the HashMap after someone has changed it via the admin interface. - if ICON_BLACKLIST_REGEX.len() >= 1 { - ICON_BLACKLIST_REGEX.clear(); - } - - // Generate the regex to store in too the Lazy Static HashMap. - let blacklist_regex = Regex::new(&blacklist).unwrap(); - let is_match = blacklist_regex.is_match(domain); - ICON_BLACKLIST_REGEX.insert(blacklist.clone(), blacklist_regex); +pub fn is_domain_blacklisted(domain: &str) -> bool { + let Some(config_blacklist) = CONFIG.icon_blacklist_regex() else { + return false; + }; - is_match - }; + // Compiled domain blacklist + static COMPILED_BLACKLIST: Mutex<Option<(String, Regex)>> = Mutex::new(None); + let mut guard = COMPILED_BLACKLIST.lock().unwrap(); - if is_match { - debug!("Blacklisted domain: {} matched ICON_BLACKLIST_REGEX", domain); - return Some(DomainBlacklistReason::Regex); + // If the stored regex is up to date, use it + if let Some((value, regex)) = &*guard { + if value == &config_blacklist { + return regex.is_match(domain); } } - if CONFIG.icon_blacklist_non_global_ips() { - if let Ok(s) = lookup_host((domain, 0)).await { - for addr in s { - if !is_global(addr.ip()) { - debug!("IP {} for domain '{}' is not a global IP!", addr.ip(), domain); - return Some(DomainBlacklistReason::IP); - } - } - } - } + // If we don't have a regex stored, or it's not up to date, recreate it + let regex = Regex::new(&config_blacklist).unwrap(); + let is_match = regex.is_match(domain); + *guard = Some((config_blacklist, regex)); - None + is_match } async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> { @@ -342,6 +195,13 @@ async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> { Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string())) } Err(e) => { + // If this error comes from the custom resolver, this means this is a blacklisted domain + // or non global IP, don't save the miss file in this case to avoid leaking it + if let Some(error) = CustomResolverError::downcast_ref(&e) { + warn!("{error}"); + return None; + } + warn!("Unable to download icon: {:?}", e); let miss_indicator = path + ".miss"; save_icon(&miss_indicator, &[]).await; @@ -491,42 +351,48 @@ async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> { let ssldomain = format!("https://{domain}"); let httpdomain = format!("http://{domain}"); - // First check the domain as given during the request for both HTTPS and HTTP. - let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)).await { - Ok(c) => Ok(c), - Err(e) => { - let mut sub_resp = Err(e); - - // When the domain is not an IP, and has more then one dot, remove all subdomains. - let is_ip = domain.parse::<IpAddr>(); - if is_ip.is_err() && domain.matches('.').count() > 1 { - let mut domain_parts = domain.split('.'); - let base_domain = format!( - "{base}.{tld}", - tld = domain_parts.next_back().unwrap(), - base = domain_parts.next_back().unwrap() - ); - if is_valid_domain(&base_domain) { - let sslbase = format!("https://{base_domain}"); - let httpbase = format!("http://{base_domain}"); - debug!("[get_icon_url]: Trying without subdomains '{base_domain}'"); - - sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase)).await; - } + // First check the domain as given during the request for HTTPS. + let resp = match get_page(&ssldomain).await { + Err(e) if CustomResolverError::downcast_ref(&e).is_none() => { + // If we get an error that is not caused by the blacklist, we retry with HTTP + match get_page(&httpdomain).await { + mut sub_resp @ Err(_) => { + // When the domain is not an IP, and has more then one dot, remove all subdomains. + let is_ip = domain.parse::<IpAddr>(); + if is_ip.is_err() && domain.matches('.').count() > 1 { + let mut domain_parts = domain.split('.'); + let base_domain = format!( + "{base}.{tld}", + tld = domain_parts.next_back().unwrap(), + base = domain_parts.next_back().unwrap() + ); + if is_valid_domain(&base_domain) { + let sslbase = format!("https://{base_domain}"); + let httpbase = format!("http://{base_domain}"); + debug!("[get_icon_url]: Trying without subdomains '{base_domain}'"); + + sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase)).await; + } - // When the domain is not an IP, and has less then 2 dots, try to add www. infront of it. - } else if is_ip.is_err() && domain.matches('.').count() < 2 { - let www_domain = format!("www.{domain}"); - if is_valid_domain(&www_domain) { - let sslwww = format!("https://{www_domain}"); - let httpwww = format!("http://{www_domain}"); - debug!("[get_icon_url]: Trying with www. prefix '{www_domain}'"); + // When the domain is not an IP, and has less then 2 dots, try to add www. infront of it. + } else if is_ip.is_err() && domain.matches('.').count() < 2 { + let www_domain = format!("www.{domain}"); + if is_valid_domain(&www_domain) { + let sslwww = format!("https://{www_domain}"); + let httpwww = format!("http://{www_domain}"); + debug!("[get_icon_url]: Trying with www. prefix '{www_domain}'"); - sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww)).await; + sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww)).await; + } + } + sub_resp } + res => res, } - sub_resp } + + // If we get a result or a blacklist error, just continue + res => res, }; // Create the iconlist @@ -573,21 +439,12 @@ async fn get_page(url: &str) -> Result<Response, Error> { } async fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> { - match check_domain_blacklist_reason(url::Url::parse(url).unwrap().host_str().unwrap_or_default()).await { - Some(DomainBlacklistReason::Regex) => warn!("Favicon '{}' is from a blacklisted domain!", url), - Some(DomainBlacklistReason::IP) => warn!("Favicon '{}' is hosted on a non-global IP!", url), - None => (), - } - let mut client = CLIENT.get(url); if !referer.is_empty() { client = client.header("Referer", referer) } - match client.send().await { - Ok(c) => c.error_for_status().map_err(Into::into), - Err(e) => err_silent!(format!("{e}")), - } + Ok(client.send().await?.error_for_status()?) } /// Returns a Integer with the priority of the type of the icon which to prefer. @@ -670,12 +527,6 @@ fn parse_sizes(sizes: &str) -> (u16, u16) { } async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> { - match check_domain_blacklist_reason(domain).await { - Some(DomainBlacklistReason::Regex) => err_silent!("Domain is blacklisted", domain), - Some(DomainBlacklistReason::IP) => err_silent!("Host resolves to a non-global IP", domain), - None => (), - } - let icon_result = get_icon_url(domain).await?; let mut buffer = Bytes::new(); @@ -711,22 +562,19 @@ async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> { _ => debug!("Extracted icon from data:image uri is invalid"), }; } else { - match get_page_with_referer(&icon.href, &icon_result.referer).await { - Ok(res) => { - buffer = stream_to_bytes_limit(res, 5120 * 1024).await?; // 5120KB/5MB for each icon max (Same as icons.bitwarden.net) - - // Check if the icon type is allowed, else try an icon from the list. - icon_type = get_icon_type(&buffer); - if icon_type.is_none() { - buffer.clear(); - debug!("Icon from {}, is not a valid image type", icon.href); - continue; - } - info!("Downloaded icon from {}", icon.href); - break; - } - Err(e) => debug!("{:?}", e), - }; + let res = get_page_with_referer(&icon.href, &icon_result.referer).await?; + + buffer = stream_to_bytes_limit(res, 5120 * 1024).await?; // 5120KB/5MB for each icon max (Same as icons.bitwarden.net) + + // Check if the icon type is allowed, else try an icon from the list. + icon_type = get_icon_type(&buffer); + if icon_type.is_none() { + buffer.clear(); + debug!("Icon from {}, is not a valid image type", icon.href); + continue; + } + info!("Downloaded icon from {}", icon.href); + break; } } |