aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDaniel García <[email protected]>2024-07-12 17:51:19 +0200
committerDaniel García <[email protected]>2024-07-12 19:01:23 +0200
commitbbbbd2fb482684149b573aa1acbc98f7fb1408b9 (patch)
treee37ab37d386a76d6b3b5dc6429b8003d34c8c231 /src
parenta4ab014ade53e4e60bda0b9cbce3af9de7eac753 (diff)
downloadvaultwarden-bbbbd2fb482684149b573aa1acbc98f7fb1408b9.tar.gz
vaultwarden-bbbbd2fb482684149b573aa1acbc98f7fb1408b9.zip
Improved HTTP client
Diffstat (limited to 'src')
-rw-r--r--src/api/admin.rs17
-rw-r--r--src/api/core/mod.rs8
-rw-r--r--src/api/core/two_factor/duo.rs7
-rw-r--r--src/api/icons.rs46
-rw-r--r--src/api/mod.rs2
-rw-r--r--src/api/push.rs27
-rw-r--r--src/config.rs25
-rw-r--r--src/error.rs5
-rw-r--r--src/http_client.rs247
-rw-r--r--src/main.rs1
-rw-r--r--src/util.rs146
11 files changed, 321 insertions, 210 deletions
diff --git a/src/api/admin.rs b/src/api/admin.rs
index 58a056b6..1ea9aa59 100644
--- a/src/api/admin.rs
+++ b/src/api/admin.rs
@@ -1,4 +1,5 @@
use once_cell::sync::Lazy;
+use reqwest::Method;
use serde::de::DeserializeOwned;
use serde_json::Value;
use std::env;
@@ -21,10 +22,10 @@ use crate::{
config::ConfigBuilder,
db::{backup_database, get_sql_server_version, models::*, DbConn, DbConnType},
error::{Error, MapResult},
+ http_client::make_http_request,
mail,
util::{
- container_base_image, format_naive_datetime_local, get_display_size, get_reqwest_client,
- is_running_in_container, NumberOrString,
+ container_base_image, format_naive_datetime_local, get_display_size, is_running_in_container, NumberOrString,
},
CONFIG, VERSION,
};
@@ -594,15 +595,15 @@ struct TimeApi {
}
async fn get_json_api<T: DeserializeOwned>(url: &str) -> Result<T, Error> {
- let json_api = get_reqwest_client();
-
- Ok(json_api.get(url).send().await?.error_for_status()?.json::<T>().await?)
+ Ok(make_http_request(Method::GET, url)?.send().await?.error_for_status()?.json::<T>().await?)
}
async fn has_http_access() -> bool {
- let http_access = get_reqwest_client();
-
- match http_access.head("https://github.com/dani-garcia/vaultwarden").send().await {
+ let req = match make_http_request(Method::HEAD, "https://github.com/dani-garcia/vaultwarden") {
+ Ok(r) => r,
+ Err(_) => return false,
+ };
+ match req.send().await {
Ok(r) => r.status().is_success(),
_ => false,
}
diff --git a/src/api/core/mod.rs b/src/api/core/mod.rs
index 9da0e886..41bd4d6b 100644
--- a/src/api/core/mod.rs
+++ b/src/api/core/mod.rs
@@ -12,6 +12,7 @@ pub use accounts::purge_auth_requests;
pub use ciphers::{purge_trashed_ciphers, CipherData, CipherSyncData, CipherSyncType};
pub use emergency_access::{emergency_notification_reminder_job, emergency_request_timeout_job};
pub use events::{event_cleanup_job, log_event, log_user_event};
+use reqwest::Method;
pub use sends::purge_sends;
pub fn routes() -> Vec<Route> {
@@ -53,7 +54,8 @@ use crate::{
auth::Headers,
db::DbConn,
error::Error,
- util::{get_reqwest_client, parse_experimental_client_feature_flags},
+ http_client::make_http_request,
+ util::parse_experimental_client_feature_flags,
};
#[derive(Debug, Serialize, Deserialize)]
@@ -139,9 +141,7 @@ async fn hibp_breach(username: &str) -> JsonResult {
);
if let Some(api_key) = crate::CONFIG.hibp_api_key() {
- let hibp_client = get_reqwest_client();
-
- let res = hibp_client.get(&url).header("hibp-api-key", api_key).send().await?;
+ let res = make_http_request(Method::GET, &url)?.header("hibp-api-key", api_key).send().await?;
// If we get a 404, return a 404, it means no breached accounts
if res.status() == 404 {
diff --git a/src/api/core/two_factor/duo.rs b/src/api/core/two_factor/duo.rs
index c5bfa9e5..8554999c 100644
--- a/src/api/core/two_factor/duo.rs
+++ b/src/api/core/two_factor/duo.rs
@@ -15,7 +15,7 @@ use crate::{
DbConn,
},
error::MapResult,
- util::get_reqwest_client,
+ http_client::make_http_request,
CONFIG,
};
@@ -210,10 +210,7 @@ async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData)
let m = Method::from_str(method).unwrap_or_default();
- let client = get_reqwest_client();
-
- client
- .request(m, &url)
+ make_http_request(m, &url)?
.basic_auth(username, Some(password))
.header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)")
.header(header::DATE, date)
diff --git a/src/api/icons.rs b/src/api/icons.rs
index 94fab3f8..f8fe059b 100644
--- a/src/api/icons.rs
+++ b/src/api/icons.rs
@@ -1,6 +1,6 @@
use std::{
net::IpAddr,
- sync::{Arc, Mutex},
+ sync::Arc,
time::{Duration, SystemTime},
};
@@ -22,7 +22,8 @@ use html5gum::{Emitter, HtmlString, InfallibleTokenizer, Readable, StringReader,
use crate::{
error::Error,
- util::{get_reqwest_client_builder, Cached, CustomDnsResolver, CustomResolverError},
+ http_client::{get_reqwest_client_builder, should_block_address, CustomHttpClientError},
+ util::Cached,
CONFIG,
};
@@ -53,7 +54,6 @@ static CLIENT: Lazy<Client> = Lazy::new(|| {
.timeout(icon_download_timeout)
.pool_max_idle_per_host(5) // Configure the Hyper Pool to only have max 5 idle connections
.pool_idle_timeout(pool_idle_timeout) // Configure the Hyper Pool to timeout after 10 seconds
- .dns_resolver(CustomDnsResolver::instance())
.default_headers(default_headers.clone())
.build()
.expect("Failed to build client")
@@ -69,7 +69,8 @@ fn icon_external(domain: &str) -> Option<Redirect> {
return None;
}
- if is_domain_blacklisted(domain) {
+ if should_block_address(domain) {
+ warn!("Blocked address: {}", domain);
return None;
}
@@ -99,6 +100,15 @@ async fn icon_internal(domain: &str) -> Cached<(ContentType, Vec<u8>)> {
);
}
+ if should_block_address(domain) {
+ warn!("Blocked address: {}", domain);
+ return Cached::ttl(
+ (ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
+ CONFIG.icon_cache_negttl(),
+ true,
+ );
+ }
+
match get_icon(domain).await {
Some((icon, icon_type)) => {
Cached::ttl((ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
@@ -144,30 +154,6 @@ fn is_valid_domain(domain: &str) -> bool {
true
}
-pub fn is_domain_blacklisted(domain: &str) -> bool {
- let Some(config_blacklist) = CONFIG.icon_blacklist_regex() else {
- return false;
- };
-
- // Compiled domain blacklist
- static COMPILED_BLACKLIST: Mutex<Option<(String, Regex)>> = Mutex::new(None);
- let mut guard = COMPILED_BLACKLIST.lock().unwrap();
-
- // If the stored regex is up to date, use it
- if let Some((value, regex)) = &*guard {
- if value == &config_blacklist {
- return regex.is_match(domain);
- }
- }
-
- // If we don't have a regex stored, or it's not up to date, recreate it
- let regex = Regex::new(&config_blacklist).unwrap();
- let is_match = regex.is_match(domain);
- *guard = Some((config_blacklist, regex));
-
- is_match
-}
-
async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
let path = format!("{}/{}.png", CONFIG.icon_cache_folder(), domain);
@@ -197,7 +183,7 @@ async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
Err(e) => {
// If this error comes from the custom resolver, this means this is a blacklisted domain
// or non global IP, don't save the miss file in this case to avoid leaking it
- if let Some(error) = CustomResolverError::downcast_ref(&e) {
+ if let Some(error) = CustomHttpClientError::downcast_ref(&e) {
warn!("{error}");
return None;
}
@@ -353,7 +339,7 @@ async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
// First check the domain as given during the request for HTTPS.
let resp = match get_page(&ssldomain).await {
- Err(e) if CustomResolverError::downcast_ref(&e).is_none() => {
+ Err(e) if CustomHttpClientError::downcast_ref(&e).is_none() => {
// If we get an error that is not caused by the blacklist, we retry with HTTP
match get_page(&httpdomain).await {
mut sub_resp @ Err(_) => {
diff --git a/src/api/mod.rs b/src/api/mod.rs
index d5281bda..27a3775f 100644
--- a/src/api/mod.rs
+++ b/src/api/mod.rs
@@ -20,7 +20,7 @@ pub use crate::api::{
core::two_factor::send_incomplete_2fa_notifications,
core::{emergency_notification_reminder_job, emergency_request_timeout_job},
core::{event_cleanup_job, events_routes as core_events_routes},
- icons::{is_domain_blacklisted, routes as icons_routes},
+ icons::routes as icons_routes,
identity::routes as identity_routes,
notifications::routes as notifications_routes,
notifications::{AnonymousNotify, Notify, UpdateType, WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS},
diff --git a/src/api/push.rs b/src/api/push.rs
index 607fb7ea..eaf304f9 100644
--- a/src/api/push.rs
+++ b/src/api/push.rs
@@ -1,11 +1,14 @@
-use reqwest::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE};
+use reqwest::{
+ header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE},
+ Method,
+};
use serde_json::Value;
use tokio::sync::RwLock;
use crate::{
api::{ApiResult, EmptyResult, UpdateType},
db::models::{Cipher, Device, Folder, Send, User},
- util::get_reqwest_client,
+ http_client::make_http_request,
CONFIG,
};
@@ -50,8 +53,7 @@ async fn get_auth_push_token() -> ApiResult<String> {
("client_secret", &client_secret),
];
- let res = match get_reqwest_client()
- .post(&format!("{}/connect/token", CONFIG.push_identity_uri()))
+ let res = match make_http_request(Method::POST, &format!("{}/connect/token", CONFIG.push_identity_uri()))?
.form(&params)
.send()
.await
@@ -104,8 +106,7 @@ pub async fn register_push_device(device: &mut Device, conn: &mut crate::db::DbC
let auth_push_token = get_auth_push_token().await?;
let auth_header = format!("Bearer {}", &auth_push_token);
- if let Err(e) = get_reqwest_client()
- .post(CONFIG.push_relay_uri() + "/push/register")
+ if let Err(e) = make_http_request(Method::POST, &(CONFIG.push_relay_uri() + "/push/register"))?
.header(CONTENT_TYPE, "application/json")
.header(ACCEPT, "application/json")
.header(AUTHORIZATION, auth_header)
@@ -132,8 +133,7 @@ pub async fn unregister_push_device(push_uuid: Option<String>) -> EmptyResult {
let auth_header = format!("Bearer {}", &auth_push_token);
- match get_reqwest_client()
- .delete(CONFIG.push_relay_uri() + "/push/" + &push_uuid.unwrap())
+ match make_http_request(Method::DELETE, &(CONFIG.push_relay_uri() + "/push/" + &push_uuid.unwrap()))?
.header(AUTHORIZATION, auth_header)
.send()
.await
@@ -266,8 +266,15 @@ async fn send_to_push_relay(notification_data: Value) {
let auth_header = format!("Bearer {}", &auth_push_token);
- if let Err(e) = get_reqwest_client()
- .post(CONFIG.push_relay_uri() + "/push/send")
+ let req = match make_http_request(Method::POST, &(CONFIG.push_relay_uri() + "/push/send")) {
+ Ok(r) => r,
+ Err(e) => {
+ error!("An error occurred while sending a send update to the push relay: {}", e);
+ return;
+ }
+ };
+
+ if let Err(e) = req
.header(ACCEPT, "application/json")
.header(CONTENT_TYPE, "application/json")
.header(AUTHORIZATION, &auth_header)
diff --git a/src/config.rs b/src/config.rs
index 489a229d..7ce2a255 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -146,6 +146,13 @@ macro_rules! make_config {
config.signups_domains_whitelist = config.signups_domains_whitelist.trim().to_lowercase();
config.org_creation_users = config.org_creation_users.trim().to_lowercase();
+
+ // Copy the values from the deprecated flags to the new ones
+ if config.http_request_blacklist_regex.is_none() {
+ config.http_request_blacklist_non_global_ips = config.icon_blacklist_non_global_ips;
+ config.http_request_blacklist_regex = config.icon_blacklist_regex.clone();
+ }
+
config
}
}
@@ -531,12 +538,18 @@ make_config! {
icon_cache_negttl: u64, true, def, 259_200;
/// Icon download timeout |> Number of seconds when to stop attempting to download an icon.
icon_download_timeout: u64, true, def, 10;
- /// Icon blacklist Regex |> Any domains or IPs that match this regex won't be fetched by the icon service.
+
+ /// [Deprecated] Icon blacklist Regex |> Use `icon_blacklist_regex` instead
+ icon_blacklist_regex: String, false, option;
+ /// [Deprecated] Icon blacklist non global IPs |> Use `http_request_blacklist_non_global_ips` instead
+ icon_blacklist_non_global_ips: bool, false, def, true;
+
+ /// HTTP blacklist Regex |> Any domains or IPs that match this regex won't be fetched by the internal HTTP client.
/// Useful to hide other servers in the local network. Check the WIKI for more details
- icon_blacklist_regex: String, true, option;
- /// Icon blacklist non global IPs |> Any IP which is not defined as a global IP will be blacklisted.
+ http_request_blacklist_regex: String, true, option;
+ /// Blacklist non global IPs |> Enabling this will cause the internal HTTP client to refuse to connect to any non global IP address.
/// Useful to secure your internal environment: See https://en.wikipedia.org/wiki/Reserved_IP_addresses for a list of IPs which it will block
- icon_blacklist_non_global_ips: bool, true, def, true;
+ http_request_blacklist_non_global_ips: bool, true, def, true;
/// Disable Two-Factor remember |> Enabling this would force the users to use a second factor to login every time.
/// Note that the checkbox would still be present, but ignored.
@@ -900,11 +913,11 @@ fn validate_config(cfg: &ConfigItems) -> Result<(), Error> {
}
// Check if the icon blacklist regex is valid
- if let Some(ref r) = cfg.icon_blacklist_regex {
+ if let Some(ref r) = cfg.http_request_blacklist_regex {
let validate_regex = regex::Regex::new(r);
match validate_regex {
Ok(_) => (),
- Err(e) => err!(format!("`ICON_BLACKLIST_REGEX` is invalid: {e:#?}")),
+ Err(e) => err!(format!("`HTTP_REQUEST_BLACKLIST_REGEX` is invalid: {e:#?}")),
}
}
diff --git a/src/error.rs b/src/error.rs
index afb1dc83..51d60982 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -2,6 +2,7 @@
// Error generator macro
//
use crate::db::models::EventType;
+use crate::http_client::CustomHttpClientError;
use std::error::Error as StdError;
macro_rules! make_error {
@@ -68,6 +69,10 @@ make_error! {
Empty(Empty): _no_source, _serialize,
// Used to represent err! calls
Simple(String): _no_source, _api_error,
+
+ // Used in our custom http client to handle non-global IPs and blacklisted domains
+ CustomHttpClient(CustomHttpClientError): _has_source, _api_error,
+
// Used for special return values, like 2FA errors
Json(Value): _no_source, _serialize,
Db(DieselErr): _has_source, _api_error,
diff --git a/src/http_client.rs b/src/http_client.rs
new file mode 100644
index 00000000..46f28d77
--- /dev/null
+++ b/src/http_client.rs
@@ -0,0 +1,247 @@
+use std::{
+ fmt,
+ net::{IpAddr, SocketAddr},
+ str::FromStr,
+ sync::{Arc, Mutex},
+ time::Duration,
+};
+
+use hickory_resolver::{system_conf::read_system_conf, TokioAsyncResolver};
+use once_cell::sync::Lazy;
+use regex::Regex;
+use reqwest::{
+ dns::{Name, Resolve, Resolving},
+ header, Client, ClientBuilder,
+};
+use url::Host;
+
+use crate::{util::is_global, CONFIG};
+
+pub fn make_http_request(method: reqwest::Method, url: &str) -> Result<reqwest::RequestBuilder, crate::Error> {
+ let Ok(url) = url::Url::parse(url) else {
+ err!("Invalid URL");
+ };
+ let Some(host) = url.host() else {
+ err!("Invalid host");
+ };
+
+ should_block_host(host)?;
+
+ static INSTANCE: Lazy<Client> = Lazy::new(|| get_reqwest_client_builder().build().expect("Failed to build client"));
+
+ Ok(INSTANCE.request(method, url))
+}
+
+pub fn get_reqwest_client_builder() -> ClientBuilder {
+ let mut headers = header::HeaderMap::new();
+ headers.insert(header::USER_AGENT, header::HeaderValue::from_static("Vaultwarden"));
+
+ let redirect_policy = reqwest::redirect::Policy::custom(|attempt| {
+ if attempt.previous().len() >= 5 {
+ return attempt.error("Too many redirects");
+ }
+
+ let Some(host) = attempt.url().host() else {
+ return attempt.error("Invalid host");
+ };
+
+ if let Err(e) = should_block_host(host) {
+ return attempt.error(e);
+ }
+
+ attempt.follow()
+ });
+
+ Client::builder()
+ .default_headers(headers)
+ .redirect(redirect_policy)
+ .dns_resolver(CustomDnsResolver::instance())
+ .timeout(Duration::from_secs(10))
+}
+
+pub fn should_block_address(domain_or_ip: &str) -> bool {
+ if let Ok(ip) = IpAddr::from_str(domain_or_ip) {
+ if should_block_ip(ip) {
+ return true;
+ }
+ }
+
+ should_block_address_blacklist(domain_or_ip)
+}
+
+fn should_block_ip(ip: IpAddr) -> bool {
+ if !CONFIG.http_request_blacklist_non_global_ips() {
+ return false;
+ }
+
+ !is_global(ip)
+}
+
+fn should_block_address_blacklist(domain_or_ip: &str) -> bool {
+ let Some(config_blacklist) = CONFIG.http_request_blacklist_regex() else {
+ return false;
+ };
+
+ // Compiled domain blacklist
+ static COMPILED_BLACKLIST: Mutex<Option<(String, Regex)>> = Mutex::new(None);
+ let mut guard = COMPILED_BLACKLIST.lock().unwrap();
+
+ // If the stored regex is up to date, use it
+ if let Some((value, regex)) = &*guard {
+ if value == &config_blacklist {
+ return regex.is_match(domain_or_ip);
+ }
+ }
+
+ // If we don't have a regex stored, or it's not up to date, recreate it
+ let regex = Regex::new(&config_blacklist).unwrap();
+ let is_match = regex.is_match(domain_or_ip);
+ *guard = Some((config_blacklist, regex));
+
+ is_match
+}
+
+fn should_block_host(host: Host<&str>) -> Result<(), CustomHttpClientError> {
+ let (ip, host_str): (Option<IpAddr>, String) = match host {
+ url::Host::Ipv4(ip) => (Some(ip.into()), ip.to_string()),
+ url::Host::Ipv6(ip) => (Some(ip.into()), ip.to_string()),
+ url::Host::Domain(d) => (None, d.to_string()),
+ };
+
+ if let Some(ip) = ip {
+ if should_block_ip(ip) {
+ return Err(CustomHttpClientError::NonGlobalIp {
+ domain: None,
+ ip,
+ });
+ }
+ }
+
+ if should_block_address_blacklist(&host_str) {
+ return Err(CustomHttpClientError::Blacklist {
+ domain: host_str,
+ });
+ }
+
+ Ok(())
+}
+
+#[derive(Debug, Clone)]
+pub enum CustomHttpClientError {
+ Blacklist {
+ domain: String,
+ },
+ NonGlobalIp {
+ domain: Option<String>,
+ ip: IpAddr,
+ },
+}
+
+impl CustomHttpClientError {
+ pub fn downcast_ref(e: &dyn std::error::Error) -> Option<&Self> {
+ let mut source = e.source();
+
+ while let Some(err) = source {
+ source = err.source();
+ if let Some(err) = err.downcast_ref::<CustomHttpClientError>() {
+ return Some(err);
+ }
+ }
+ None
+ }
+}
+
+impl fmt::Display for CustomHttpClientError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::Blacklist {
+ domain,
+ } => write!(f, "Blacklisted domain: {domain} matched HTTP_REQUEST_BLACKLIST_REGEX"),
+ Self::NonGlobalIp {
+ domain: Some(domain),
+ ip,
+ } => write!(f, "IP {ip} for domain '{domain}' is not a global IP!"),
+ Self::NonGlobalIp {
+ domain: None,
+ ip,
+ } => write!(f, "IP {ip} is not a global IP!"),
+ }
+ }
+}
+
+impl std::error::Error for CustomHttpClientError {}
+
+#[derive(Debug, Clone)]
+enum CustomDnsResolver {
+ Default(),
+ Hickory(Arc<TokioAsyncResolver>),
+}
+type BoxError = Box<dyn std::error::Error + Send + Sync>;
+
+impl CustomDnsResolver {
+ fn instance() -> Arc<Self> {
+ static INSTANCE: Lazy<Arc<CustomDnsResolver>> = Lazy::new(CustomDnsResolver::new);
+ Arc::clone(&*INSTANCE)
+ }
+
+ fn new() -> Arc<Self> {
+ match read_system_conf() {
+ Ok((config, opts)) => {
+ let resolver = TokioAsyncResolver::tokio(config.clone(), opts.clone());
+ Arc::new(Self::Hickory(Arc::new(resolver)))
+ }
+ Err(e) => {
+ warn!("Error creating Hickory resolver, falling back to default: {e:?}");
+ Arc::new(Self::Default())
+ }
+ }
+ }
+
+ // Note that we get an iterator of addresses, but we only grab the first one for convenience
+ async fn resolve_domain(&self, name: &str) -> Result<Option<SocketAddr>, BoxError> {
+ pre_resolve(name)?;
+
+ let result = match self {
+ Self::Default() => tokio::net::lookup_host(name).await?.next(),
+ Self::Hickory(r) => r.lookup_ip(name).await?.iter().next().map(|a| SocketAddr::new(a, 0)),
+ };
+
+ if let Some(addr) = &result {
+ post_resolve(name, addr.ip())?;
+ }
+
+ Ok(result)
+ }
+}
+
+fn pre_resolve(name: &str) -> Result<(), CustomHttpClientError> {
+ if should_block_address(name) {
+ return Err(CustomHttpClientError::Blacklist {
+ domain: name.to_string(),
+ });
+ }
+
+ Ok(())
+}
+
+fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomHttpClientError> {
+ if should_block_ip(ip) {
+ Err(CustomHttpClientError::NonGlobalIp {
+ domain: Some(name.to_string()),
+ ip,
+ })
+ } else {
+ Ok(())
+ }
+}
+
+impl Resolve for CustomDnsResolver {
+ fn resolve(&self, name: Name) -> Resolving {
+ let this = self.clone();
+ Box::pin(async move {
+ let name = name.as_str();
+ let result = this.resolve_domain(name).await?;
+ Ok::<reqwest::dns::Addrs, _>(Box::new(result.into_iter()))
+ })
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index 73085901..ecc4f320 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -47,6 +47,7 @@ mod config;
mod crypto;
#[macro_use]
mod db;
+mod http_client;
mod mail;
mod ratelimit;
mod util;
diff --git a/src/util.rs b/src/util.rs
index 29df7bbc..04fedbfb 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -4,7 +4,6 @@
use std::{collections::HashMap, io::Cursor, ops::Deref, path::Path};
use num_traits::ToPrimitive;
-use once_cell::sync::Lazy;
use rocket::{
fairing::{Fairing, Info, Kind},
http::{ContentType, Header, HeaderMap, Method, Status},
@@ -686,19 +685,6 @@ where
}
}
-use reqwest::{header, Client, ClientBuilder};
-
-pub fn get_reqwest_client() -> &'static Client {
- static INSTANCE: Lazy<Client> = Lazy::new(|| get_reqwest_client_builder().build().expect("Failed to build client"));
- &INSTANCE
-}
-
-pub fn get_reqwest_client_builder() -> ClientBuilder {
- let mut headers = header::HeaderMap::new();
- headers.insert(header::USER_AGENT, header::HeaderValue::from_static("Vaultwarden"));
- Client::builder().default_headers(headers).timeout(Duration::from_secs(10))
-}
-
pub fn convert_json_key_lcase_first(src_json: Value) -> Value {
match src_json {
Value::Array(elm) => {
@@ -750,138 +736,6 @@ pub fn parse_experimental_client_feature_flags(experimental_client_feature_flags
feature_states
}
-mod dns_resolver {
- use std::{
- fmt,
- net::{IpAddr, SocketAddr},
- sync::Arc,
- };
-
- use hickory_resolver::{system_conf::read_system_conf, TokioAsyncResolver};
- use once_cell::sync::Lazy;
- use reqwest::dns::{Name, Resolve, Resolving};
-
- use crate::{util::is_global, CONFIG};
-
- #[derive(Debug, Clone)]
- pub enum CustomResolverError {
- Blacklist {
- domain: String,
- },
- NonGlobalIp {
- domain: String,
- ip: IpAddr,
- },
- }
-
- impl CustomResolverError {
- pub fn downcast_ref(e: &dyn std::error::Error) -> Option<&Self> {
- let mut source = e.source();
-
- while let Some(err) = source {
- source = err.source();
- if let Some(err) = err.downcast_ref::<CustomResolverError>() {
- return Some(err);
- }
- }
- None
- }
- }
-
- impl fmt::Display for CustomResolverError {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- match self {
- Self::Blacklist {
- domain,
- } => write!(f, "Blacklisted domain: {domain} matched ICON_BLACKLIST_REGEX"),
- Self::NonGlobalIp {
- domain,
- ip,
- } => write!(f, "IP {ip} for domain '{domain}' is not a global IP!"),
- }
- }
- }
-
- impl std::error::Error for CustomResolverError {}
-
- #[derive(Debug, Clone)]
- pub enum CustomDnsResolver {
- Default(),
- Hickory(Arc<TokioAsyncResolver>),
- }
- type BoxError = Box<dyn std::error::Error + Send + Sync>;
-
- impl CustomDnsResolver {
- pub fn instance() -> Arc<Self> {
- static INSTANCE: Lazy<Arc<CustomDnsResolver>> = Lazy::new(CustomDnsResolver::new);
- Arc::clone(&*INSTANCE)
- }
-
- fn new() -> Arc<Self> {
- match read_system_conf() {
- Ok((config, opts)) => {
- let resolver = TokioAsyncResolver::tokio(config.clone(), opts.clone());
- Arc::new(Self::Hickory(Arc::new(resolver)))
- }
- Err(e) => {
- warn!("Error creating Hickory resolver, falling back to default: {e:?}");
- Arc::new(Self::Default())
- }
- }
- }
-
- // Note that we get an iterator of addresses, but we only grab the first one for convenience
- async fn resolve_domain(&self, name: &str) -> Result<Option<SocketAddr>, BoxError> {
- pre_resolve(name)?;
-
- let result = match self {
- Self::Default() => tokio::net::lookup_host(name).await?.next(),
- Self::Hickory(r) => r.lookup_ip(name).await?.iter().next().map(|a| SocketAddr::new(a, 0)),
- };
-
- if let Some(addr) = &result {
- post_resolve(name, addr.ip())?;
- }
-
- Ok(result)
- }
- }
-
- fn pre_resolve(name: &str) -> Result<(), CustomResolverError> {
- if crate::api::is_domain_blacklisted(name) {
- return Err(CustomResolverError::Blacklist {
- domain: name.to_string(),
- });
- }
-
- Ok(())
- }
-
- fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomResolverError> {
- if CONFIG.icon_blacklist_non_global_ips() && !is_global(ip) {
- Err(CustomResolverError::NonGlobalIp {
- domain: name.to_string(),
- ip,
- })
- } else {
- Ok(())
- }
- }
-
- impl Resolve for CustomDnsResolver {
- fn resolve(&self, name: Name) -> Resolving {
- let this = self.clone();
- Box::pin(async move {
- let name = name.as_str();
- let result = this.resolve_domain(name).await?;
- Ok::<reqwest::dns::Addrs, _>(Box::new(result.into_iter()))
- })
- }
- }
-}
-
-pub use dns_resolver::{CustomDnsResolver, CustomResolverError};
-
/// TODO: This is extracted from IpAddr::is_global, which is unstable:
/// https://doc.rust-lang.org/nightly/std/net/enum.IpAddr.html#method.is_global
/// Remove once https://github.com/rust-lang/rust/issues/27709 is merged