From 7f9b897c69ac9b5c42f6095d08bbe29401b6aac5 Mon Sep 17 00:00:00 2001 From: privacyguard Date: Sat, 6 Jul 2024 21:51:59 +0300 Subject: [PATCH] Fixes and improvements based on review feedback --- crates/api_common/src/oauth_provider.rs | 6 +-- crates/api_common/src/site.rs | 2 +- crates/api_crud/src/oauth_provider/update.rs | 7 ++-- crates/api_crud/src/site/read.rs | 39 +++++++++----------- crates/api_crud/src/user/create.rs | 15 +++----- crates/db_schema/src/impls/oauth_account.rs | 17 --------- crates/db_schema/src/impls/oauth_provider.rs | 24 ++---------- crates/utils/src/error.rs | 1 + src/api_routes_http.rs | 4 +- 9 files changed, 38 insertions(+), 77 deletions(-) diff --git a/crates/api_common/src/oauth_provider.rs b/crates/api_common/src/oauth_provider.rs index 052e9fa61..a338ee8fa 100644 --- a/crates/api_common/src/oauth_provider.rs +++ b/crates/api_common/src/oauth_provider.rs @@ -57,16 +57,16 @@ pub struct DeleteOAuthProvider { } #[skip_serializing_none] -#[derive(Debug, Serialize, Deserialize, Clone, Default)] +#[derive(Debug, Serialize, Deserialize, Clone)] #[cfg_attr(feature = "full", derive(TS))] #[cfg_attr(feature = "full", ts(export))] /// Logging in with an OAuth 2.0 authorization -pub struct OAuth { +pub struct AuthenticateWithOauth { pub code: String, #[cfg_attr(feature = "full", ts(type = "string"))] pub oauth_provider_id: OAuthProviderId, #[cfg_attr(feature = "full", ts(type = "string"))] - pub redirect_uri: Option, + pub redirect_uri: Url, } #[skip_serializing_none] diff --git a/crates/api_common/src/site.rs b/crates/api_common/src/site.rs index 3e41d387b..0253d47bc 100644 --- a/crates/api_common/src/site.rs +++ b/crates/api_common/src/site.rs @@ -313,7 +313,7 @@ pub struct GetSiteResponse { /// A list of custom emojis your site supports. pub custom_emojis: Vec, /// A list of external auth methods your site supports. - pub oauth_providers: Vec>, + pub oauth_providers: Vec, pub blocked_urls: Vec, } diff --git a/crates/api_crud/src/oauth_provider/update.rs b/crates/api_crud/src/oauth_provider/update.rs index f19920b57..5a21c5f39 100644 --- a/crates/api_crud/src/oauth_provider/update.rs +++ b/crates/api_crud/src/oauth_provider/update.rs @@ -7,7 +7,7 @@ use lemmy_db_schema::{ utils::naive_now, }; use lemmy_db_views::structs::LocalUserView; -use lemmy_utils::error::LemmyError; +use lemmy_utils::{error::LemmyError, LemmyErrorType}; use url::Url; #[tracing::instrument(skip(context))] @@ -41,7 +41,8 @@ pub async fn update_oauth_provider( let update_result = UnsafeOAuthProvider::update(&mut context.pool(), data.id, &oauth_provider_form.build()).await?; - let unsafe_oauth_provider = - UnsafeOAuthProvider::get(&mut context.pool(), update_result.id).await?; + let unsafe_oauth_provider = UnsafeOAuthProvider::read(&mut context.pool(), update_result.id) + .await? + .ok_or(LemmyErrorType::CouldntFindOauthProvider)?; Ok(Json(OAuthProvider::from_unsafe(&unsafe_oauth_provider))) } diff --git a/crates/api_crud/src/site/read.rs b/crates/api_crud/src/site/read.rs index de9623d5c..49d48d320 100644 --- a/crates/api_crud/src/site/read.rs +++ b/crates/api_crud/src/site/read.rs @@ -115,26 +115,23 @@ pub async fn get_site( Ok(Json(site_response)) } -fn filter_oauth_providers(oauth_providers: &mut [Option]) { - for oauth_provider_opt in oauth_providers { - if let Some(oauth_provider) = oauth_provider_opt { - if oauth_provider.enabled.is_some() - && oauth_provider.enabled.expect("unexpected enabled value") - { - oauth_provider.issuer = None; - oauth_provider.token_endpoint = None; - oauth_provider.userinfo_endpoint = None; - oauth_provider.id_claim = None; - oauth_provider.name_claim = None; - oauth_provider.auto_verify_email = None; - oauth_provider.auto_approve_application = None; - oauth_provider.account_linking_enabled = None; - oauth_provider.enabled = None; - oauth_provider.published = None; - oauth_provider.updated = None; - } else { - *oauth_provider_opt = None; - } +fn filter_oauth_providers(oauth_providers: &mut Vec) { + oauth_providers.retain_mut(|oauth_provider| { + if oauth_provider.enabled.unwrap_or(false) { + oauth_provider.issuer = None; + oauth_provider.token_endpoint = None; + oauth_provider.userinfo_endpoint = None; + oauth_provider.id_claim = None; + oauth_provider.name_claim = None; + oauth_provider.auto_verify_email = None; + oauth_provider.auto_approve_application = None; + oauth_provider.account_linking_enabled = None; + oauth_provider.enabled = None; + oauth_provider.published = None; + oauth_provider.updated = None; + true + } else { + false } - } + }) } diff --git a/crates/api_crud/src/user/create.rs b/crates/api_crud/src/user/create.rs index c65e9686f..074bf6ae8 100644 --- a/crates/api_crud/src/user/create.rs +++ b/crates/api_crud/src/user/create.rs @@ -3,7 +3,7 @@ use actix_web::{web::Json, HttpRequest}; use lemmy_api_common::{ claims::Claims, context::LemmyContext, - oauth_provider::{OAuth, TokenResponse}, + oauth_provider::{AuthenticateWithOauth, TokenResponse}, person::{LoginResponse, Register}, utils::{ check_email_verified, @@ -229,8 +229,8 @@ pub async fn register( } #[tracing::instrument(skip(context))] -pub async fn register_from_oauth( - data: Json, +pub async fn authenticate_with_oauth( + data: Json, req: HttpRequest, context: Data, ) -> LemmyResult> { @@ -242,7 +242,6 @@ pub async fn register_from_oauth( // validate inputs if data.oauth_provider_id == OAuthProviderId(0i64) - || data.redirect_uri.is_none() || data.code.is_empty() || data.code.len() > 300 { @@ -250,10 +249,7 @@ pub async fn register_from_oauth( } // validate the redirect_uri - let redirect_uri = data - .redirect_uri - .as_ref() - .ok_or(LemmyErrorType::OauthAuthorizationInvalid)?; + let redirect_uri = &data.redirect_uri; if !redirect_uri .host_str() .unwrap_or("") @@ -269,9 +265,10 @@ pub async fn register_from_oauth( // Fetch the OAUTH provider and make sure it's enabled let oauth_provider_id = data.oauth_provider_id; - let oauth_provider = UnsafeOAuthProvider::get(&mut context.pool(), oauth_provider_id) + let oauth_provider = UnsafeOAuthProvider::read(&mut context.pool(), oauth_provider_id) .await .ok() + .ok_or(LemmyErrorType::OauthAuthorizationInvalid)? .ok_or(LemmyErrorType::OauthAuthorizationInvalid)?; if !oauth_provider.enabled { diff --git a/crates/db_schema/src/impls/oauth_account.rs b/crates/db_schema/src/impls/oauth_account.rs index 64a3787c2..81db68f25 100644 --- a/crates/db_schema/src/impls/oauth_account.rs +++ b/crates/db_schema/src/impls/oauth_account.rs @@ -33,20 +33,3 @@ impl Crud for OAuthAccount { .await } } - -impl OAuthAccount { - pub async fn get(pool: &mut DbPool<'_>, oauth_account_id: OAuthAccountId) -> Result { - let conn = &mut get_conn(pool).await?; - let oauth_accounts = oauth_account::table - .find(oauth_account_id) - .select(oauth_account::all_columns) - .limit(1) - .load::(conn) - .await?; - if let Some(oauth_account) = oauth_accounts.into_iter().next() { - Ok(oauth_account) - } else { - Err(diesel::result::Error::NotFound) - } - } -} diff --git a/crates/db_schema/src/impls/oauth_provider.rs b/crates/db_schema/src/impls/oauth_provider.rs index 9e594ef54..e464fa6bc 100644 --- a/crates/db_schema/src/impls/oauth_provider.rs +++ b/crates/db_schema/src/impls/oauth_provider.rs @@ -41,24 +41,6 @@ impl Crud for UnsafeOAuthProvider { } impl UnsafeOAuthProvider { - pub async fn get( - pool: &mut DbPool<'_>, - oauth_provider_id: OAuthProviderId, - ) -> Result { - let conn = &mut get_conn(pool).await?; - let oauth_providers = oauth_provider::table - .find(oauth_provider_id) - .select(oauth_provider::all_columns) - .limit(1) - .load::(conn) - .await?; - if let Some(oauth_provider) = oauth_providers.into_iter().next() { - Ok(oauth_provider) - } else { - Err(diesel::result::Error::NotFound) - } - } - pub async fn get_all(pool: &mut DbPool<'_>) -> Result, Error> { let conn = &mut get_conn(pool).await?; let oauth_providers = oauth_provider::table @@ -72,12 +54,12 @@ impl UnsafeOAuthProvider { } impl OAuthProvider { - pub async fn get_all(pool: &mut DbPool<'_>) -> Result>, Error> { + pub async fn get_all(pool: &mut DbPool<'_>) -> Result, Error> { let oauth_providers = UnsafeOAuthProvider::get_all(pool).await?; - let mut result = Vec::>::new(); + let mut result = Vec::::new(); for oauth_provider in &oauth_providers { - result.push(Some(Self::from_unsafe(oauth_provider))); + result.push(Self::from_unsafe(oauth_provider)); } Ok(result) diff --git a/crates/utils/src/error.rs b/crates/utils/src/error.rs index 9e4caced6..50903e589 100644 --- a/crates/utils/src/error.rs +++ b/crates/utils/src/error.rs @@ -53,6 +53,7 @@ pub enum LemmyErrorType { CouldntFindCommentReply, CouldntFindPrivateMessage, CouldntFindActivity, + CouldntFindOauthProvider, PersonIsBlocked, CommunityIsBlocked, InstanceIsBlocked, diff --git a/src/api_routes_http.rs b/src/api_routes_http.rs index 280d62488..5625a8cfd 100644 --- a/src/api_routes_http.rs +++ b/src/api_routes_http.rs @@ -128,7 +128,7 @@ use lemmy_api_crud::{ }, site::{create::create_site, read::get_site, update::update_site}, user::{ - create::{register, register_from_oauth}, + create::{authenticate_with_oauth, register}, delete::delete_account, }, }; @@ -395,7 +395,7 @@ pub fn config(cfg: &mut web::ServiceConfig, rate_limit: &RateLimitCell) { .service( web::scope("/oauth") .wrap(rate_limit.register()) - .route("/register", web::post().to(register_from_oauth)), + .route("/authenticate", web::post().to(authenticate_with_oauth)), ), ); cfg.service(