Compare commits

...
Sign in to create a new pull request.

2 commits

3 changed files with 68 additions and 4 deletions

View file

@ -4,7 +4,7 @@ use axum_extra::{
headers::{Authorization, authorization::Bearer}, headers::{Authorization, authorization::Bearer},
typed_header::TypedHeaderRejectionReason, typed_header::TypedHeaderRejectionReason,
}; };
use conduwuit::{Err, Error, Result, debug_error, err, warn}; use conduwuit::{Err, Error, Result, debug_error, debug_info, err, warn};
use futures::{ use futures::{
TryFutureExt, TryFutureExt,
future::{ future::{
@ -329,6 +329,14 @@ async fn auth_server(
return Err!(Request(Forbidden("Failed to verify X-Matrix signatures."))); return Err!(Request(Forbidden("Failed to verify X-Matrix signatures.")));
} }
if services.sending.server_is_offline(destination).await {
debug_info!(?destination, "server returned from being offline");
services
.sending
.mark_server_online(destination, false)
.await;
}
Ok(Auth { Ok(Auth {
origin: origin.to_owned().into(), origin: origin.to_owned().into(),
sender_user: None, sender_user: None,

View file

@ -5,6 +5,7 @@ mod dest;
mod sender; mod sender;
use std::{ use std::{
collections::HashSet,
fmt::Debug, fmt::Debug,
hash::{DefaultHasher, Hash, Hasher}, hash::{DefaultHasher, Hash, Hasher},
iter::once, iter::once,
@ -19,8 +20,8 @@ use conduwuit::{
warn, warn,
}; };
use futures::{FutureExt, Stream, StreamExt}; use futures::{FutureExt, Stream, StreamExt};
use ruma::{RoomId, ServerName, UserId, api::OutgoingRequest}; use ruma::{OwnedServerName, RoomId, ServerName, UserId, api::OutgoingRequest};
use tokio::{task, task::JoinSet}; use tokio::{sync::RwLock, task, task::JoinSet};
use self::data::Data; use self::data::Data;
pub use self::{ pub use self::{
@ -37,6 +38,7 @@ pub struct Service {
server: Arc<Server>, server: Arc<Server>,
services: Services, services: Services,
channels: Vec<(loole::Sender<Msg>, loole::Receiver<Msg>)>, channels: Vec<(loole::Sender<Msg>, loole::Receiver<Msg>)>,
pub offline_servers: RwLock<HashSet<OwnedServerName>>,
} }
struct Services { struct Services {
@ -52,6 +54,7 @@ struct Services {
account_data: Dep<account_data::Service>, account_data: Dep<account_data::Service>,
appservice: Dep<crate::appservice::Service>, appservice: Dep<crate::appservice::Service>,
pusher: Dep<pusher::Service>, pusher: Dep<pusher::Service>,
resolver: Dep<crate::resolver::Service>,
federation: Dep<federation::Service>, federation: Dep<federation::Service>,
} }
@ -96,9 +99,11 @@ impl crate::Service for Service {
account_data: args.depend::<account_data::Service>("account_data"), account_data: args.depend::<account_data::Service>("account_data"),
appservice: args.depend::<crate::appservice::Service>("appservice"), appservice: args.depend::<crate::appservice::Service>("appservice"),
pusher: args.depend::<pusher::Service>("pusher"), pusher: args.depend::<pusher::Service>("pusher"),
resolver: args.depend::<crate::resolver::Service>("resolver"),
federation: args.depend::<federation::Service>("federation"), federation: args.depend::<federation::Service>("federation"),
}, },
channels: (0..num_senders).map(|_| loole::unbounded()).collect(), channels: (0..num_senders).map(|_| loole::unbounded()).collect(),
offline_servers: RwLock::new(HashSet::new()),
})) }))
} }
@ -146,6 +151,8 @@ impl crate::Service for Service {
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
fn unconstrained(&self) -> bool { true } fn unconstrained(&self) -> bool { true }
async fn clear_cache(&self) { self.offline_servers.write().await.clear(); }
} }
impl Service { impl Service {
@ -379,6 +386,39 @@ impl Service {
let chans = self.channels.len().max(1); let chans = self.channels.len().max(1);
hash.overflowing_rem(chans).0 hash.overflowing_rem(chans).0
} }
/// Marks a server as offline
pub async fn mark_server_offline(&self, server: OwnedServerName) {
self.offline_servers.write().await.insert(server);
}
/// Marks a server as online again and flushes the senders if it was
/// previously marked as offline
pub async fn mark_server_online(&self, server: &ServerName, skip_flush: bool) {
if self.offline_servers.write().await.remove(server) && !skip_flush {
// Flush the senders if this server was previously offline
self.services.resolver.cache.del_destination(server);
self.services.resolver.cache.del_override(server);
self.dispatch(Msg {
dest: Destination::Federation(server.to_owned()),
event: SendingEvent::Flush,
queue_id: Vec::<u8>::new(),
})
.inspect_err(|e| {
error!(
?server,
?e,
"failed to dispatch flush message for server coming back online"
);
})
.ok();
}
}
/// Checks if a server is currently marked as offline
pub async fn server_is_offline(&self, server: &ServerName) -> bool {
self.offline_servers.read().await.contains(server)
}
} }
fn num_senders(args: &crate::Args<'_>) -> usize { fn num_senders(args: &crate::Args<'_>) -> usize {

View file

@ -9,6 +9,7 @@ use std::{
}; };
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use conduwuit::debug_warn;
use conduwuit_core::{ use conduwuit_core::{
Error, Event, Result, debug, err, error, Error, Event, Result, debug, err, error,
result::LogErr, result::LogErr,
@ -135,7 +136,13 @@ impl Service {
) { ) {
match response { match response {
| Ok(dest) => self.handle_response_ok(&dest, futures, statuses).await, | Ok(dest) => self.handle_response_ok(&dest, futures, statuses).await,
| Err((dest, e)) => Self::handle_response_err(dest, statuses, &e), | Err((dest, e)) => {
Self::handle_response_err(dest.clone(), statuses, &e);
if let Destination::Federation(server_name) = dest {
debug_warn!(?server_name, "marking server offline due to error: {e:?}");
self.mark_server_offline(server_name).await;
}
},
} }
} }
@ -180,6 +187,12 @@ impl Service {
} else { } else {
statuses.remove(dest); statuses.remove(dest);
} }
if let Destination::Federation(server_name) = dest {
self.mark_server_online(server_name, true).await;
// We skip the flush here because we were already able to contact
// the server, and have queued any pending events, and the
// resolver cache will be fine.
}
} }
#[allow(clippy::needless_pass_by_ref_mut)] #[allow(clippy::needless_pass_by_ref_mut)]
@ -190,6 +203,9 @@ impl Service {
futures: &mut SendingFutures<'a>, futures: &mut SendingFutures<'a>,
statuses: &mut CurTransactionStatus, statuses: &mut CurTransactionStatus,
) { ) {
if msg.event == SendingEvent::Flush {
statuses.remove(&msg.dest);
}
let iv = vec![(msg.queue_id, msg.event)]; let iv = vec![(msg.queue_id, msg.event)];
if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses).await { if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses).await {
if !events.is_empty() { if !events.is_empty() {