Compare commits

...

3 Commits

Author SHA1 Message Date
privacyguard 45c0a0030a use derive_new::new instead of TypedBuilder 2024-07-07 02:31:15 +03:00
privacyguard d9c7e96f31 update submodule to the latest version 2024-07-07 01:09:03 +03:00
privacyguard 7f9b897c69 Fixes and improvements based on review feedback 2024-07-07 00:51:24 +03:00
13 changed files with 81 additions and 127 deletions

View File

@ -57,16 +57,16 @@ pub struct DeleteOAuthProvider {
} }
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, Clone, Default)] #[derive(Debug, Serialize, Deserialize, Clone)]
#[cfg_attr(feature = "full", derive(TS))] #[cfg_attr(feature = "full", derive(TS))]
#[cfg_attr(feature = "full", ts(export))] #[cfg_attr(feature = "full", ts(export))]
/// Logging in with an OAuth 2.0 authorization /// Logging in with an OAuth 2.0 authorization
pub struct OAuth { pub struct AuthenticateWithOauth {
pub code: String, pub code: String,
#[cfg_attr(feature = "full", ts(type = "string"))] #[cfg_attr(feature = "full", ts(type = "string"))]
pub oauth_provider_id: OAuthProviderId, pub oauth_provider_id: OAuthProviderId,
#[cfg_attr(feature = "full", ts(type = "string"))] #[cfg_attr(feature = "full", ts(type = "string"))]
pub redirect_uri: Option<Url>, pub redirect_uri: Url,
} }
#[skip_serializing_none] #[skip_serializing_none]

View File

@ -313,7 +313,7 @@ pub struct GetSiteResponse {
/// A list of custom emojis your site supports. /// A list of custom emojis your site supports.
pub custom_emojis: Vec<CustomEmojiView>, pub custom_emojis: Vec<CustomEmojiView>,
/// A list of external auth methods your site supports. /// A list of external auth methods your site supports.
pub oauth_providers: Vec<Option<OAuthProvider>>, pub oauth_providers: Vec<OAuthProvider>,
pub blocked_urls: Vec<LocalSiteUrlBlocklist>, pub blocked_urls: Vec<LocalSiteUrlBlocklist>,
} }

View File

@ -36,23 +36,23 @@ pub async fn create_oauth_provider(
reader.read(&mut id_bytes); reader.read(&mut id_bytes);
let cloned_data = data.clone(); let cloned_data = data.clone();
let oauth_provider_form = OAuthProviderInsertForm::builder() let oauth_provider_form = OAuthProviderInsertForm {
.id(OAuthProviderId(i64::from_ne_bytes(id_bytes))) id: OAuthProviderId(i64::from_ne_bytes(id_bytes)),
.display_name(cloned_data.display_name) display_name: cloned_data.display_name,
.issuer(Url::parse(&cloned_data.issuer)?.into()) issuer: Url::parse(&cloned_data.issuer)?.into(),
.authorization_endpoint(Url::parse(&cloned_data.authorization_endpoint)?.into()) authorization_endpoint: Url::parse(&cloned_data.authorization_endpoint)?.into(),
.token_endpoint(Url::parse(&cloned_data.token_endpoint)?.into()) token_endpoint: Url::parse(&cloned_data.token_endpoint)?.into(),
.userinfo_endpoint(Url::parse(&cloned_data.userinfo_endpoint)?.into()) userinfo_endpoint: Url::parse(&cloned_data.userinfo_endpoint)?.into(),
.id_claim(cloned_data.id_claim) id_claim: cloned_data.id_claim,
.name_claim(cloned_data.name_claim) name_claim: cloned_data.name_claim,
.client_id(data.client_id.to_string()) client_id: data.client_id.to_string(),
.client_secret(data.client_secret.to_string()) client_secret: data.client_secret.to_string(),
.scopes(data.scopes.to_string()) scopes: data.scopes.to_string(),
.auto_verify_email(data.auto_verify_email) auto_verify_email: data.auto_verify_email,
.auto_approve_application(data.auto_approve_application) auto_approve_application: data.auto_approve_application,
.account_linking_enabled(data.account_linking_enabled) account_linking_enabled: data.account_linking_enabled,
.enabled(data.enabled) enabled: data.enabled,
.build(); };
let unsafe_oauth_provider = let unsafe_oauth_provider =
UnsafeOAuthProvider::create(&mut context.pool(), &oauth_provider_form).await?; UnsafeOAuthProvider::create(&mut context.pool(), &oauth_provider_form).await?;
Ok(Json(OAuthProvider::from_unsafe(&unsafe_oauth_provider))) Ok(Json(OAuthProvider::from_unsafe(&unsafe_oauth_provider)))

View File

@ -7,7 +7,7 @@ use lemmy_db_schema::{
utils::naive_now, utils::naive_now,
}; };
use lemmy_db_views::structs::LocalUserView; use lemmy_db_views::structs::LocalUserView;
use lemmy_utils::error::LemmyError; use lemmy_utils::{error::LemmyError, LemmyErrorType};
use url::Url; use url::Url;
#[tracing::instrument(skip(context))] #[tracing::instrument(skip(context))]
@ -20,28 +20,30 @@ pub async fn update_oauth_provider(
is_admin(&local_user_view)?; is_admin(&local_user_view)?;
let cloned_data = data.clone(); let cloned_data = data.clone();
let oauth_provider_form = OAuthProviderUpdateForm::builder() let oauth_provider_form = OAuthProviderUpdateForm {
.display_name(cloned_data.display_name) display_name: cloned_data.display_name,
.authorization_endpoint(Url::parse(&cloned_data.authorization_endpoint)?.into()) authorization_endpoint: Url::parse(&cloned_data.authorization_endpoint)?.into(),
.token_endpoint(Url::parse(&cloned_data.token_endpoint)?.into()) token_endpoint: Url::parse(&cloned_data.token_endpoint)?.into(),
.userinfo_endpoint(Url::parse(&cloned_data.userinfo_endpoint)?.into()) userinfo_endpoint: Url::parse(&cloned_data.userinfo_endpoint)?.into(),
.id_claim(data.id_claim.to_string()) id_claim: data.id_claim.to_string(),
.name_claim(data.name_claim.to_string()) name_claim: data.name_claim.to_string(),
.client_secret(if !data.client_secret.is_empty() { client_secret: if !data.client_secret.is_empty() {
Some(data.client_secret.to_string()) Some(data.client_secret.to_string())
} else { } else {
None None
}) },
.scopes(data.scopes.to_string()) scopes: data.scopes.to_string(),
.auto_verify_email(data.auto_verify_email) auto_verify_email: data.auto_verify_email,
.auto_approve_application(data.auto_approve_application) auto_approve_application: data.auto_approve_application,
.account_linking_enabled(data.account_linking_enabled) account_linking_enabled: data.account_linking_enabled,
.enabled(data.enabled) enabled: data.enabled,
.updated(naive_now()); updated: naive_now(),
};
let update_result = let update_result =
UnsafeOAuthProvider::update(&mut context.pool(), data.id, &oauth_provider_form.build()).await?; UnsafeOAuthProvider::update(&mut context.pool(), data.id, &oauth_provider_form).await?;
let unsafe_oauth_provider = let unsafe_oauth_provider = UnsafeOAuthProvider::read(&mut context.pool(), update_result.id)
UnsafeOAuthProvider::get(&mut context.pool(), update_result.id).await?; .await?
.ok_or(LemmyErrorType::CouldntFindOauthProvider)?;
Ok(Json(OAuthProvider::from_unsafe(&unsafe_oauth_provider))) Ok(Json(OAuthProvider::from_unsafe(&unsafe_oauth_provider)))
} }

View File

@ -115,12 +115,9 @@ pub async fn get_site(
Ok(Json(site_response)) Ok(Json(site_response))
} }
fn filter_oauth_providers(oauth_providers: &mut [Option<OAuthProvider>]) { fn filter_oauth_providers(oauth_providers: &mut Vec<OAuthProvider>) {
for oauth_provider_opt in oauth_providers { oauth_providers.retain_mut(|oauth_provider| {
if let Some(oauth_provider) = oauth_provider_opt { if oauth_provider.enabled.unwrap_or(false) {
if oauth_provider.enabled.is_some()
&& oauth_provider.enabled.expect("unexpected enabled value")
{
oauth_provider.issuer = None; oauth_provider.issuer = None;
oauth_provider.token_endpoint = None; oauth_provider.token_endpoint = None;
oauth_provider.userinfo_endpoint = None; oauth_provider.userinfo_endpoint = None;
@ -132,9 +129,9 @@ fn filter_oauth_providers(oauth_providers: &mut [Option<OAuthProvider>]) {
oauth_provider.enabled = None; oauth_provider.enabled = None;
oauth_provider.published = None; oauth_provider.published = None;
oauth_provider.updated = None; oauth_provider.updated = None;
true
} else { } else {
*oauth_provider_opt = None; false
}
}
} }
})
} }

View File

@ -3,7 +3,7 @@ use actix_web::{web::Json, HttpRequest};
use lemmy_api_common::{ use lemmy_api_common::{
claims::Claims, claims::Claims,
context::LemmyContext, context::LemmyContext,
oauth_provider::{OAuth, TokenResponse}, oauth_provider::{AuthenticateWithOauth, TokenResponse},
person::{LoginResponse, Register}, person::{LoginResponse, Register},
utils::{ utils::{
check_email_verified, check_email_verified,
@ -229,8 +229,8 @@ pub async fn register(
} }
#[tracing::instrument(skip(context))] #[tracing::instrument(skip(context))]
pub async fn register_from_oauth( pub async fn authenticate_with_oauth(
data: Json<OAuth>, data: Json<AuthenticateWithOauth>,
req: HttpRequest, req: HttpRequest,
context: Data<LemmyContext>, context: Data<LemmyContext>,
) -> LemmyResult<Json<LoginResponse>> { ) -> LemmyResult<Json<LoginResponse>> {
@ -242,7 +242,6 @@ pub async fn register_from_oauth(
// validate inputs // validate inputs
if data.oauth_provider_id == OAuthProviderId(0i64) if data.oauth_provider_id == OAuthProviderId(0i64)
|| data.redirect_uri.is_none()
|| data.code.is_empty() || data.code.is_empty()
|| data.code.len() > 300 || data.code.len() > 300
{ {
@ -250,10 +249,7 @@ pub async fn register_from_oauth(
} }
// validate the redirect_uri // validate the redirect_uri
let redirect_uri = data let redirect_uri = &data.redirect_uri;
.redirect_uri
.as_ref()
.ok_or(LemmyErrorType::OauthAuthorizationInvalid)?;
if !redirect_uri if !redirect_uri
.host_str() .host_str()
.unwrap_or("") .unwrap_or("")
@ -269,9 +265,10 @@ pub async fn register_from_oauth(
// Fetch the OAUTH provider and make sure it's enabled // Fetch the OAUTH provider and make sure it's enabled
let oauth_provider_id = data.oauth_provider_id; let oauth_provider_id = data.oauth_provider_id;
let oauth_provider = UnsafeOAuthProvider::get(&mut context.pool(), oauth_provider_id) let oauth_provider = UnsafeOAuthProvider::read(&mut context.pool(), oauth_provider_id)
.await .await
.ok() .ok()
.ok_or(LemmyErrorType::OauthAuthorizationInvalid)?
.ok_or(LemmyErrorType::OauthAuthorizationInvalid)?; .ok_or(LemmyErrorType::OauthAuthorizationInvalid)?;
if !oauth_provider.enabled { if !oauth_provider.enabled {
@ -393,11 +390,8 @@ pub async fn register_from_oauth(
if oauth_provider.account_linking_enabled { if oauth_provider.account_linking_enabled {
// Link with OAUTH => Login user // Link with OAUTH => Login user
let oauth_account_form = OAuthAccountInsertForm::builder() let oauth_account_form =
.local_user_id(user_view.local_user.id) OAuthAccountInsertForm::new(user_view.local_user.id, oauth_provider.id, oauth_user_id);
.oauth_provider_id(oauth_provider.id)
.oauth_user_id(oauth_user_id)
.build();
OAuthAccount::create(&mut context.pool(), &oauth_account_form) OAuthAccount::create(&mut context.pool(), &oauth_account_form)
.await .await
@ -455,11 +449,8 @@ pub async fn register_from_oauth(
.ok_or(LemmyErrorType::OauthLoginFailed)?; .ok_or(LemmyErrorType::OauthLoginFailed)?;
// Create the oauth account // Create the oauth account
let oauth_account_form = OAuthAccountInsertForm::builder() let oauth_account_form =
.local_user_id(local_user.id) OAuthAccountInsertForm::new(local_user.id, oauth_provider.id, oauth_user_id);
.oauth_provider_id(oauth_provider.id)
.oauth_user_id(oauth_user_id)
.build();
OAuthAccount::create(&mut context.pool(), &oauth_account_form) OAuthAccount::create(&mut context.pool(), &oauth_account_form)
.await .await

View File

@ -33,20 +33,3 @@ impl Crud for OAuthAccount {
.await .await
} }
} }
impl OAuthAccount {
pub async fn get(pool: &mut DbPool<'_>, oauth_account_id: OAuthAccountId) -> Result<Self, Error> {
let conn = &mut get_conn(pool).await?;
let oauth_accounts = oauth_account::table
.find(oauth_account_id)
.select(oauth_account::all_columns)
.limit(1)
.load::<OAuthAccount>(conn)
.await?;
if let Some(oauth_account) = oauth_accounts.into_iter().next() {
Ok(oauth_account)
} else {
Err(diesel::result::Error::NotFound)
}
}
}

View File

@ -41,24 +41,6 @@ impl Crud for UnsafeOAuthProvider {
} }
impl UnsafeOAuthProvider { impl UnsafeOAuthProvider {
pub async fn get(
pool: &mut DbPool<'_>,
oauth_provider_id: OAuthProviderId,
) -> Result<Self, Error> {
let conn = &mut get_conn(pool).await?;
let oauth_providers = oauth_provider::table
.find(oauth_provider_id)
.select(oauth_provider::all_columns)
.limit(1)
.load::<UnsafeOAuthProvider>(conn)
.await?;
if let Some(oauth_provider) = oauth_providers.into_iter().next() {
Ok(oauth_provider)
} else {
Err(diesel::result::Error::NotFound)
}
}
pub async fn get_all(pool: &mut DbPool<'_>) -> Result<Vec<Self>, Error> { pub async fn get_all(pool: &mut DbPool<'_>) -> Result<Vec<Self>, Error> {
let conn = &mut get_conn(pool).await?; let conn = &mut get_conn(pool).await?;
let oauth_providers = oauth_provider::table let oauth_providers = oauth_provider::table
@ -72,12 +54,12 @@ impl UnsafeOAuthProvider {
} }
impl OAuthProvider { impl OAuthProvider {
pub async fn get_all(pool: &mut DbPool<'_>) -> Result<Vec<Option<Self>>, Error> { pub async fn get_all(pool: &mut DbPool<'_>) -> Result<Vec<Self>, Error> {
let oauth_providers = UnsafeOAuthProvider::get_all(pool).await?; let oauth_providers = UnsafeOAuthProvider::get_all(pool).await?;
let mut result = Vec::<Option<OAuthProvider>>::new(); let mut result = Vec::<OAuthProvider>::new();
for oauth_provider in &oauth_providers { for oauth_provider in &oauth_providers {
result.push(Some(Self::from_unsafe(oauth_provider))); result.push(Self::from_unsafe(oauth_provider));
} }
Ok(result) Ok(result)

View File

@ -6,7 +6,6 @@ use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none; use serde_with::skip_serializing_none;
#[cfg(feature = "full")] #[cfg(feature = "full")]
use ts_rs::TS; use ts_rs::TS;
use typed_builder::TypedBuilder;
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)] #[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
@ -24,7 +23,7 @@ pub struct OAuthAccount {
pub updated: Option<DateTime<Utc>>, pub updated: Option<DateTime<Utc>>,
} }
#[derive(Debug, Clone, TypedBuilder)] #[derive(Debug, Clone, derive_new::new)]
#[cfg_attr(feature = "full", derive(Insertable, AsChangeset))] #[cfg_attr(feature = "full", derive(Insertable, AsChangeset))]
#[cfg_attr(feature = "full", diesel(table_name = oauth_account))] #[cfg_attr(feature = "full", diesel(table_name = oauth_account))]
pub struct OAuthAccountInsertForm { pub struct OAuthAccountInsertForm {
@ -33,7 +32,7 @@ pub struct OAuthAccountInsertForm {
pub oauth_user_id: String, pub oauth_user_id: String,
} }
#[derive(Debug, Clone, TypedBuilder)] #[derive(Debug, Clone, derive_new::new)]
#[cfg_attr(feature = "full", derive(Insertable, AsChangeset))] #[cfg_attr(feature = "full", derive(Insertable, AsChangeset))]
#[cfg_attr(feature = "full", diesel(table_name = oauth_account))] #[cfg_attr(feature = "full", diesel(table_name = oauth_account))]
pub struct OAuthAccountUpdateForm { pub struct OAuthAccountUpdateForm {

View File

@ -6,7 +6,6 @@ use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none; use serde_with::skip_serializing_none;
#[cfg(feature = "full")] #[cfg(feature = "full")]
use ts_rs::TS; use ts_rs::TS;
use typed_builder::TypedBuilder;
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)] #[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
@ -110,7 +109,7 @@ pub struct OAuthProvider {
pub updated: Option<DateTime<Utc>>, pub updated: Option<DateTime<Utc>>,
} }
#[derive(Debug, Clone, TypedBuilder)] #[derive(Debug, Clone)]
#[cfg_attr(feature = "full", derive(Insertable, AsChangeset, TS))] #[cfg_attr(feature = "full", derive(Insertable, AsChangeset, TS))]
#[cfg_attr(feature = "full", diesel(table_name = oauth_provider))] #[cfg_attr(feature = "full", diesel(table_name = oauth_provider))]
#[cfg_attr(feature = "full", ts(export))] #[cfg_attr(feature = "full", ts(export))]
@ -136,7 +135,7 @@ pub struct OAuthProviderInsertForm {
pub enabled: bool, pub enabled: bool,
} }
#[derive(Debug, Clone, TypedBuilder)] #[derive(Debug, Clone)]
#[cfg_attr(feature = "full", derive(Insertable, AsChangeset, TS))] #[cfg_attr(feature = "full", derive(Insertable, AsChangeset, TS))]
#[cfg_attr(feature = "full", diesel(table_name = oauth_provider))] #[cfg_attr(feature = "full", diesel(table_name = oauth_provider))]
#[cfg_attr(feature = "full", ts(export))] #[cfg_attr(feature = "full", ts(export))]

View File

@ -53,6 +53,7 @@ pub enum LemmyErrorType {
CouldntFindCommentReply, CouldntFindCommentReply,
CouldntFindPrivateMessage, CouldntFindPrivateMessage,
CouldntFindActivity, CouldntFindActivity,
CouldntFindOauthProvider,
PersonIsBlocked, PersonIsBlocked,
CommunityIsBlocked, CommunityIsBlocked,
InstanceIsBlocked, InstanceIsBlocked,

@ -1 +1 @@
Subproject commit e9b3b25fa1af7e06c4ffab86624d95da0836ef36 Subproject commit 94f0c7e44e967ea6d003ee03b1753f08011fcf53

View File

@ -128,7 +128,7 @@ use lemmy_api_crud::{
}, },
site::{create::create_site, read::get_site, update::update_site}, site::{create::create_site, read::get_site, update::update_site},
user::{ user::{
create::{register, register_from_oauth}, create::{authenticate_with_oauth, register},
delete::delete_account, delete::delete_account,
}, },
}; };
@ -395,7 +395,7 @@ pub fn config(cfg: &mut web::ServiceConfig, rate_limit: &RateLimitCell) {
.service( .service(
web::scope("/oauth") web::scope("/oauth")
.wrap(rate_limit.register()) .wrap(rate_limit.register())
.route("/register", web::post().to(register_from_oauth)), .route("/authenticate", web::post().to(authenticate_with_oauth)),
), ),
); );
cfg.service( cfg.service(