Fixes and improvements based on review feedback

pull/4881/head
privacyguard 2024-07-06 21:51:59 +03:00
parent 6dd5613f46
commit 7f9b897c69
9 changed files with 38 additions and 77 deletions

View File

@ -57,16 +57,16 @@ pub struct DeleteOAuthProvider {
}
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, Clone, Default)]
#[derive(Debug, Serialize, Deserialize, Clone)]
#[cfg_attr(feature = "full", derive(TS))]
#[cfg_attr(feature = "full", ts(export))]
/// Logging in with an OAuth 2.0 authorization
pub struct OAuth {
pub struct AuthenticateWithOauth {
pub code: String,
#[cfg_attr(feature = "full", ts(type = "string"))]
pub oauth_provider_id: OAuthProviderId,
#[cfg_attr(feature = "full", ts(type = "string"))]
pub redirect_uri: Option<Url>,
pub redirect_uri: Url,
}
#[skip_serializing_none]

View File

@ -313,7 +313,7 @@ pub struct GetSiteResponse {
/// A list of custom emojis your site supports.
pub custom_emojis: Vec<CustomEmojiView>,
/// A list of external auth methods your site supports.
pub oauth_providers: Vec<Option<OAuthProvider>>,
pub oauth_providers: Vec<OAuthProvider>,
pub blocked_urls: Vec<LocalSiteUrlBlocklist>,
}

View File

@ -7,7 +7,7 @@ use lemmy_db_schema::{
utils::naive_now,
};
use lemmy_db_views::structs::LocalUserView;
use lemmy_utils::error::LemmyError;
use lemmy_utils::{error::LemmyError, LemmyErrorType};
use url::Url;
#[tracing::instrument(skip(context))]
@ -41,7 +41,8 @@ pub async fn update_oauth_provider(
let update_result =
UnsafeOAuthProvider::update(&mut context.pool(), data.id, &oauth_provider_form.build()).await?;
let unsafe_oauth_provider =
UnsafeOAuthProvider::get(&mut context.pool(), update_result.id).await?;
let unsafe_oauth_provider = UnsafeOAuthProvider::read(&mut context.pool(), update_result.id)
.await?
.ok_or(LemmyErrorType::CouldntFindOauthProvider)?;
Ok(Json(OAuthProvider::from_unsafe(&unsafe_oauth_provider)))
}

View File

@ -115,26 +115,23 @@ pub async fn get_site(
Ok(Json(site_response))
}
fn filter_oauth_providers(oauth_providers: &mut [Option<OAuthProvider>]) {
for oauth_provider_opt in oauth_providers {
if let Some(oauth_provider) = oauth_provider_opt {
if oauth_provider.enabled.is_some()
&& oauth_provider.enabled.expect("unexpected enabled value")
{
oauth_provider.issuer = None;
oauth_provider.token_endpoint = None;
oauth_provider.userinfo_endpoint = None;
oauth_provider.id_claim = None;
oauth_provider.name_claim = None;
oauth_provider.auto_verify_email = None;
oauth_provider.auto_approve_application = None;
oauth_provider.account_linking_enabled = None;
oauth_provider.enabled = None;
oauth_provider.published = None;
oauth_provider.updated = None;
} else {
*oauth_provider_opt = None;
}
fn filter_oauth_providers(oauth_providers: &mut Vec<OAuthProvider>) {
oauth_providers.retain_mut(|oauth_provider| {
if oauth_provider.enabled.unwrap_or(false) {
oauth_provider.issuer = None;
oauth_provider.token_endpoint = None;
oauth_provider.userinfo_endpoint = None;
oauth_provider.id_claim = None;
oauth_provider.name_claim = None;
oauth_provider.auto_verify_email = None;
oauth_provider.auto_approve_application = None;
oauth_provider.account_linking_enabled = None;
oauth_provider.enabled = None;
oauth_provider.published = None;
oauth_provider.updated = None;
true
} else {
false
}
}
})
}

View File

@ -3,7 +3,7 @@ use actix_web::{web::Json, HttpRequest};
use lemmy_api_common::{
claims::Claims,
context::LemmyContext,
oauth_provider::{OAuth, TokenResponse},
oauth_provider::{AuthenticateWithOauth, TokenResponse},
person::{LoginResponse, Register},
utils::{
check_email_verified,
@ -229,8 +229,8 @@ pub async fn register(
}
#[tracing::instrument(skip(context))]
pub async fn register_from_oauth(
data: Json<OAuth>,
pub async fn authenticate_with_oauth(
data: Json<AuthenticateWithOauth>,
req: HttpRequest,
context: Data<LemmyContext>,
) -> LemmyResult<Json<LoginResponse>> {
@ -242,7 +242,6 @@ pub async fn register_from_oauth(
// validate inputs
if data.oauth_provider_id == OAuthProviderId(0i64)
|| data.redirect_uri.is_none()
|| data.code.is_empty()
|| data.code.len() > 300
{
@ -250,10 +249,7 @@ pub async fn register_from_oauth(
}
// validate the redirect_uri
let redirect_uri = data
.redirect_uri
.as_ref()
.ok_or(LemmyErrorType::OauthAuthorizationInvalid)?;
let redirect_uri = &data.redirect_uri;
if !redirect_uri
.host_str()
.unwrap_or("")
@ -269,9 +265,10 @@ pub async fn register_from_oauth(
// Fetch the OAUTH provider and make sure it's enabled
let oauth_provider_id = data.oauth_provider_id;
let oauth_provider = UnsafeOAuthProvider::get(&mut context.pool(), oauth_provider_id)
let oauth_provider = UnsafeOAuthProvider::read(&mut context.pool(), oauth_provider_id)
.await
.ok()
.ok_or(LemmyErrorType::OauthAuthorizationInvalid)?
.ok_or(LemmyErrorType::OauthAuthorizationInvalid)?;
if !oauth_provider.enabled {

View File

@ -33,20 +33,3 @@ impl Crud for OAuthAccount {
.await
}
}
impl OAuthAccount {
pub async fn get(pool: &mut DbPool<'_>, oauth_account_id: OAuthAccountId) -> Result<Self, Error> {
let conn = &mut get_conn(pool).await?;
let oauth_accounts = oauth_account::table
.find(oauth_account_id)
.select(oauth_account::all_columns)
.limit(1)
.load::<OAuthAccount>(conn)
.await?;
if let Some(oauth_account) = oauth_accounts.into_iter().next() {
Ok(oauth_account)
} else {
Err(diesel::result::Error::NotFound)
}
}
}

View File

@ -41,24 +41,6 @@ impl Crud for UnsafeOAuthProvider {
}
impl UnsafeOAuthProvider {
pub async fn get(
pool: &mut DbPool<'_>,
oauth_provider_id: OAuthProviderId,
) -> Result<Self, Error> {
let conn = &mut get_conn(pool).await?;
let oauth_providers = oauth_provider::table
.find(oauth_provider_id)
.select(oauth_provider::all_columns)
.limit(1)
.load::<UnsafeOAuthProvider>(conn)
.await?;
if let Some(oauth_provider) = oauth_providers.into_iter().next() {
Ok(oauth_provider)
} else {
Err(diesel::result::Error::NotFound)
}
}
pub async fn get_all(pool: &mut DbPool<'_>) -> Result<Vec<Self>, Error> {
let conn = &mut get_conn(pool).await?;
let oauth_providers = oauth_provider::table
@ -72,12 +54,12 @@ impl UnsafeOAuthProvider {
}
impl OAuthProvider {
pub async fn get_all(pool: &mut DbPool<'_>) -> Result<Vec<Option<Self>>, Error> {
pub async fn get_all(pool: &mut DbPool<'_>) -> Result<Vec<Self>, Error> {
let oauth_providers = UnsafeOAuthProvider::get_all(pool).await?;
let mut result = Vec::<Option<OAuthProvider>>::new();
let mut result = Vec::<OAuthProvider>::new();
for oauth_provider in &oauth_providers {
result.push(Some(Self::from_unsafe(oauth_provider)));
result.push(Self::from_unsafe(oauth_provider));
}
Ok(result)

View File

@ -53,6 +53,7 @@ pub enum LemmyErrorType {
CouldntFindCommentReply,
CouldntFindPrivateMessage,
CouldntFindActivity,
CouldntFindOauthProvider,
PersonIsBlocked,
CommunityIsBlocked,
InstanceIsBlocked,

View File

@ -128,7 +128,7 @@ use lemmy_api_crud::{
},
site::{create::create_site, read::get_site, update::update_site},
user::{
create::{register, register_from_oauth},
create::{authenticate_with_oauth, register},
delete::delete_account,
},
};
@ -395,7 +395,7 @@ pub fn config(cfg: &mut web::ServiceConfig, rate_limit: &RateLimitCell) {
.service(
web::scope("/oauth")
.wrap(rate_limit.register())
.route("/register", web::post().to(register_from_oauth)),
.route("/authenticate", web::post().to(authenticate_with_oauth)),
),
);
cfg.service(