diff options
-rw-r--r-- | src/api/identity.rs | 102 |
1 files changed, 56 insertions, 46 deletions
diff --git a/src/api/identity.rs b/src/api/identity.rs index ecdf37cb..02896129 100644 --- a/src/api/identity.rs +++ b/src/api/identity.rs @@ -1,4 +1,4 @@ -use rocket::request::LenientForm; +use rocket::request::{Form, FormItems, FromForm}; use rocket::Route; use rocket_contrib::json::Json; @@ -22,13 +22,27 @@ pub fn routes() -> Vec<Route> { } #[post("/connect/token", data = "<data>")] -fn login(data: LenientForm<ConnectData>, conn: DbConn, ip: ClientIp) -> JsonResult { +fn login(data: Form<ConnectData>, conn: DbConn, ip: ClientIp) -> JsonResult { let data: ConnectData = data.into_inner(); - validate_data(&data)?; - match data.grant_type { - GrantType::refresh_token => _refresh_login(data, conn), - GrantType::password => _password_login(data, conn, ip), + match data.grant_type.as_ref() { + "refresh_token" => { + _check_is_some(&data.refresh_token, "refresh_token cannot be blank")?; + _refresh_login(data, conn) + } + "password" => { + _check_is_some(&data.client_id, "client_id cannot be blank")?; + _check_is_some(&data.password, "password cannot be blank")?; + _check_is_some(&data.scope, "scope cannot be blank")?; + _check_is_some(&data.username, "username cannot be blank")?; + + _check_is_some(&data.device_identifier, "device_identifier cannot be blank")?; + _check_is_some(&data.device_name, "device_name cannot be blank")?; + _check_is_some(&data.device_type, "device_type cannot be blank")?; + + _password_login(data, conn, ip) + } + t => err!("Invalid type", t), } } @@ -86,9 +100,10 @@ fn _password_login(data: ConnectData, conn: DbConn, ip: ClientIp) -> JsonResult )) } + // On iOS, device_type sends "iOS", on others it sends a number let device_type = util::try_parse_string(data.device_type.as_ref()).unwrap_or(0); - let device_id = data.device_identifier.clone().unwrap_or_else(crate::util::get_uuid); - let device_name = data.device_name.clone().unwrap_or("unknown_device".into()); + let device_id = data.device_identifier.clone().expect("No device id provided"); + let device_name = data.device_name.clone().expect("No device name provided"); // Find device or create new let mut device = match Device::find_by_uuid(&device_id, &conn) { @@ -101,7 +116,7 @@ fn _password_login(data: ConnectData, conn: DbConn, ip: ClientIp) -> JsonResult device } } - None => Device::new(device_id, user.uuid.clone(), device_name, device_type) + None => Device::new(device_id, user.uuid.clone(), device_name, device_type), }; let twofactor_token = twofactor_auth(&user.uuid, &data.clone(), &mut device, &conn)?; @@ -133,12 +148,7 @@ fn _password_login(data: ConnectData, conn: DbConn, ip: ClientIp) -> JsonResult Ok(Json(result)) } -fn twofactor_auth( - user_uuid: &str, - data: &ConnectData, - device: &mut Device, - conn: &DbConn, -) -> ApiResult<Option<String>> { +fn twofactor_auth(user_uuid: &str, data: &ConnectData, device: &mut Device, conn: &DbConn) -> ApiResult<Option<String>> { let twofactors_raw = TwoFactor::find_by_user(user_uuid, conn); // Remove u2f challenge twofactors (impl detail) let twofactors: Vec<_> = twofactors_raw.iter().filter(|tf| tf.type_ < 1000).collect(); @@ -230,13 +240,9 @@ fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &DbConn) -> Api let mut challenge_map = JsonMap::new(); challenge_map.insert("appId".into(), Value::String(request.app_id.clone())); - challenge_map - .insert("challenge".into(), Value::String(request.challenge.clone())); + challenge_map.insert("challenge".into(), Value::String(request.challenge.clone())); challenge_map.insert("version".into(), Value::String(key.version)); - challenge_map.insert( - "keyHandle".into(), - Value::String(key.key_handle.unwrap_or_default()), - ); + challenge_map.insert("keyHandle".into(), Value::String(key.key_handle.unwrap_or_default())); challenge_list.push(Value::Object(challenge_map)); } @@ -269,10 +275,10 @@ fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &DbConn) -> Api Ok(result) } -#[derive(FromForm, Debug, Clone)] +#[derive(Debug, Clone, Default)] #[allow(non_snake_case)] struct ConnectData { - grant_type: GrantType, + grant_type: String, // refresh_token, password // Needed for grant_type="refresh_token" refresh_token: Option<String>, @@ -283,40 +289,44 @@ struct ConnectData { scope: Option<String>, username: Option<String>, - #[form(field = "deviceIdentifier")] device_identifier: Option<String>, - #[form(field = "deviceName")] device_name: Option<String>, - #[form(field = "deviceType")] device_type: Option<String>, // Needed for two-factor auth - #[form(field = "twoFactorProvider")] two_factor_provider: Option<i32>, - #[form(field = "twoFactorToken")] two_factor_token: Option<String>, - #[form(field = "twoFactorRemember")] two_factor_remember: Option<i32>, } -#[derive(FromFormValue, Debug, Clone, Copy)] -#[allow(non_camel_case_types)] -enum GrantType { - refresh_token, - password, -} - -fn validate_data(data: &ConnectData) -> EmptyResult { - match data.grant_type { - GrantType::refresh_token => { - _check_is_some(&data.refresh_token, "refresh_token cannot be blank") - } - GrantType::password => { - _check_is_some(&data.client_id, "client_id cannot be blank")?; - _check_is_some(&data.password, "password cannot be blank")?; - _check_is_some(&data.scope, "scope cannot be blank")?; - _check_is_some(&data.username, "username cannot be blank") +impl<'f> FromForm<'f> for ConnectData { + type Error = String; + + fn from_form(items: &mut FormItems<'f>, _strict: bool) -> Result<Self, Self::Error> { + let mut form = Self::default(); + for item in items { + let (key, value) = item.key_value_decoded(); + let mut normalized_key = key.to_lowercase(); + normalized_key.retain(|c| c != '_'); // Remove '_' + + match normalized_key.as_ref() { + "granttype" => form.grant_type = value, + "refreshtoken" => form.refresh_token = Some(value), + "clientid" => form.client_id = Some(value), + "password" => form.password = Some(value), + "scope" => form.scope = Some(value), + "username" => form.username = Some(value), + "deviceidentifier" => form.device_identifier = Some(value), + "devicename" => form.device_name = Some(value), + "devicetype" => form.device_type = Some(value), + "twofactorprovider" => form.two_factor_provider = value.parse().ok(), + "twofactortoken" => form.two_factor_token = Some(value), + "twofactorremember" => form.two_factor_remember = value.parse().ok(), + key => warn!("Detected unexpected parameter during login: {}", key), + } } + + Ok(form) } } |