feat: Add helper functions for federation channels

This commit is contained in:
nexy7574 2026-02-21 00:35:48 +00:00 committed by timedout
parent 21a97cdd0b
commit 2f9956ddca
No known key found for this signature in database
GPG key ID: 0FA334385D0B689F

View file

@ -1,28 +1,34 @@
use std::{collections::HashMap, sync::Arc};
use conduwuit::{Result, SyncRwLock};
use async_trait::async_trait;
use conduwuit::{Error, Result, SyncRwLock};
use database::{Handle, Map};
use ruma::{
DeviceId, OwnedServerName, OwnedTransactionId, TransactionId, UserId,
api::federation::transactions::send_transaction_message,
api::{
client::error::ErrorKind::LimitExceeded,
federation::transactions::send_transaction_message,
},
};
use tokio::sync::watch::{Receiver, Sender};
pub type TxnKey = (OwnedServerName, OwnedTransactionId);
pub type TxnChanType = (TxnKey, send_transaction_message::v1::Response);
pub type ActiveTxnsMapType = HashMap<TxnKey, (Sender<TxnChanType>, Receiver<TxnChanType>)>;
pub type WrappedTransactionResponse = Option<send_transaction_message::v1::Response>;
pub type ActiveTransactionsMap = HashMap<TxnKey, Receiver<WrappedTransactionResponse>>;
pub struct Service {
db: Data,
pub servername_txnid_response_cache:
servername_txnid_response_cache:
Arc<SyncRwLock<HashMap<TxnKey, send_transaction_message::v1::Response>>>,
pub servername_txnid_active: Arc<SyncRwLock<ActiveTxnsMapType>>,
servername_txnid_active: Arc<SyncRwLock<ActiveTransactionsMap>>,
max_active_txns: usize,
}
struct Data {
userdevicetxnid_response: Arc<Map>,
}
#[async_trait]
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
@ -31,9 +37,15 @@ impl crate::Service for Service {
},
servername_txnid_response_cache: Arc::new(SyncRwLock::new(HashMap::new())),
servername_txnid_active: Arc::new(SyncRwLock::new(HashMap::new())),
max_active_txns: 50, // TODO: fetch from config
}))
}
async fn clear_cache(&self) {
let mut state = self.servername_txnid_response_cache.write();
state.clear();
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
@ -63,4 +75,67 @@ impl Service {
let key = (user_id, device_id, txn_id);
self.db.userdevicetxnid_response.qry(&key).await
}
/// Fetches a receiver channel for the given transaction, if any exists.
/// If the given txn is not active, None is returned.
#[must_use]
pub fn get_active_federation_txn(
&self,
key: &TxnKey,
) -> Option<Receiver<WrappedTransactionResponse>> {
let state = self.servername_txnid_active.read();
state.get(key).cloned()
}
/// Starts a new inbound transaction handler, returning the appropriate
/// sender to broadcast the response via.
///
/// If the given key is already active, a rate-limited response is returned.
pub fn start_federation_txn(
&self,
key: TxnKey,
) -> Result<Sender<WrappedTransactionResponse>> {
let mut state = self.servername_txnid_active.write();
if state.get(&key).is_some() {
Err(Error::BadRequest(
LimitExceeded { retry_after: None },
"Transaction is already being handled",
))
} else if state.keys().any(|k| k.0 == key.0) {
Err(Error::BadRequest(
LimitExceeded { retry_after: None },
"Still processing another transaction from this origin",
))
} else if state.len() >= self.max_active_txns {
Err(Error::BadRequest(
LimitExceeded { retry_after: None },
"Server is overloaded, try again later",
))
} else {
let (tx, rx) = tokio::sync::watch::channel(None);
state.insert(key, rx);
Ok(tx)
}
}
/// Finishes a transaction, removing it from the active txns registry.
pub fn finish_federation_txn(&self, key: &TxnKey) {
let mut state = self.servername_txnid_active.write();
state.remove(key);
}
/// Gets a cached transaction response, if the given key has a value.
#[must_use]
pub fn get_cached_txn(&self, key: &TxnKey) -> Option<send_transaction_message::v1::Response> {
let state = self.servername_txnid_response_cache.read();
state.get(key).cloned()
}
/// Sets a cached transaction response. The existing key will be overwritten
/// if it exists.
pub fn set_cached_txn(&self, key: TxnKey, response: send_transaction_message::v1::Response) {
let mut state = self.servername_txnid_response_cache.write();
// TODO: time-to-live?
state.insert(key, response);
}
}