diff options
-rw-r--r-- | src/api/admin.rs | 32 | ||||
-rw-r--r-- | src/db/mod.rs | 22 | ||||
-rw-r--r-- | src/main.rs | 65 | ||||
-rw-r--r-- | src/util.rs | 22 |
4 files changed, 108 insertions, 33 deletions
diff --git a/src/api/admin.rs b/src/api/admin.rs index e6be3783..03d86920 100644 --- a/src/api/admin.rs +++ b/src/api/admin.rs @@ -25,7 +25,8 @@ use crate::{ http_client::make_http_request, mail, util::{ - container_base_image, format_naive_datetime_local, get_display_size, is_running_in_container, NumberOrString, + container_base_image, format_naive_datetime_local, get_display_size, get_web_vault_version, + is_running_in_container, NumberOrString, }, CONFIG, VERSION, }; @@ -576,11 +577,6 @@ async fn delete_organization(uuid: &str, _token: AdminToken, mut conn: DbConn) - } #[derive(Deserialize)] -struct WebVaultVersion { - version: String, -} - -#[derive(Deserialize)] struct GitRelease { tag_name: String, } @@ -679,18 +675,6 @@ async fn diagnostics(_token: AdminToken, ip_header: IpHeader, mut conn: DbConn) use chrono::prelude::*; use std::net::ToSocketAddrs; - // Get current running versions - let web_vault_version: WebVaultVersion = - match std::fs::read_to_string(format!("{}/{}", CONFIG.web_vault_folder(), "vw-version.json")) { - Ok(s) => serde_json::from_str(&s)?, - _ => match std::fs::read_to_string(format!("{}/{}", CONFIG.web_vault_folder(), "version.json")) { - Ok(s) => serde_json::from_str(&s)?, - _ => WebVaultVersion { - version: String::from("Version file missing"), - }, - }, - }; - // Execute some environment checks let running_within_container = is_running_in_container(); let has_http_access = has_http_access().await; @@ -710,13 +694,16 @@ async fn diagnostics(_token: AdminToken, ip_header: IpHeader, mut conn: DbConn) let ip_header_name = &ip_header.0.unwrap_or_default(); + // Get current running versions + let web_vault_version = get_web_vault_version(); + let diagnostics_json = json!({ "dns_resolved": dns_resolved, "current_release": VERSION, "latest_release": latest_release, "latest_commit": latest_commit, "web_vault_enabled": &CONFIG.web_vault_enabled(), - "web_vault_version": web_vault_version.version.trim_start_matches('v'), + "web_vault_version": web_vault_version, "latest_web_build": latest_web_build, "running_within_container": running_within_container, "container_base_image": if running_within_container { container_base_image() } else { "Not applicable" }, @@ -765,9 +752,12 @@ fn delete_config(_token: AdminToken) -> EmptyResult { } #[post("/config/backup_db")] -async fn backup_db(_token: AdminToken, mut conn: DbConn) -> EmptyResult { +async fn backup_db(_token: AdminToken, mut conn: DbConn) -> ApiResult<String> { if *CAN_BACKUP { - backup_database(&mut conn).await + match backup_database(&mut conn).await { + Ok(f) => Ok(format!("Backup to '{f}' was successful")), + Err(e) => err!(format!("Backup was unsuccessful {e}")), + } } else { err!("Can't back up current DB (Only SQLite supports this feature)"); } diff --git a/src/db/mod.rs b/src/db/mod.rs index 824b3c71..51ffba9c 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -368,23 +368,31 @@ pub mod models; /// Creates a back-up of the sqlite database /// MySQL/MariaDB and PostgreSQL are not supported. -pub async fn backup_database(conn: &mut DbConn) -> Result<(), Error> { +pub async fn backup_database(conn: &mut DbConn) -> Result<String, Error> { db_run! {@raw conn: postgresql, mysql { let _ = conn; err!("PostgreSQL and MySQL/MariaDB do not support this backup feature"); } sqlite { - use std::path::Path; - let db_url = CONFIG.database_url(); - let db_path = Path::new(&db_url).parent().unwrap().to_string_lossy(); - let file_date = chrono::Utc::now().format("%Y%m%d_%H%M%S").to_string(); - diesel::sql_query(format!("VACUUM INTO '{db_path}/db_{file_date}.sqlite3'")).execute(conn)?; - Ok(()) + backup_sqlite_database(conn) } } } +#[cfg(sqlite)] +pub fn backup_sqlite_database(conn: &mut diesel::sqlite::SqliteConnection) -> Result<String, Error> { + use diesel::RunQueryDsl; + let db_url = CONFIG.database_url(); + let db_path = std::path::Path::new(&db_url).parent().unwrap(); + let backup_file = db_path + .join(format!("db_{}.sqlite3", chrono::Utc::now().format("%Y%m%d_%H%M%S"))) + .to_string_lossy() + .into_owned(); + diesel::sql_query(format!("VACUUM INTO '{backup_file}'")).execute(conn)?; + Ok(backup_file) +} + /// Get the SQL Server version pub async fn get_sql_server_version(conn: &mut DbConn) -> String { db_run! {@raw conn: diff --git a/src/main.rs b/src/main.rs index 9f96dc60..6e725483 100644 --- a/src/main.rs +++ b/src/main.rs @@ -38,6 +38,7 @@ use std::{ use tokio::{ fs::File, io::{AsyncBufReadExt, BufReader}, + signal::unix::SignalKind, }; #[macro_use] @@ -97,10 +98,12 @@ USAGE: FLAGS: -h, --help Prints help information - -v, --version Prints the app version + -v, --version Prints the app and web-vault version COMMAND: hash [--preset {bitwarden|owasp}] Generate an Argon2id PHC ADMIN_TOKEN + backup Create a backup of the SQLite database + You can also send the USR1 signal to trigger a backup PRESETS: m= t= p= bitwarden (default) 64MiB, 3 Iterations, 4 Threads @@ -115,11 +118,13 @@ fn parse_args() { let version = VERSION.unwrap_or("(Version info from Git not present)"); if pargs.contains(["-h", "--help"]) { - println!("vaultwarden {version}"); + println!("Vaultwarden {version}"); print!("{HELP}"); exit(0); } else if pargs.contains(["-v", "--version"]) { - println!("vaultwarden {version}"); + let web_vault_version = util::get_web_vault_version(); + println!("Vaultwarden {version}"); + println!("Web-Vault {web_vault_version}"); exit(0); } @@ -174,13 +179,47 @@ fn parse_args() { argon2_timer.elapsed() ); } else { - error!("Unable to generate Argon2id PHC hash."); + println!("Unable to generate Argon2id PHC hash."); exit(1); } + } else if command == "backup" { + match backup_sqlite() { + Ok(f) => { + println!("Backup to '{f}' was successful"); + exit(0); + } + Err(e) => { + println!("Backup failed. {e:?}"); + exit(1); + } + } } exit(0); } } + +fn backup_sqlite() -> Result<String, Error> { + #[cfg(sqlite)] + { + use crate::db::{backup_sqlite_database, DbConnType}; + if DbConnType::from_url(&CONFIG.database_url()).map(|t| t == DbConnType::sqlite).unwrap_or(false) { + use diesel::Connection; + let url = crate::CONFIG.database_url(); + + // Establish a connection to the sqlite database + let mut conn = diesel::sqlite::SqliteConnection::establish(&url)?; + let backup_file = backup_sqlite_database(&mut conn)?; + Ok(backup_file) + } else { + err_silent!("The database type is not SQLite. Backups only works for SQLite databases") + } + } + #[cfg(not(sqlite))] + { + err_silent!("The 'sqlite' feature is not enabled. Backups only works for SQLite databases") + } +} + fn launch_info() { println!( "\ @@ -346,7 +385,7 @@ fn init_logging() -> Result<log::LevelFilter, Error> { } #[cfg(not(windows))] { - const SIGHUP: i32 = tokio::signal::unix::SignalKind::hangup().as_raw_value(); + const SIGHUP: i32 = SignalKind::hangup().as_raw_value(); let path = Path::new(&log_file); logger = logger.chain(fern::log_reopen1(path, [SIGHUP])?); } @@ -560,6 +599,22 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> CONFIG.shutdown(); }); + #[cfg(unix)] + { + tokio::spawn(async move { + let mut signal_user1 = tokio::signal::unix::signal(SignalKind::user_defined1()).unwrap(); + loop { + // If we need more signals to act upon, we might want to use select! here. + // With only one item to listen for this is enough. + let _ = signal_user1.recv().await; + match backup_sqlite() { + Ok(f) => info!("Backup to '{f}' was successful"), + Err(e) => error!("Backup failed. {e:?}"), + } + } + }); + } + let _ = instance.launch().await?; info!("Vaultwarden process exited!"); diff --git a/src/util.rs b/src/util.rs index 9d58a53f..c586798c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -513,6 +513,28 @@ pub fn container_base_image() -> &'static str { } } +#[derive(Deserialize)] +struct WebVaultVersion { + version: String, +} + +pub fn get_web_vault_version() -> String { + let version_files = [ + format!("{}/vw-version.json", CONFIG.web_vault_folder()), + format!("{}/version.json", CONFIG.web_vault_folder()), + ]; + + for version_file in version_files { + if let Ok(version_str) = std::fs::read_to_string(&version_file) { + if let Ok(version) = serde_json::from_str::<WebVaultVersion>(&version_str) { + return String::from(version.version.trim_start_matches('v')); + } + } + } + + String::from("Version file missing") +} + // // Deserialization methods // |