diff options
author | rustdesk <[email protected]> | 2022-05-12 20:00:33 +0800 |
---|---|---|
committer | rustdesk <[email protected]> | 2022-05-12 20:00:33 +0800 |
commit | b3f39598a7324dacec0cd84d5e09b95724805cc8 (patch) | |
tree | 9a80ecd503ec53089bfd03d65612d10e615b76af /src/database.rs | |
parent | d36d6da4452f6250e4b9c9e821367e2dc6783f11 (diff) | |
download | rustdesk-server-b3f39598a7324dacec0cd84d5e09b95724805cc8.tar.gz rustdesk-server-b3f39598a7324dacec0cd84d5e09b95724805cc8.zip |
change sled to sqlite and remove lic
Diffstat (limited to 'src/database.rs')
-rw-r--r-- | src/database.rs | 231 |
1 files changed, 231 insertions, 0 deletions
diff --git a/src/database.rs b/src/database.rs new file mode 100644 index 0000000..5a199e6 --- /dev/null +++ b/src/database.rs @@ -0,0 +1,231 @@ +use async_trait::async_trait; +use hbb_common::{log, ResultType}; +use serde_json::value::Value; +use sqlx::{ + sqlite::SqliteConnectOptions, ConnectOptions, Connection, Error as SqlxError, SqliteConnection, +}; +use std::{ops::DerefMut, str::FromStr}; +//use sqlx::postgres::PgPoolOptions; +//use sqlx::mysql::MySqlPoolOptions; + +pub(crate) type DB = sqlx::Sqlite; +pub(crate) type MapValue = serde_json::map::Map<String, Value>; +pub(crate) type MapStr = std::collections::HashMap<String, String>; +type Pool = deadpool::managed::Pool<DbPool>; + +pub struct DbPool { + url: String, +} + +#[async_trait] +impl deadpool::managed::Manager for DbPool { + type Type = SqliteConnection; + type Error = SqlxError; + async fn create(&self) -> Result<SqliteConnection, SqlxError> { + let mut opt = SqliteConnectOptions::from_str(&self.url).unwrap(); + opt.log_statements(log::LevelFilter::Debug); + SqliteConnection::connect_with(&opt).await + } + async fn recycle( + &self, + obj: &mut SqliteConnection, + ) -> deadpool::managed::RecycleResult<SqlxError> { + Ok(obj.ping().await?) + } +} + +#[derive(Clone)] +pub struct Database { + pool: Pool, +} + +#[derive(Default)] +pub struct Peer { + pub guid: Vec<u8>, + pub id: String, + pub uuid: Vec<u8>, + pub pk: Vec<u8>, + pub user: Option<Vec<u8>>, + pub info: String, + pub status: Option<i64>, +} + +impl Database { + pub async fn new(url: &str) -> ResultType<Database> { + if !std::path::Path::new(url).exists() { + std::fs::File::create(url).ok(); + } + let n: usize = std::env::var("MAX_CONNECTIONS") + .unwrap_or("1".to_owned()) + .parse() + .unwrap_or(1); + log::info!("MAX_CONNECTIONS={}", n); + let pool = Pool::new( + DbPool { + url: url.to_owned(), + }, + n, + ); + let _ = pool.get().await?; // test + let db = Database { pool }; + db.create_tables().await?; + Ok(db) + } + + async fn create_tables(&self) -> ResultType<()> { + sqlx::query!( + " + create table if not exists peer ( + guid blob primary key not null, + id varchar(100) not null, + uuid blob not null, + pk blob not null, + created_at datetime not null default(current_timestamp), + user blob, + status tinyint, + note varchar(300), + info text not null + ) without rowid; + create unique index if not exists index_peer_id on peer (id); + create index if not exists index_peer_user on peer (user); + create index if not exists index_peer_created_at on peer (created_at); + create index if not exists index_peer_status on peer (status); + " + ) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + pub async fn get_peer(&self, id: &str) -> ResultType<Option<Peer>> { + Ok(sqlx::query_as!( + Peer, + "select guid, id, uuid, pk, user, status, info from peer where id = ?", + id + ) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?) + } + + pub async fn get_peer_id(&self, guid: &[u8]) -> ResultType<Option<String>> { + Ok(sqlx::query!("select id from peer where guid = ?", guid) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await? + .map(|x| x.id)) + } + + #[inline] + pub async fn get_conn(&self) -> ResultType<deadpool::managed::Object<DbPool>> { + Ok(self.pool.get().await?) + } + + pub async fn update_peer(&self, payload: MapValue, guid: &[u8]) -> ResultType<()> { + let mut conn = self.get_conn().await?; + let mut tx = conn.begin().await?; + if let Some(v) = payload.get("note") { + let v = get_str(v); + sqlx::query!("update peer set note = ? where guid = ?", v, guid) + .execute(&mut tx) + .await?; + } + tx.commit().await?; + Ok(()) + } + + pub async fn insert_peer( + &self, + id: &str, + uuid: &Vec<u8>, + pk: &Vec<u8>, + info: &str, + ) -> ResultType<Vec<u8>> { + let guid = uuid::Uuid::new_v4().as_bytes().to_vec(); + sqlx::query!( + "insert into peer(guid, id, uuid, pk, info) values(?, ?, ?, ?, ?)", + guid, + id, + uuid, + pk, + info + ) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(guid) + } + + pub async fn update_pk( + &self, + guid: &Vec<u8>, + id: &str, + pk: &Vec<u8>, + info: &str, + ) -> ResultType<()> { + sqlx::query!( + "update peer set id=?, pk=?, info=? where guid=?", + id, + pk, + info, + guid + ) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use hbb_common::tokio; + #[test] + fn test_insert() { + insert(); + } + + #[tokio::main(flavor = "multi_thread")] + async fn insert() { + let db = super::Database::new("test.sqlite3").await.unwrap(); + let mut jobs = vec![]; + for i in 0..10000 { + let cloned = db.clone(); + let id = i.to_string(); + let a = tokio::spawn(async move { + let empty_vec = Vec::new(); + cloned + .insert_peer(&id, &empty_vec, &empty_vec, "") + .await + .unwrap(); + }); + jobs.push(a); + } + for i in 0..10000 { + let cloned = db.clone(); + let id = i.to_string(); + let a = tokio::spawn(async move { + cloned.get_peer(&id).await.unwrap(); + }); + jobs.push(a); + } + hbb_common::futures::future::join_all(jobs).await; + } +} + +#[inline] +pub fn guid2str(guid: &Vec<u8>) -> String { + let mut bytes = [0u8; 16]; + bytes[..].copy_from_slice(&guid); + uuid::Uuid::from_bytes(bytes).to_string() +} + +pub(crate) fn get_str(v: &Value) -> Option<&str> { + match v { + Value::String(v) => { + let v = v.trim(); + if v.is_empty() { + None + } else { + Some(v) + } + } + _ => None, + } +} |