aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel García <[email protected]>2022-05-11 21:36:11 +0200
committerDaniel García <[email protected]>2022-05-11 21:36:11 +0200
commitb9c434addbbfcdf4ce430182181d7962053eb951 (patch)
treeb9750d07b24146209acefc33b8f0d1f5c35c81ed
parent7f61dd5fe3d27fc3dc5c580d842972c273678362 (diff)
parent451ad47327339cb51b0be83d735229496fa56c4d (diff)
downloadvaultwarden-b9c434addbbfcdf4ce430182181d7962053eb951.tar.gz
vaultwarden-b9c434addbbfcdf4ce430182181d7962053eb951.zip
Merge branch 'jjlin-db-conn-init' into main
-rw-r--r--.env.template9
-rw-r--r--src/config.rs5
-rw-r--r--src/db/mod.rs46
3 files changed, 57 insertions, 3 deletions
diff --git a/.env.template b/.env.template
index 3c8a5ebb..a835200a 100644
--- a/.env.template
+++ b/.env.template
@@ -29,6 +29,15 @@
## Define the size of the connection pool used for connecting to the database.
# DATABASE_MAX_CONNS=10
+## Database connection initialization
+## Allows SQL statements to be run whenever a new database connection is created.
+## This is mainly useful for connection-scoped pragmas.
+## If empty, a database-specific default is used:
+## - SQLite: "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;"
+## - MySQL: ""
+## - PostgreSQL: ""
+# DATABASE_CONN_INIT=""
+
## Individual folders, these override %DATA_FOLDER%
# RSA_KEY_FILENAME=data/rsa_key
# ICON_CACHE_FOLDER=data/icon_cache
diff --git a/src/config.rs b/src/config.rs
index 2cef76e2..cd90caa1 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -515,11 +515,14 @@ make_config! {
db_connection_retries: u32, false, def, 15;
/// Timeout when aquiring database connection
- database_timeout: u64, false, def, 30;
+ database_timeout: u64, false, def, 30;
/// Database connection pool size
database_max_conns: u32, false, def, 10;
+ /// Database connection init |> SQL statements to run when creating a new database connection, mainly useful for connection-scoped pragmas. If empty, a database-specific default is used.
+ database_conn_init: String, false, def, "".to_string();
+
/// Bypass admin page security (Know the risks!) |> Disables the Admin Token for the admin page so you may use your own auth in-front
disable_admin_token: bool, true, def, false;
diff --git a/src/db/mod.rs b/src/db/mod.rs
index 3223fb65..0b3b7a5b 100644
--- a/src/db/mod.rs
+++ b/src/db/mod.rs
@@ -1,6 +1,10 @@
use std::{sync::Arc, time::Duration};
-use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
+use diesel::{
+ connection::SimpleConnection,
+ r2d2::{ConnectionManager, CustomizeConnection, Pool, PooledConnection},
+};
+
use rocket::{
http::Status,
outcome::IntoOutcome,
@@ -62,6 +66,23 @@ macro_rules! generate_connections {
#[allow(non_camel_case_types)]
pub enum DbConnInner { $( #[cfg($name)] $name(PooledConnection<ConnectionManager< $ty >>), )+ }
+ #[derive(Debug)]
+ pub struct DbConnOptions {
+ pub init_stmts: String,
+ }
+
+ $( // Based on <https://stackoverflow.com/a/57717533>.
+ #[cfg($name)]
+ impl CustomizeConnection<$ty, diesel::r2d2::Error> for DbConnOptions {
+ fn on_acquire(&self, conn: &mut $ty) -> Result<(), diesel::r2d2::Error> {
+ (|| {
+ if !self.init_stmts.is_empty() {
+ conn.batch_execute(&self.init_stmts)?;
+ }
+ Ok(())
+ })().map_err(diesel::r2d2::Error::QueryError)
+ }
+ })+
#[derive(Clone)]
pub struct DbPool {
@@ -103,7 +124,8 @@ macro_rules! generate_connections {
}
impl DbPool {
- // For the given database URL, guess it's type, run migrations create pool and return it
+ // For the given database URL, guess its type, run migrations, create pool, and return it
+ #[allow(clippy::diverging_sub_expression)]
pub fn from_config() -> Result<Self, Error> {
let url = CONFIG.database_url();
let conn_type = DbConnType::from_url(&url)?;
@@ -117,6 +139,9 @@ macro_rules! generate_connections {
let pool = Pool::builder()
.max_size(CONFIG.database_max_conns())
.connection_timeout(Duration::from_secs(CONFIG.database_timeout()))
+ .connection_customizer(Box::new(DbConnOptions{
+ init_stmts: conn_type.get_init_stmts()
+ }))
.build(manager)
.map_res("Failed to create pool")?;
return Ok(DbPool {
@@ -190,6 +215,23 @@ impl DbConnType {
err!("`DATABASE_URL` looks like a SQLite URL, but 'sqlite' feature is not enabled")
}
}
+
+ pub fn get_init_stmts(&self) -> String {
+ let init_stmts = CONFIG.database_conn_init();
+ if !init_stmts.is_empty() {
+ init_stmts
+ } else {
+ self.default_init_stmts()
+ }
+ }
+
+ pub fn default_init_stmts(&self) -> String {
+ match self {
+ Self::sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_string(),
+ Self::mysql => "".to_string(),
+ Self::postgresql => "".to_string(),
+ }
+ }
}
#[macro_export]