use crate::{error::LemmyError, IpAddr}; use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform}; use futures::future::{ok, Ready}; use rate_limiter::{RateLimitStorage, RateLimitType}; use serde::{Deserialize, Serialize}; use std::{ future::Future, pin::Pin, rc::Rc, sync::{Arc, Mutex}, task::{Context, Poll}, }; use tokio::sync::{mpsc, mpsc::Sender, OnceCell}; use typed_builder::TypedBuilder; pub mod rate_limiter; #[derive(Debug, Deserialize, Serialize, Clone, TypedBuilder)] pub struct RateLimitConfig { #[builder(default = 180)] /// Maximum number of messages created in interval pub message: i32, #[builder(default = 60)] /// Interval length for message limit, in seconds pub message_per_second: i32, #[builder(default = 6)] /// Maximum number of posts created in interval pub post: i32, #[builder(default = 300)] /// Interval length for post limit, in seconds pub post_per_second: i32, #[builder(default = 3)] /// Maximum number of registrations in interval pub register: i32, #[builder(default = 3600)] /// Interval length for registration limit, in seconds pub register_per_second: i32, #[builder(default = 6)] /// Maximum number of image uploads in interval pub image: i32, #[builder(default = 3600)] /// Interval length for image uploads, in seconds pub image_per_second: i32, #[builder(default = 6)] /// Maximum number of comments created in interval pub comment: i32, #[builder(default = 600)] /// Interval length for comment limit, in seconds pub comment_per_second: i32, #[builder(default = 60)] /// Maximum number of searches created in interval pub search: i32, #[builder(default = 600)] /// Interval length for search limit, in seconds pub search_per_second: i32, } #[derive(Debug, Clone)] struct RateLimit { pub rate_limiter: RateLimitStorage, pub rate_limit_config: RateLimitConfig, } #[derive(Debug, Clone)] pub struct RateLimitedGuard { rate_limit: Arc>, type_: RateLimitType, } /// Single instance of rate limit config and buckets, which is shared across all threads. #[derive(Clone)] pub struct RateLimitCell { tx: Sender, rate_limit: Arc>, } impl RateLimitCell { /// Initialize cell if it wasnt initialized yet. Otherwise returns the existing cell. pub async fn new(rate_limit_config: RateLimitConfig) -> &'static Self { static LOCAL_INSTANCE: OnceCell = OnceCell::const_new(); LOCAL_INSTANCE .get_or_init(|| async { let (tx, mut rx) = mpsc::channel::(4); let rate_limit = Arc::new(Mutex::new(RateLimit { rate_limiter: Default::default(), rate_limit_config, })); let rate_limit2 = rate_limit.clone(); tokio::spawn(async move { while let Some(r) = rx.recv().await { rate_limit2 .lock() .expect("Failed to lock rate limit mutex for updating") .rate_limit_config = r; } }); RateLimitCell { tx, rate_limit } }) .await } /// Call this when the config was updated, to update all in-memory cells. pub async fn send(&self, config: RateLimitConfig) -> Result<(), LemmyError> { self.tx.send(config).await?; Ok(()) } pub fn message(&self) -> RateLimitedGuard { self.kind(RateLimitType::Message) } pub fn post(&self) -> RateLimitedGuard { self.kind(RateLimitType::Post) } pub fn register(&self) -> RateLimitedGuard { self.kind(RateLimitType::Register) } pub fn image(&self) -> RateLimitedGuard { self.kind(RateLimitType::Image) } pub fn comment(&self) -> RateLimitedGuard { self.kind(RateLimitType::Comment) } pub fn search(&self) -> RateLimitedGuard { self.kind(RateLimitType::Search) } fn kind(&self, type_: RateLimitType) -> RateLimitedGuard { RateLimitedGuard { rate_limit: self.rate_limit.clone(), type_, } } } pub struct RateLimitedMiddleware { rate_limited: RateLimitedGuard, service: Rc, } impl RateLimitedGuard { /// Returns true if the request passed the rate limit, false if it failed and should be rejected. pub fn check(self, ip_addr: IpAddr) -> bool { // Does not need to be blocking because the RwLock in settings never held across await points, // and the operation here locks only long enough to clone let mut guard = self .rate_limit .lock() .expect("Failed to lock rate limit mutex for reading"); let rate_limit = &guard.rate_limit_config; let (kind, interval) = match self.type_ { RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second), RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second), RateLimitType::Register => (rate_limit.register, rate_limit.register_per_second), RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second), RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second), RateLimitType::Search => (rate_limit.search, rate_limit.search_per_second), }; let limiter = &mut guard.rate_limiter; limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval) } } impl Transform for RateLimitedGuard where S: Service + 'static, S::Future: 'static, { type Response = S::Response; type Error = actix_web::Error; type InitError = (); type Transform = RateLimitedMiddleware; type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(RateLimitedMiddleware { rate_limited: self.clone(), service: Rc::new(service), }) } } type FutResult = dyn Future>; impl Service for RateLimitedMiddleware where S: Service + 'static, S::Future: 'static, { type Response = S::Response; type Error = actix_web::Error; type Future = Pin>>; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx) } fn call(&self, req: ServiceRequest) -> Self::Future { let ip_addr = get_ip(&req.connection_info()); let rate_limited = self.rate_limited.clone(); let service = self.service.clone(); Box::pin(async move { if rate_limited.check(ip_addr) { service.call(req).await } else { let (http_req, _) = req.into_parts(); Ok(ServiceResponse::from_err( LemmyError::from_message("rate_limit_error"), http_req, )) } }) } } fn get_ip(conn_info: &ConnectionInfo) -> IpAddr { IpAddr( conn_info .realip_remote_addr() .unwrap_or("127.0.0.1:12345") .split(':') .next() .unwrap_or("127.0.0.1") .to_string(), ) }