diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index 9707dbb9..1c78396f 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -1,28 +1,34 @@ use std::{collections::HashMap, sync::Arc}; -use conduwuit::{Result, SyncRwLock}; +use async_trait::async_trait; +use conduwuit::{Error, Result, SyncRwLock}; use database::{Handle, Map}; use ruma::{ DeviceId, OwnedServerName, OwnedTransactionId, TransactionId, UserId, - api::federation::transactions::send_transaction_message, + api::{ + client::error::ErrorKind::LimitExceeded, + federation::transactions::send_transaction_message, + }, }; use tokio::sync::watch::{Receiver, Sender}; pub type TxnKey = (OwnedServerName, OwnedTransactionId); -pub type TxnChanType = (TxnKey, send_transaction_message::v1::Response); -pub type ActiveTxnsMapType = HashMap, Receiver)>; +pub type WrappedTransactionResponse = Option; +pub type ActiveTransactionsMap = HashMap>; pub struct Service { db: Data, - pub servername_txnid_response_cache: + servername_txnid_response_cache: Arc>>, - pub servername_txnid_active: Arc>, + servername_txnid_active: Arc>, + max_active_txns: usize, } struct Data { userdevicetxnid_response: Arc, } +#[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { @@ -31,9 +37,15 @@ impl crate::Service for Service { }, servername_txnid_response_cache: Arc::new(SyncRwLock::new(HashMap::new())), servername_txnid_active: Arc::new(SyncRwLock::new(HashMap::new())), + max_active_txns: 50, // TODO: fetch from config })) } + async fn clear_cache(&self) { + let mut state = self.servername_txnid_response_cache.write(); + state.clear(); + } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } @@ -63,4 +75,67 @@ impl Service { let key = (user_id, device_id, txn_id); self.db.userdevicetxnid_response.qry(&key).await } + + /// Fetches a receiver channel for the given transaction, if any exists. + /// If the given txn is not active, None is returned. + #[must_use] + pub fn get_active_federation_txn( + &self, + key: &TxnKey, + ) -> Option> { + let state = self.servername_txnid_active.read(); + state.get(key).cloned() + } + + /// Starts a new inbound transaction handler, returning the appropriate + /// sender to broadcast the response via. + /// + /// If the given key is already active, a rate-limited response is returned. + pub fn start_federation_txn( + &self, + key: TxnKey, + ) -> Result> { + let mut state = self.servername_txnid_active.write(); + if state.get(&key).is_some() { + Err(Error::BadRequest( + LimitExceeded { retry_after: None }, + "Transaction is already being handled", + )) + } else if state.keys().any(|k| k.0 == key.0) { + Err(Error::BadRequest( + LimitExceeded { retry_after: None }, + "Still processing another transaction from this origin", + )) + } else if state.len() >= self.max_active_txns { + Err(Error::BadRequest( + LimitExceeded { retry_after: None }, + "Server is overloaded, try again later", + )) + } else { + let (tx, rx) = tokio::sync::watch::channel(None); + state.insert(key, rx); + Ok(tx) + } + } + + /// Finishes a transaction, removing it from the active txns registry. + pub fn finish_federation_txn(&self, key: &TxnKey) { + let mut state = self.servername_txnid_active.write(); + state.remove(key); + } + + /// Gets a cached transaction response, if the given key has a value. + #[must_use] + pub fn get_cached_txn(&self, key: &TxnKey) -> Option { + let state = self.servername_txnid_response_cache.read(); + state.get(key).cloned() + } + + /// Sets a cached transaction response. The existing key will be overwritten + /// if it exists. + pub fn set_cached_txn(&self, key: TxnKey, response: send_transaction_message::v1::Response) { + let mut state = self.servername_txnid_response_cache.write(); + // TODO: time-to-live? + state.insert(key, response); + } }