diff --git a/conduwuit-example.toml b/conduwuit-example.toml index 9807746a..e80d5554 100644 --- a/conduwuit-example.toml +++ b/conduwuit-example.toml @@ -299,6 +299,16 @@ # #max_concurrent_inbound_transactions = 150 +# Maximum age (in seconds) for cached federation transaction responses. +# Entries older than this will be removed during cleanup. +# +#transaction_id_cache_max_age_secs = 7200 (2 hours) + +# Maximum number of cached federation transaction responses. +# When the cache exceeds this limit, older entries will be removed. +# +#transaction_id_cache_max_entries = 8192 + # Default/base connection timeout (seconds). This is used only by URL # previews and update/news endpoint checks. # diff --git a/src/api/server/send.rs b/src/api/server/send.rs index 336df99d..7f5fb502 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -42,7 +42,7 @@ use ruma::{ to_device::DeviceIdOrAllDevices, uint, }; -use service::transaction_ids::{TxnKey, WrappedTransactionResponse}; +use service::transaction_ids::{FederationTxnState, TxnKey, WrappedTransactionResponse}; use tokio::sync::watch::{Receiver, Sender}; use tracing::instrument; @@ -65,18 +65,6 @@ pub(crate) async fn send_transaction_message_route( ))); } - let txn_key = (body.origin().to_owned(), body.transaction_id.clone()); - - // Did we already process this transaction - if let Some(response) = services.transaction_ids.get_cached_txn(&txn_key) { - return Ok(response); - } - // Or are currently processing it - if let Some(receiver) = services.transaction_ids.get_active_federation_txn(&txn_key) { - // Wait up to 50 seconds for a result - return wait_for_result(receiver).await; - } - if body.pdus.len() > PDU_LIMIT { return Err!(Request(Forbidden( "Not allowed to send more than {PDU_LIMIT} PDUs in one transaction" @@ -89,21 +77,31 @@ pub(crate) async fn send_transaction_message_route( ))); } - let sender = services + let txn_key = (body.origin().to_owned(), body.transaction_id.clone()); + + // Atomically check cache, join active, or start new transaction + match services .transaction_ids - .start_federation_txn(txn_key.clone())?; - services.server.runtime().spawn(process_inbound_transaction( - services, - body, - client, - txn_key.clone(), - sender, - )); - let receiver = services - .transaction_ids - .get_active_federation_txn(&txn_key) - .expect("just-created transaction was missing"); - wait_for_result(receiver).await + .get_or_start_federation_txn(txn_key.clone())? + { + | FederationTxnState::Cached(response) => { + // Already responded + Ok(response) + }, + | FederationTxnState::Active(receiver) => { + // Another thread is processing + wait_for_result(receiver).await + }, + | FederationTxnState::Started { receiver, sender } => { + // We're the first, spawn the processing task + services + .server + .runtime() + .spawn(process_inbound_transaction(services, body, client, txn_key, sender)); + // and wait for it + wait_for_result(receiver).await + }, + } } async fn wait_for_result( @@ -161,7 +159,7 @@ async fn process_inbound_transaction( // think we processed it properly (and lose events), but we also can't return // an actual error. drop(sender); - services.transaction_ids.finish_federation_txn(&txn_key); + services.transaction_ids.remove_federation_txn(&txn_key); panic!("failed to handle incoming transaction"); }; @@ -186,14 +184,10 @@ async fn process_inbound_transaction( .map(|(e, r)| (e, r.map_err(error::sanitized_message))) .collect(), }; + services .transaction_ids - .set_cached_txn(txn_key.clone(), response.clone()); - sender - .send(Some(response)) - .expect("couldn't send response to channel"); - services.transaction_ids.finish_federation_txn(&txn_key); - drop(sender); + .finish_federation_txn(txn_key, sender, response); } async fn handle( diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 993f455d..9792766d 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -379,6 +379,20 @@ pub struct Config { #[serde(default = "default_max_concurrent_inbound_transactions")] pub max_concurrent_inbound_transactions: usize, + /// Maximum age (in seconds) for cached federation transaction responses. + /// Entries older than this will be removed during cleanup. + /// + /// default: 7200 (2 hours) + #[serde(default = "default_transaction_id_cache_max_age_secs")] + pub transaction_id_cache_max_age_secs: u64, + + /// Maximum number of cached federation transaction responses. + /// When the cache exceeds this limit, older entries will be removed. + /// + /// default: 8192 + #[serde(default = "default_transaction_id_cache_max_entries")] + pub transaction_id_cache_max_entries: usize, + /// Default/base connection timeout (seconds). This is used only by URL /// previews and update/news endpoint checks. /// @@ -2553,6 +2567,10 @@ fn default_max_fetch_prev_events() -> u16 { 192_u16 } fn default_max_concurrent_inbound_transactions() -> usize { 150 } +fn default_transaction_id_cache_max_age_secs() -> u64 { 60 * 60 * 2 } + +fn default_transaction_id_cache_max_entries() -> usize { 8192 } + fn default_tracing_flame_filter() -> String { cfg!(debug_assertions) .then_some("trace,h2=off") diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index bb3fc466..35b83121 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -1,7 +1,14 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::HashMap, + sync::{ + Arc, + atomic::{AtomicU64, Ordering}, + }, + time::{Duration, SystemTime}, +}; use async_trait::async_trait; -use conduwuit::{Error, Result, SyncRwLock, debug, debug_warn, warn}; +use conduwuit::{Error, Result, SyncRwLock, debug_warn, warn}; use database::{Handle, Map}; use ruma::{ DeviceId, OwnedServerName, OwnedTransactionId, TransactionId, UserId, @@ -12,16 +19,58 @@ use ruma::{ }; use tokio::sync::watch::{Receiver, Sender}; +use crate::{Dep, config}; + pub type TxnKey = (OwnedServerName, OwnedTransactionId); pub type WrappedTransactionResponse = Option; -pub type ActiveTransactionsMap = HashMap>; + +/// Minimum interval between cache cleanup runs. +/// Exists to prevent thrashing when the cache is full of things that can't be +/// cleared +const CLEANUP_INTERVAL_SECS: u64 = 30; + +#[derive(Clone, Debug)] +pub struct CachedTxnResponse { + pub response: send_transaction_message::v1::Response, + pub created: SystemTime, +} + +/// Internal state for a federation transaction. +/// Either actively being processed or completed and cached. +#[derive(Clone)] +enum TxnState { + /// Transaction is currently being processed. + Active(Receiver), + + /// Transaction completed and response is cached. + Cached(CachedTxnResponse), +} + +/// Result of atomically checking or starting a federation transaction. +pub enum FederationTxnState { + /// Transaction already completed and cached + Cached(send_transaction_message::v1::Response), + + /// Transaction is currently being processed by another request. + /// Wait on this receiver for the result. + Active(Receiver), + + /// This caller should process the transaction (first to request it). + Started { + receiver: Receiver, + sender: Sender, + }, +} pub struct Service { + services: Services, db: Data, - servername_txnid_response_cache: - Arc>>, - servername_txnid_active: Arc>, - max_active_txns: usize, + federation_txn_state: Arc>>, + last_cleanup: AtomicU64, +} + +struct Services { + config: Dep, } struct Data { @@ -32,30 +81,35 @@ struct Data { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + services: Services { + config: args.depend::("config"), + }, db: Data { userdevicetxnid_response: args.db["userdevicetxnid_response"].clone(), }, - servername_txnid_response_cache: Arc::new(SyncRwLock::new(HashMap::new())), - servername_txnid_active: Arc::new(SyncRwLock::new(HashMap::new())), - max_active_txns: args - .depend::("config") - .max_concurrent_inbound_transactions, + federation_txn_state: Arc::new(SyncRwLock::new(HashMap::new())), + last_cleanup: AtomicU64::new(0), })) } async fn clear_cache(&self) { - let mut state = self.servername_txnid_response_cache.write(); - state.clear(); + let mut state = self.federation_txn_state.write(); + // Only clear cached entries, preserve active transactions + state.retain(|_, v| matches!(v, TxnState::Active(_))); } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { + /// Returns the count of currently active (in-progress) transactions. #[must_use] pub fn txn_active_handle_count(&self) -> usize { - let state = self.servername_txnid_active.read(); - state.len() + let state = self.federation_txn_state.read(); + state + .values() + .filter(|v| matches!(v, TxnState::Active(_))) + .count() } pub fn add_client_txnid( @@ -84,80 +138,169 @@ impl Service { self.db.userdevicetxnid_response.qry(&key).await } - /// Fetches a receiver channel for the given transaction, if any exists. - /// If the given txn is not active, None is returned. - #[must_use] - pub fn get_active_federation_txn( - &self, - key: &TxnKey, - ) -> Option> { - let state = self.servername_txnid_active.read(); - state.get(key).cloned() - } + /// Atomically gets a cached response, joins an active transaction, or + /// starts a new one. + pub fn get_or_start_federation_txn(&self, key: TxnKey) -> Result { + // Only one upgradable lock can be held at a time, and there aren't any + // read-only locks, so no point being upgradable + let mut state = self.federation_txn_state.write(); - /// Starts a new inbound transaction handler, returning the appropriate - /// sender to broadcast the response via. - /// - /// If the given key is already active, a rate-limited response is returned. - pub fn start_federation_txn( - &self, - key: TxnKey, - ) -> Result> { - let mut state = self.servername_txnid_active.write(); - if state.get(&key).is_some() { - debug!( - origin = ?key.0, - id = ?key.1, - "Origin re-sent already running transaction" - ); - Err(Error::BadRequest( - LimitExceeded { retry_after: None }, - "Transaction is already being handled", - )) - } else if state.keys().any(|k| k.0 == key.0) { + // Check existing state for this key + if let Some(txn_state) = state.get(&key) { + return Ok(match txn_state { + | TxnState::Cached(cached) => FederationTxnState::Cached(cached.response.clone()), + | TxnState::Active(receiver) => FederationTxnState::Active(receiver.clone()), + }); + } + + // Check if another transaction from this origin is already running + let has_active_from_origin = state + .iter() + .any(|(k, v)| k.0 == key.0 && matches!(v, TxnState::Active(_))); + + if has_active_from_origin { debug_warn!( origin = ?key.0, "Got concurrent transaction request from an origin with an active transaction" ); - Err(Error::BadRequest( + return Err(Error::BadRequest( LimitExceeded { retry_after: None }, "Still processing another transaction from this origin", - )) - } else if state.len() >= self.max_active_txns { + )); + } + + let max_active_txns = self.services.config.max_concurrent_inbound_transactions; + + // Check if we're at capacity + if state.len() >= max_active_txns + && let active_count = state + .values() + .filter(|v| matches!(v, TxnState::Active(_))) + .count() && active_count >= max_active_txns + { warn!( - active = state.len(), - max = self.max_active_txns, + active = active_count, + max = max_active_txns, "Server is overloaded, dropping incoming transaction" ); - Err(Error::BadRequest( + return Err(Error::BadRequest( LimitExceeded { retry_after: None }, "Server is overloaded, try again later", - )) - } else { - let (tx, rx) = tokio::sync::watch::channel(None); - state.insert(key, rx); - Ok(tx) + )); + } + + // Start new transaction + let (sender, receiver) = tokio::sync::watch::channel(None); + state.insert(key, TxnState::Active(receiver.clone())); + + Ok(FederationTxnState::Started { receiver, sender }) + } + + /// Finishes a transaction by transitioning it from active to cached state. + /// Additionally may trigger cleanup of old entries. + pub fn finish_federation_txn( + &self, + key: TxnKey, + sender: Sender>, + response: send_transaction_message::v1::Response, + ) { + // Check if cleanup might be needed before acquiring the lock + let should_try_cleanup = self.should_try_cleanup(); + + let mut state = self.federation_txn_state.write(); + + // Explicitly set cached first so there is no gap where receivers get a closed + // channel + state.insert( + key, + TxnState::Cached(CachedTxnResponse { + response: response.clone(), + created: SystemTime::now(), + }), + ); + + sender + .send(Some(response)) + .expect("couldn't send response to channel"); + + // explicitly close + drop(sender); + + // This task is dangling, we can try clean caches now + if should_try_cleanup { + self.cleanup_entries_locked(&mut state); } } - /// Finishes a transaction, removing it from the active txns registry. - pub fn finish_federation_txn(&self, key: &TxnKey) { - let mut state = self.servername_txnid_active.write(); + pub fn remove_federation_txn(&self, key: &TxnKey) { + let mut state = self.federation_txn_state.write(); state.remove(key); } - /// Gets a cached transaction response, if the given key has a value. - #[must_use] - pub fn get_cached_txn(&self, key: &TxnKey) -> Option { - let state = self.servername_txnid_response_cache.read(); - state.get(key).cloned() + /// Checks if enough time has passed since the last cleanup to consider + /// running another. Updates the last cleanup time if returning true. + fn should_try_cleanup(&self) -> bool { + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("SystemTime before UNIX_EPOCH") + .as_secs(); + let last = self.last_cleanup.load(Ordering::Relaxed); + + if now.saturating_sub(last) >= CLEANUP_INTERVAL_SECS { + // CAS: only update if no one else has updated it since we read + self.last_cleanup + .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + } else { + false + } } - /// Sets a cached transaction response. The existing key will be overwritten - /// if it exists. - pub fn set_cached_txn(&self, key: TxnKey, response: send_transaction_message::v1::Response) { - let mut state = self.servername_txnid_response_cache.write(); - // TODO: time-to-live? - state.insert(key, response); + /// Cleans up cached entries based on age and count limits. + /// + /// First removes all cached entries older than the configured max age. + /// Then, if the cache still exceeds the max entry count, removes the oldest + /// cached entries until the count is within limits. + /// + /// Must be called with write lock held on the state map. + fn cleanup_entries_locked(&self, state: &mut HashMap) { + let max_age_secs = self.services.config.transaction_id_cache_max_age_secs; + let max_entries = self.services.config.transaction_id_cache_max_entries; + + // First pass: remove all cached entries older than max age + let cutoff = SystemTime::now() + .checked_sub(Duration::from_secs(max_age_secs)) + .unwrap_or(SystemTime::UNIX_EPOCH); + + state.retain(|_, v| match v { + | TxnState::Active(_) => true, // Never remove active transactions + | TxnState::Cached(cached) => cached.created > cutoff, + }); + + // Count cached entries + let cached_count = state + .values() + .filter(|v| matches!(v, TxnState::Cached(_))) + .count(); + + // Second pass: if still over max entries, remove oldest cached entries + if cached_count > max_entries { + let excess = cached_count.saturating_sub(max_entries); + + // Collect cached entries sorted by age (oldest first) + let mut cached_entries: Vec<_> = state + .iter() + .filter_map(|(k, v)| match v { + | TxnState::Cached(cached) => Some((k.clone(), cached.created)), + | TxnState::Active(_) => None, + }) + .collect(); + cached_entries.sort_by(|a, b| a.1.cmp(&b.1)); + + // Remove the oldest cached entries to get under the limit + for (key, _) in cached_entries.into_iter().take(excess) { + state.remove(&key); + } + } } }