aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/database.rs
diff options
context:
space:
mode:
authorrustdesk <[email protected]>2022-05-12 20:00:33 +0800
committerrustdesk <[email protected]>2022-05-12 20:00:33 +0800
commitb3f39598a7324dacec0cd84d5e09b95724805cc8 (patch)
tree9a80ecd503ec53089bfd03d65612d10e615b76af /src/database.rs
parentd36d6da4452f6250e4b9c9e821367e2dc6783f11 (diff)
downloadrustdesk-server-b3f39598a7324dacec0cd84d5e09b95724805cc8.tar.gz
rustdesk-server-b3f39598a7324dacec0cd84d5e09b95724805cc8.zip
change sled to sqlite and remove lic
Diffstat (limited to 'src/database.rs')
-rw-r--r--src/database.rs231
1 files changed, 231 insertions, 0 deletions
diff --git a/src/database.rs b/src/database.rs
new file mode 100644
index 0000000..5a199e6
--- /dev/null
+++ b/src/database.rs
@@ -0,0 +1,231 @@
+use async_trait::async_trait;
+use hbb_common::{log, ResultType};
+use serde_json::value::Value;
+use sqlx::{
+ sqlite::SqliteConnectOptions, ConnectOptions, Connection, Error as SqlxError, SqliteConnection,
+};
+use std::{ops::DerefMut, str::FromStr};
+//use sqlx::postgres::PgPoolOptions;
+//use sqlx::mysql::MySqlPoolOptions;
+
+pub(crate) type DB = sqlx::Sqlite;
+pub(crate) type MapValue = serde_json::map::Map<String, Value>;
+pub(crate) type MapStr = std::collections::HashMap<String, String>;
+type Pool = deadpool::managed::Pool<DbPool>;
+
+pub struct DbPool {
+ url: String,
+}
+
+#[async_trait]
+impl deadpool::managed::Manager for DbPool {
+ type Type = SqliteConnection;
+ type Error = SqlxError;
+ async fn create(&self) -> Result<SqliteConnection, SqlxError> {
+ let mut opt = SqliteConnectOptions::from_str(&self.url).unwrap();
+ opt.log_statements(log::LevelFilter::Debug);
+ SqliteConnection::connect_with(&opt).await
+ }
+ async fn recycle(
+ &self,
+ obj: &mut SqliteConnection,
+ ) -> deadpool::managed::RecycleResult<SqlxError> {
+ Ok(obj.ping().await?)
+ }
+}
+
+#[derive(Clone)]
+pub struct Database {
+ pool: Pool,
+}
+
+#[derive(Default)]
+pub struct Peer {
+ pub guid: Vec<u8>,
+ pub id: String,
+ pub uuid: Vec<u8>,
+ pub pk: Vec<u8>,
+ pub user: Option<Vec<u8>>,
+ pub info: String,
+ pub status: Option<i64>,
+}
+
+impl Database {
+ pub async fn new(url: &str) -> ResultType<Database> {
+ if !std::path::Path::new(url).exists() {
+ std::fs::File::create(url).ok();
+ }
+ let n: usize = std::env::var("MAX_CONNECTIONS")
+ .unwrap_or("1".to_owned())
+ .parse()
+ .unwrap_or(1);
+ log::info!("MAX_CONNECTIONS={}", n);
+ let pool = Pool::new(
+ DbPool {
+ url: url.to_owned(),
+ },
+ n,
+ );
+ let _ = pool.get().await?; // test
+ let db = Database { pool };
+ db.create_tables().await?;
+ Ok(db)
+ }
+
+ async fn create_tables(&self) -> ResultType<()> {
+ sqlx::query!(
+ "
+ create table if not exists peer (
+ guid blob primary key not null,
+ id varchar(100) not null,
+ uuid blob not null,
+ pk blob not null,
+ created_at datetime not null default(current_timestamp),
+ user blob,
+ status tinyint,
+ note varchar(300),
+ info text not null
+ ) without rowid;
+ create unique index if not exists index_peer_id on peer (id);
+ create index if not exists index_peer_user on peer (user);
+ create index if not exists index_peer_created_at on peer (created_at);
+ create index if not exists index_peer_status on peer (status);
+ "
+ )
+ .execute(self.pool.get().await?.deref_mut())
+ .await?;
+ Ok(())
+ }
+
+ pub async fn get_peer(&self, id: &str) -> ResultType<Option<Peer>> {
+ Ok(sqlx::query_as!(
+ Peer,
+ "select guid, id, uuid, pk, user, status, info from peer where id = ?",
+ id
+ )
+ .fetch_optional(self.pool.get().await?.deref_mut())
+ .await?)
+ }
+
+ pub async fn get_peer_id(&self, guid: &[u8]) -> ResultType<Option<String>> {
+ Ok(sqlx::query!("select id from peer where guid = ?", guid)
+ .fetch_optional(self.pool.get().await?.deref_mut())
+ .await?
+ .map(|x| x.id))
+ }
+
+ #[inline]
+ pub async fn get_conn(&self) -> ResultType<deadpool::managed::Object<DbPool>> {
+ Ok(self.pool.get().await?)
+ }
+
+ pub async fn update_peer(&self, payload: MapValue, guid: &[u8]) -> ResultType<()> {
+ let mut conn = self.get_conn().await?;
+ let mut tx = conn.begin().await?;
+ if let Some(v) = payload.get("note") {
+ let v = get_str(v);
+ sqlx::query!("update peer set note = ? where guid = ?", v, guid)
+ .execute(&mut tx)
+ .await?;
+ }
+ tx.commit().await?;
+ Ok(())
+ }
+
+ pub async fn insert_peer(
+ &self,
+ id: &str,
+ uuid: &Vec<u8>,
+ pk: &Vec<u8>,
+ info: &str,
+ ) -> ResultType<Vec<u8>> {
+ let guid = uuid::Uuid::new_v4().as_bytes().to_vec();
+ sqlx::query!(
+ "insert into peer(guid, id, uuid, pk, info) values(?, ?, ?, ?, ?)",
+ guid,
+ id,
+ uuid,
+ pk,
+ info
+ )
+ .execute(self.pool.get().await?.deref_mut())
+ .await?;
+ Ok(guid)
+ }
+
+ pub async fn update_pk(
+ &self,
+ guid: &Vec<u8>,
+ id: &str,
+ pk: &Vec<u8>,
+ info: &str,
+ ) -> ResultType<()> {
+ sqlx::query!(
+ "update peer set id=?, pk=?, info=? where guid=?",
+ id,
+ pk,
+ info,
+ guid
+ )
+ .execute(self.pool.get().await?.deref_mut())
+ .await?;
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use hbb_common::tokio;
+ #[test]
+ fn test_insert() {
+ insert();
+ }
+
+ #[tokio::main(flavor = "multi_thread")]
+ async fn insert() {
+ let db = super::Database::new("test.sqlite3").await.unwrap();
+ let mut jobs = vec![];
+ for i in 0..10000 {
+ let cloned = db.clone();
+ let id = i.to_string();
+ let a = tokio::spawn(async move {
+ let empty_vec = Vec::new();
+ cloned
+ .insert_peer(&id, &empty_vec, &empty_vec, "")
+ .await
+ .unwrap();
+ });
+ jobs.push(a);
+ }
+ for i in 0..10000 {
+ let cloned = db.clone();
+ let id = i.to_string();
+ let a = tokio::spawn(async move {
+ cloned.get_peer(&id).await.unwrap();
+ });
+ jobs.push(a);
+ }
+ hbb_common::futures::future::join_all(jobs).await;
+ }
+}
+
+#[inline]
+pub fn guid2str(guid: &Vec<u8>) -> String {
+ let mut bytes = [0u8; 16];
+ bytes[..].copy_from_slice(&guid);
+ uuid::Uuid::from_bytes(bytes).to_string()
+}
+
+pub(crate) fn get_str(v: &Value) -> Option<&str> {
+ match v {
+ Value::String(v) => {
+ let v = v.trim();
+ if v.is_empty() {
+ None
+ } else {
+ Some(v)
+ }
+ }
+ _ => None,
+ }
+}