aboutsummaryrefslogtreecommitdiff
path: root/src/api/icons.rs
diff options
context:
space:
mode:
authorDaniel GarcĂ­a <[email protected]>2024-04-27 20:25:34 +0200
committerGitHub <[email protected]>2024-04-27 20:25:34 +0200
commit27dc67fadd3d45b9f7d8d37407cef9453b8f5802 (patch)
tree6f2b4de80e0e6bb0b89c91e315139993e589a76b /src/api/icons.rs
parent2ad33ec97f415edb2af054f527efed52b3b93a9e (diff)
downloadvaultwarden-27dc67fadd3d45b9f7d8d37407cef9453b8f5802.tar.gz
vaultwarden-27dc67fadd3d45b9f7d8d37407cef9453b8f5802.zip
Implement custom DNS resolver (#3988)
Diffstat (limited to 'src/api/icons.rs')
-rw-r--r--src/api/icons.rs322
1 files changed, 85 insertions, 237 deletions
diff --git a/src/api/icons.rs b/src/api/icons.rs
index 2f76b86a..94fab3f8 100644
--- a/src/api/icons.rs
+++ b/src/api/icons.rs
@@ -1,6 +1,6 @@
use std::{
net::IpAddr,
- sync::Arc,
+ sync::{Arc, Mutex},
time::{Duration, SystemTime},
};
@@ -16,14 +16,13 @@ use rocket::{http::ContentType, response::Redirect, Route};
use tokio::{
fs::{create_dir_all, remove_file, symlink_metadata, File},
io::{AsyncReadExt, AsyncWriteExt},
- net::lookup_host,
};
use html5gum::{Emitter, HtmlString, InfallibleTokenizer, Readable, StringReader, Tokenizer};
use crate::{
error::Error,
- util::{get_reqwest_client_builder, Cached},
+ util::{get_reqwest_client_builder, Cached, CustomDnsResolver, CustomResolverError},
CONFIG,
};
@@ -49,48 +48,32 @@ static CLIENT: Lazy<Client> = Lazy::new(|| {
let icon_download_timeout = Duration::from_secs(CONFIG.icon_download_timeout());
let pool_idle_timeout = Duration::from_secs(10);
// Reuse the client between requests
- let client = get_reqwest_client_builder()
+ get_reqwest_client_builder()
.cookie_provider(Arc::clone(&cookie_store))
.timeout(icon_download_timeout)
.pool_max_idle_per_host(5) // Configure the Hyper Pool to only have max 5 idle connections
.pool_idle_timeout(pool_idle_timeout) // Configure the Hyper Pool to timeout after 10 seconds
- .hickory_dns(true)
- .default_headers(default_headers.clone());
-
- match client.build() {
- Ok(client) => client,
- Err(e) => {
- error!("Possible trust-dns error, trying with trust-dns disabled: '{e}'");
- get_reqwest_client_builder()
- .cookie_provider(cookie_store)
- .timeout(icon_download_timeout)
- .pool_max_idle_per_host(5) // Configure the Hyper Pool to only have max 5 idle connections
- .pool_idle_timeout(pool_idle_timeout) // Configure the Hyper Pool to timeout after 10 seconds
- .hickory_dns(false)
- .default_headers(default_headers)
- .build()
- .expect("Failed to build client")
- }
- }
+ .dns_resolver(CustomDnsResolver::instance())
+ .default_headers(default_headers.clone())
+ .build()
+ .expect("Failed to build client")
});
// Build Regex only once since this takes a lot of time.
static ICON_SIZE_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?x)(\d+)\D*(\d+)").unwrap());
-// Special HashMap which holds the user defined Regex to speedup matching the regex.
-static ICON_BLACKLIST_REGEX: Lazy<dashmap::DashMap<String, Regex>> = Lazy::new(dashmap::DashMap::new);
-
-async fn icon_redirect(domain: &str, template: &str) -> Option<Redirect> {
+#[get("/<domain>/icon.png")]
+fn icon_external(domain: &str) -> Option<Redirect> {
if !is_valid_domain(domain) {
warn!("Invalid domain: {}", domain);
return None;
}
- if check_domain_blacklist_reason(domain).await.is_some() {
+ if is_domain_blacklisted(domain) {
return None;
}
- let url = template.replace("{}", domain);
+ let url = CONFIG._icon_service_url().replace("{}", domain);
match CONFIG.icon_redirect_code() {
301 => Some(Redirect::moved(url)), // legacy permanent redirect
302 => Some(Redirect::found(url)), // legacy temporary redirect
@@ -104,11 +87,6 @@ async fn icon_redirect(domain: &str, template: &str) -> Option<Redirect> {
}
#[get("/<domain>/icon.png")]
-async fn icon_external(domain: &str) -> Option<Redirect> {
- icon_redirect(domain, &CONFIG._icon_service_url()).await
-}
-
-#[get("/<domain>/icon.png")]
async fn icon_internal(domain: &str) -> Cached<(ContentType, Vec<u8>)> {
const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png");
@@ -166,153 +144,28 @@ fn is_valid_domain(domain: &str) -> bool {
true
}
-/// TODO: This is extracted from IpAddr::is_global, which is unstable:
-/// https://doc.rust-lang.org/nightly/std/net/enum.IpAddr.html#method.is_global
-/// Remove once https://github.com/rust-lang/rust/issues/27709 is merged
-#[allow(clippy::nonminimal_bool)]
-#[cfg(not(feature = "unstable"))]
-fn is_global(ip: IpAddr) -> bool {
- match ip {
- IpAddr::V4(ip) => {
- // check if this address is 192.0.0.9 or 192.0.0.10. These addresses are the only two
- // globally routable addresses in the 192.0.0.0/24 range.
- if u32::from(ip) == 0xc0000009 || u32::from(ip) == 0xc000000a {
- return true;
- }
- !ip.is_private()
- && !ip.is_loopback()
- && !ip.is_link_local()
- && !ip.is_broadcast()
- && !ip.is_documentation()
- && !(ip.octets()[0] == 100 && (ip.octets()[1] & 0b1100_0000 == 0b0100_0000))
- && !(ip.octets()[0] == 192 && ip.octets()[1] == 0 && ip.octets()[2] == 0)
- && !(ip.octets()[0] & 240 == 240 && !ip.is_broadcast())
- && !(ip.octets()[0] == 198 && (ip.octets()[1] & 0xfe) == 18)
- // Make sure the address is not in 0.0.0.0/8
- && ip.octets()[0] != 0
- }
- IpAddr::V6(ip) => {
- if ip.is_multicast() && ip.segments()[0] & 0x000f == 14 {
- true
- } else {
- !ip.is_multicast()
- && !ip.is_loopback()
- && !((ip.segments()[0] & 0xffc0) == 0xfe80)
- && !((ip.segments()[0] & 0xfe00) == 0xfc00)
- && !ip.is_unspecified()
- && !((ip.segments()[0] == 0x2001) && (ip.segments()[1] == 0xdb8))
- }
- }
- }
-}
-
-#[cfg(feature = "unstable")]
-fn is_global(ip: IpAddr) -> bool {
- ip.is_global()
-}
-
-/// These are some tests to check that the implementations match
-/// The IPv4 can be all checked in 5 mins or so and they are correct as of nightly 2020-07-11
-/// The IPV6 can't be checked in a reasonable time, so we check about ten billion random ones, so far correct
-/// Note that the is_global implementation is subject to change as new IP RFCs are created
-///
-/// To run while showing progress output:
-/// cargo test --features sqlite,unstable -- --nocapture --ignored
-#[cfg(test)]
-#[cfg(feature = "unstable")]
-mod tests {
- use super::*;
-
- #[test]
- #[ignore]
- fn test_ipv4_global() {
- for a in 0..u8::MAX {
- println!("Iter: {}/255", a);
- for b in 0..u8::MAX {
- for c in 0..u8::MAX {
- for d in 0..u8::MAX {
- let ip = IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d));
- assert_eq!(ip.is_global(), is_global(ip))
- }
- }
- }
- }
- }
-
- #[test]
- #[ignore]
- fn test_ipv6_global() {
- use ring::rand::{SecureRandom, SystemRandom};
- let mut v = [0u8; 16];
- let rand = SystemRandom::new();
- for i in 0..1_000 {
- println!("Iter: {}/1_000", i);
- for _ in 0..10_000_000 {
- rand.fill(&mut v).expect("Error generating random values");
- let ip = IpAddr::V6(std::net::Ipv6Addr::new(
- (v[14] as u16) << 8 | v[15] as u16,
- (v[12] as u16) << 8 | v[13] as u16,
- (v[10] as u16) << 8 | v[11] as u16,
- (v[8] as u16) << 8 | v[9] as u16,
- (v[6] as u16) << 8 | v[7] as u16,
- (v[4] as u16) << 8 | v[5] as u16,
- (v[2] as u16) << 8 | v[3] as u16,
- (v[0] as u16) << 8 | v[1] as u16,
- ));
- assert_eq!(ip.is_global(), is_global(ip))
- }
- }
- }
-}
-
-#[derive(Clone)]
-enum DomainBlacklistReason {
- Regex,
- IP,
-}
-
-use cached::proc_macro::cached;
-#[cached(key = "String", convert = r#"{ domain.to_string() }"#, size = 16, time = 60)]
-async fn check_domain_blacklist_reason(domain: &str) -> Option<DomainBlacklistReason> {
- // First check the blacklist regex if there is a match.
- // This prevents the blocked domain(s) from being leaked via a DNS lookup.
- if let Some(blacklist) = CONFIG.icon_blacklist_regex() {
- // Use the pre-generate Regex stored in a Lazy HashMap if there's one, else generate it.
- let is_match = if let Some(regex) = ICON_BLACKLIST_REGEX.get(&blacklist) {
- regex.is_match(domain)
- } else {
- // Clear the current list if the previous key doesn't exists.
- // To prevent growing of the HashMap after someone has changed it via the admin interface.
- if ICON_BLACKLIST_REGEX.len() >= 1 {
- ICON_BLACKLIST_REGEX.clear();
- }
-
- // Generate the regex to store in too the Lazy Static HashMap.
- let blacklist_regex = Regex::new(&blacklist).unwrap();
- let is_match = blacklist_regex.is_match(domain);
- ICON_BLACKLIST_REGEX.insert(blacklist.clone(), blacklist_regex);
+pub fn is_domain_blacklisted(domain: &str) -> bool {
+ let Some(config_blacklist) = CONFIG.icon_blacklist_regex() else {
+ return false;
+ };
- is_match
- };
+ // Compiled domain blacklist
+ static COMPILED_BLACKLIST: Mutex<Option<(String, Regex)>> = Mutex::new(None);
+ let mut guard = COMPILED_BLACKLIST.lock().unwrap();
- if is_match {
- debug!("Blacklisted domain: {} matched ICON_BLACKLIST_REGEX", domain);
- return Some(DomainBlacklistReason::Regex);
+ // If the stored regex is up to date, use it
+ if let Some((value, regex)) = &*guard {
+ if value == &config_blacklist {
+ return regex.is_match(domain);
}
}
- if CONFIG.icon_blacklist_non_global_ips() {
- if let Ok(s) = lookup_host((domain, 0)).await {
- for addr in s {
- if !is_global(addr.ip()) {
- debug!("IP {} for domain '{}' is not a global IP!", addr.ip(), domain);
- return Some(DomainBlacklistReason::IP);
- }
- }
- }
- }
+ // If we don't have a regex stored, or it's not up to date, recreate it
+ let regex = Regex::new(&config_blacklist).unwrap();
+ let is_match = regex.is_match(domain);
+ *guard = Some((config_blacklist, regex));
- None
+ is_match
}
async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
@@ -342,6 +195,13 @@ async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string()))
}
Err(e) => {
+ // If this error comes from the custom resolver, this means this is a blacklisted domain
+ // or non global IP, don't save the miss file in this case to avoid leaking it
+ if let Some(error) = CustomResolverError::downcast_ref(&e) {
+ warn!("{error}");
+ return None;
+ }
+
warn!("Unable to download icon: {:?}", e);
let miss_indicator = path + ".miss";
save_icon(&miss_indicator, &[]).await;
@@ -491,42 +351,48 @@ async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
let ssldomain = format!("https://{domain}");
let httpdomain = format!("http://{domain}");
- // First check the domain as given during the request for both HTTPS and HTTP.
- let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)).await {
- Ok(c) => Ok(c),
- Err(e) => {
- let mut sub_resp = Err(e);
-
- // When the domain is not an IP, and has more then one dot, remove all subdomains.
- let is_ip = domain.parse::<IpAddr>();
- if is_ip.is_err() && domain.matches('.').count() > 1 {
- let mut domain_parts = domain.split('.');
- let base_domain = format!(
- "{base}.{tld}",
- tld = domain_parts.next_back().unwrap(),
- base = domain_parts.next_back().unwrap()
- );
- if is_valid_domain(&base_domain) {
- let sslbase = format!("https://{base_domain}");
- let httpbase = format!("http://{base_domain}");
- debug!("[get_icon_url]: Trying without subdomains '{base_domain}'");
-
- sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase)).await;
- }
+ // First check the domain as given during the request for HTTPS.
+ let resp = match get_page(&ssldomain).await {
+ Err(e) if CustomResolverError::downcast_ref(&e).is_none() => {
+ // If we get an error that is not caused by the blacklist, we retry with HTTP
+ match get_page(&httpdomain).await {
+ mut sub_resp @ Err(_) => {
+ // When the domain is not an IP, and has more then one dot, remove all subdomains.
+ let is_ip = domain.parse::<IpAddr>();
+ if is_ip.is_err() && domain.matches('.').count() > 1 {
+ let mut domain_parts = domain.split('.');
+ let base_domain = format!(
+ "{base}.{tld}",
+ tld = domain_parts.next_back().unwrap(),
+ base = domain_parts.next_back().unwrap()
+ );
+ if is_valid_domain(&base_domain) {
+ let sslbase = format!("https://{base_domain}");
+ let httpbase = format!("http://{base_domain}");
+ debug!("[get_icon_url]: Trying without subdomains '{base_domain}'");
+
+ sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase)).await;
+ }
- // When the domain is not an IP, and has less then 2 dots, try to add www. infront of it.
- } else if is_ip.is_err() && domain.matches('.').count() < 2 {
- let www_domain = format!("www.{domain}");
- if is_valid_domain(&www_domain) {
- let sslwww = format!("https://{www_domain}");
- let httpwww = format!("http://{www_domain}");
- debug!("[get_icon_url]: Trying with www. prefix '{www_domain}'");
+ // When the domain is not an IP, and has less then 2 dots, try to add www. infront of it.
+ } else if is_ip.is_err() && domain.matches('.').count() < 2 {
+ let www_domain = format!("www.{domain}");
+ if is_valid_domain(&www_domain) {
+ let sslwww = format!("https://{www_domain}");
+ let httpwww = format!("http://{www_domain}");
+ debug!("[get_icon_url]: Trying with www. prefix '{www_domain}'");
- sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww)).await;
+ sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww)).await;
+ }
+ }
+ sub_resp
}
+ res => res,
}
- sub_resp
}
+
+ // If we get a result or a blacklist error, just continue
+ res => res,
};
// Create the iconlist
@@ -573,21 +439,12 @@ async fn get_page(url: &str) -> Result<Response, Error> {
}
async fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
- match check_domain_blacklist_reason(url::Url::parse(url).unwrap().host_str().unwrap_or_default()).await {
- Some(DomainBlacklistReason::Regex) => warn!("Favicon '{}' is from a blacklisted domain!", url),
- Some(DomainBlacklistReason::IP) => warn!("Favicon '{}' is hosted on a non-global IP!", url),
- None => (),
- }
-
let mut client = CLIENT.get(url);
if !referer.is_empty() {
client = client.header("Referer", referer)
}
- match client.send().await {
- Ok(c) => c.error_for_status().map_err(Into::into),
- Err(e) => err_silent!(format!("{e}")),
- }
+ Ok(client.send().await?.error_for_status()?)
}
/// Returns a Integer with the priority of the type of the icon which to prefer.
@@ -670,12 +527,6 @@ fn parse_sizes(sizes: &str) -> (u16, u16) {
}
async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
- match check_domain_blacklist_reason(domain).await {
- Some(DomainBlacklistReason::Regex) => err_silent!("Domain is blacklisted", domain),
- Some(DomainBlacklistReason::IP) => err_silent!("Host resolves to a non-global IP", domain),
- None => (),
- }
-
let icon_result = get_icon_url(domain).await?;
let mut buffer = Bytes::new();
@@ -711,22 +562,19 @@ async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
_ => debug!("Extracted icon from data:image uri is invalid"),
};
} else {
- match get_page_with_referer(&icon.href, &icon_result.referer).await {
- Ok(res) => {
- buffer = stream_to_bytes_limit(res, 5120 * 1024).await?; // 5120KB/5MB for each icon max (Same as icons.bitwarden.net)
-
- // Check if the icon type is allowed, else try an icon from the list.
- icon_type = get_icon_type(&buffer);
- if icon_type.is_none() {
- buffer.clear();
- debug!("Icon from {}, is not a valid image type", icon.href);
- continue;
- }
- info!("Downloaded icon from {}", icon.href);
- break;
- }
- Err(e) => debug!("{:?}", e),
- };
+ let res = get_page_with_referer(&icon.href, &icon_result.referer).await?;
+
+ buffer = stream_to_bytes_limit(res, 5120 * 1024).await?; // 5120KB/5MB for each icon max (Same as icons.bitwarden.net)
+
+ // Check if the icon type is allowed, else try an icon from the list.
+ icon_type = get_icon_type(&buffer);
+ if icon_type.is_none() {
+ buffer.clear();
+ debug!("Icon from {}, is not a valid image type", icon.href);
+ continue;
+ }
+ info!("Downloaded icon from {}", icon.href);
+ break;
}
}