aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/api/identity.rs102
1 files changed, 56 insertions, 46 deletions
diff --git a/src/api/identity.rs b/src/api/identity.rs
index ecdf37cb..02896129 100644
--- a/src/api/identity.rs
+++ b/src/api/identity.rs
@@ -1,4 +1,4 @@
-use rocket::request::LenientForm;
+use rocket::request::{Form, FormItems, FromForm};
use rocket::Route;
use rocket_contrib::json::Json;
@@ -22,13 +22,27 @@ pub fn routes() -> Vec<Route> {
}
#[post("/connect/token", data = "<data>")]
-fn login(data: LenientForm<ConnectData>, conn: DbConn, ip: ClientIp) -> JsonResult {
+fn login(data: Form<ConnectData>, conn: DbConn, ip: ClientIp) -> JsonResult {
let data: ConnectData = data.into_inner();
- validate_data(&data)?;
- match data.grant_type {
- GrantType::refresh_token => _refresh_login(data, conn),
- GrantType::password => _password_login(data, conn, ip),
+ match data.grant_type.as_ref() {
+ "refresh_token" => {
+ _check_is_some(&data.refresh_token, "refresh_token cannot be blank")?;
+ _refresh_login(data, conn)
+ }
+ "password" => {
+ _check_is_some(&data.client_id, "client_id cannot be blank")?;
+ _check_is_some(&data.password, "password cannot be blank")?;
+ _check_is_some(&data.scope, "scope cannot be blank")?;
+ _check_is_some(&data.username, "username cannot be blank")?;
+
+ _check_is_some(&data.device_identifier, "device_identifier cannot be blank")?;
+ _check_is_some(&data.device_name, "device_name cannot be blank")?;
+ _check_is_some(&data.device_type, "device_type cannot be blank")?;
+
+ _password_login(data, conn, ip)
+ }
+ t => err!("Invalid type", t),
}
}
@@ -86,9 +100,10 @@ fn _password_login(data: ConnectData, conn: DbConn, ip: ClientIp) -> JsonResult
))
}
+ // On iOS, device_type sends "iOS", on others it sends a number
let device_type = util::try_parse_string(data.device_type.as_ref()).unwrap_or(0);
- let device_id = data.device_identifier.clone().unwrap_or_else(crate::util::get_uuid);
- let device_name = data.device_name.clone().unwrap_or("unknown_device".into());
+ let device_id = data.device_identifier.clone().expect("No device id provided");
+ let device_name = data.device_name.clone().expect("No device name provided");
// Find device or create new
let mut device = match Device::find_by_uuid(&device_id, &conn) {
@@ -101,7 +116,7 @@ fn _password_login(data: ConnectData, conn: DbConn, ip: ClientIp) -> JsonResult
device
}
}
- None => Device::new(device_id, user.uuid.clone(), device_name, device_type)
+ None => Device::new(device_id, user.uuid.clone(), device_name, device_type),
};
let twofactor_token = twofactor_auth(&user.uuid, &data.clone(), &mut device, &conn)?;
@@ -133,12 +148,7 @@ fn _password_login(data: ConnectData, conn: DbConn, ip: ClientIp) -> JsonResult
Ok(Json(result))
}
-fn twofactor_auth(
- user_uuid: &str,
- data: &ConnectData,
- device: &mut Device,
- conn: &DbConn,
-) -> ApiResult<Option<String>> {
+fn twofactor_auth(user_uuid: &str, data: &ConnectData, device: &mut Device, conn: &DbConn) -> ApiResult<Option<String>> {
let twofactors_raw = TwoFactor::find_by_user(user_uuid, conn);
// Remove u2f challenge twofactors (impl detail)
let twofactors: Vec<_> = twofactors_raw.iter().filter(|tf| tf.type_ < 1000).collect();
@@ -230,13 +240,9 @@ fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &DbConn) -> Api
let mut challenge_map = JsonMap::new();
challenge_map.insert("appId".into(), Value::String(request.app_id.clone()));
- challenge_map
- .insert("challenge".into(), Value::String(request.challenge.clone()));
+ challenge_map.insert("challenge".into(), Value::String(request.challenge.clone()));
challenge_map.insert("version".into(), Value::String(key.version));
- challenge_map.insert(
- "keyHandle".into(),
- Value::String(key.key_handle.unwrap_or_default()),
- );
+ challenge_map.insert("keyHandle".into(), Value::String(key.key_handle.unwrap_or_default()));
challenge_list.push(Value::Object(challenge_map));
}
@@ -269,10 +275,10 @@ fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &DbConn) -> Api
Ok(result)
}
-#[derive(FromForm, Debug, Clone)]
+#[derive(Debug, Clone, Default)]
#[allow(non_snake_case)]
struct ConnectData {
- grant_type: GrantType,
+ grant_type: String, // refresh_token, password
// Needed for grant_type="refresh_token"
refresh_token: Option<String>,
@@ -283,40 +289,44 @@ struct ConnectData {
scope: Option<String>,
username: Option<String>,
- #[form(field = "deviceIdentifier")]
device_identifier: Option<String>,
- #[form(field = "deviceName")]
device_name: Option<String>,
- #[form(field = "deviceType")]
device_type: Option<String>,
// Needed for two-factor auth
- #[form(field = "twoFactorProvider")]
two_factor_provider: Option<i32>,
- #[form(field = "twoFactorToken")]
two_factor_token: Option<String>,
- #[form(field = "twoFactorRemember")]
two_factor_remember: Option<i32>,
}
-#[derive(FromFormValue, Debug, Clone, Copy)]
-#[allow(non_camel_case_types)]
-enum GrantType {
- refresh_token,
- password,
-}
-
-fn validate_data(data: &ConnectData) -> EmptyResult {
- match data.grant_type {
- GrantType::refresh_token => {
- _check_is_some(&data.refresh_token, "refresh_token cannot be blank")
- }
- GrantType::password => {
- _check_is_some(&data.client_id, "client_id cannot be blank")?;
- _check_is_some(&data.password, "password cannot be blank")?;
- _check_is_some(&data.scope, "scope cannot be blank")?;
- _check_is_some(&data.username, "username cannot be blank")
+impl<'f> FromForm<'f> for ConnectData {
+ type Error = String;
+
+ fn from_form(items: &mut FormItems<'f>, _strict: bool) -> Result<Self, Self::Error> {
+ let mut form = Self::default();
+ for item in items {
+ let (key, value) = item.key_value_decoded();
+ let mut normalized_key = key.to_lowercase();
+ normalized_key.retain(|c| c != '_'); // Remove '_'
+
+ match normalized_key.as_ref() {
+ "granttype" => form.grant_type = value,
+ "refreshtoken" => form.refresh_token = Some(value),
+ "clientid" => form.client_id = Some(value),
+ "password" => form.password = Some(value),
+ "scope" => form.scope = Some(value),
+ "username" => form.username = Some(value),
+ "deviceidentifier" => form.device_identifier = Some(value),
+ "devicename" => form.device_name = Some(value),
+ "devicetype" => form.device_type = Some(value),
+ "twofactorprovider" => form.two_factor_provider = value.parse().ok(),
+ "twofactortoken" => form.two_factor_token = Some(value),
+ "twofactorremember" => form.two_factor_remember = value.parse().ok(),
+ key => warn!("Detected unexpected parameter during login: {}", key),
+ }
}
+
+ Ok(form)
}
}