Compare commits

...
Sign in to create a new pull request.

3 commits

Author SHA1 Message Date
timedout
6c96945b0a
chore: Add news fragment 2026-03-04 05:42:19 +00:00
timedout
6f103939df
feat: Update policy server implementation to be closer to stable MSC4284
Untested
2026-03-04 05:36:19 +00:00
timedout
b133955582
chore: Bump ruwuma to update PS definitions 2026-03-04 03:19:00 +00:00
8 changed files with 245 additions and 174 deletions

29
Cargo.lock generated
View file

@ -1113,6 +1113,7 @@ dependencies = [
"conduwuit_core", "conduwuit_core",
"conduwuit_database", "conduwuit_database",
"const-str", "const-str",
"ed25519-dalek",
"either", "either",
"futures", "futures",
"hickory-resolver", "hickory-resolver",
@ -1221,7 +1222,7 @@ dependencies = [
[[package]] [[package]]
name = "continuwuity-admin-api" name = "continuwuity-admin-api"
version = "0.1.0" version = "0.1.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=bb12ed288a31a23aa11b10ba0fad22b7f985eb88#bb12ed288a31a23aa11b10ba0fad22b7f985eb88" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=6c65b295d2c109ab31165c2db016097f3e74d02e#6c65b295d2c109ab31165c2db016097f3e74d02e"
dependencies = [ dependencies = [
"ruma-common", "ruma-common",
"serde", "serde",
@ -1600,7 +1601,7 @@ dependencies = [
[[package]] [[package]]
name = "draupnir-antispam" name = "draupnir-antispam"
version = "0.1.0" version = "0.1.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=bb12ed288a31a23aa11b10ba0fad22b7f985eb88#bb12ed288a31a23aa11b10ba0fad22b7f985eb88" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=6c65b295d2c109ab31165c2db016097f3e74d02e#6c65b295d2c109ab31165c2db016097f3e74d02e"
dependencies = [ dependencies = [
"ruma-common", "ruma-common",
"serde", "serde",
@ -3002,7 +3003,7 @@ checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
[[package]] [[package]]
name = "meowlnir-antispam" name = "meowlnir-antispam"
version = "0.1.0" version = "0.1.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=bb12ed288a31a23aa11b10ba0fad22b7f985eb88#bb12ed288a31a23aa11b10ba0fad22b7f985eb88" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=6c65b295d2c109ab31165c2db016097f3e74d02e#6c65b295d2c109ab31165c2db016097f3e74d02e"
dependencies = [ dependencies = [
"ruma-common", "ruma-common",
"serde", "serde",
@ -4096,7 +4097,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma" name = "ruma"
version = "0.10.1" version = "0.10.1"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=bb12ed288a31a23aa11b10ba0fad22b7f985eb88#bb12ed288a31a23aa11b10ba0fad22b7f985eb88" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=6c65b295d2c109ab31165c2db016097f3e74d02e#6c65b295d2c109ab31165c2db016097f3e74d02e"
dependencies = [ dependencies = [
"assign", "assign",
"continuwuity-admin-api", "continuwuity-admin-api",
@ -4119,7 +4120,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-appservice-api" name = "ruma-appservice-api"
version = "0.10.0" version = "0.10.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=bb12ed288a31a23aa11b10ba0fad22b7f985eb88#bb12ed288a31a23aa11b10ba0fad22b7f985eb88" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=6c65b295d2c109ab31165c2db016097f3e74d02e#6c65b295d2c109ab31165c2db016097f3e74d02e"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-common", "ruma-common",
@ -4131,7 +4132,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-client-api" name = "ruma-client-api"
version = "0.18.0" version = "0.18.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=bb12ed288a31a23aa11b10ba0fad22b7f985eb88#bb12ed288a31a23aa11b10ba0fad22b7f985eb88" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=6c65b295d2c109ab31165c2db016097f3e74d02e#6c65b295d2c109ab31165c2db016097f3e74d02e"
dependencies = [ dependencies = [
"as_variant", "as_variant",
"assign", "assign",
@ -4154,7 +4155,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-common" name = "ruma-common"
version = "0.13.0" version = "0.13.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=bb12ed288a31a23aa11b10ba0fad22b7f985eb88#bb12ed288a31a23aa11b10ba0fad22b7f985eb88" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=6c65b295d2c109ab31165c2db016097f3e74d02e#6c65b295d2c109ab31165c2db016097f3e74d02e"
dependencies = [ dependencies = [
"as_variant", "as_variant",
"base64 0.22.1", "base64 0.22.1",
@ -4186,7 +4187,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-events" name = "ruma-events"
version = "0.28.1" version = "0.28.1"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=bb12ed288a31a23aa11b10ba0fad22b7f985eb88#bb12ed288a31a23aa11b10ba0fad22b7f985eb88" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=6c65b295d2c109ab31165c2db016097f3e74d02e#6c65b295d2c109ab31165c2db016097f3e74d02e"
dependencies = [ dependencies = [
"as_variant", "as_variant",
"indexmap", "indexmap",
@ -4211,7 +4212,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-federation-api" name = "ruma-federation-api"
version = "0.9.0" version = "0.9.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=bb12ed288a31a23aa11b10ba0fad22b7f985eb88#bb12ed288a31a23aa11b10ba0fad22b7f985eb88" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=6c65b295d2c109ab31165c2db016097f3e74d02e#6c65b295d2c109ab31165c2db016097f3e74d02e"
dependencies = [ dependencies = [
"bytes", "bytes",
"headers", "headers",
@ -4233,7 +4234,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-identifiers-validation" name = "ruma-identifiers-validation"
version = "0.9.5" version = "0.9.5"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=bb12ed288a31a23aa11b10ba0fad22b7f985eb88#bb12ed288a31a23aa11b10ba0fad22b7f985eb88" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=6c65b295d2c109ab31165c2db016097f3e74d02e#6c65b295d2c109ab31165c2db016097f3e74d02e"
dependencies = [ dependencies = [
"js_int", "js_int",
"thiserror 2.0.18", "thiserror 2.0.18",
@ -4242,7 +4243,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-identity-service-api" name = "ruma-identity-service-api"
version = "0.9.0" version = "0.9.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=bb12ed288a31a23aa11b10ba0fad22b7f985eb88#bb12ed288a31a23aa11b10ba0fad22b7f985eb88" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=6c65b295d2c109ab31165c2db016097f3e74d02e#6c65b295d2c109ab31165c2db016097f3e74d02e"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-common", "ruma-common",
@ -4252,7 +4253,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-macros" name = "ruma-macros"
version = "0.13.0" version = "0.13.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=bb12ed288a31a23aa11b10ba0fad22b7f985eb88#bb12ed288a31a23aa11b10ba0fad22b7f985eb88" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=6c65b295d2c109ab31165c2db016097f3e74d02e#6c65b295d2c109ab31165c2db016097f3e74d02e"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"proc-macro-crate", "proc-macro-crate",
@ -4267,7 +4268,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-push-gateway-api" name = "ruma-push-gateway-api"
version = "0.9.0" version = "0.9.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=bb12ed288a31a23aa11b10ba0fad22b7f985eb88#bb12ed288a31a23aa11b10ba0fad22b7f985eb88" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=6c65b295d2c109ab31165c2db016097f3e74d02e#6c65b295d2c109ab31165c2db016097f3e74d02e"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-common", "ruma-common",
@ -4279,7 +4280,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-signatures" name = "ruma-signatures"
version = "0.15.0" version = "0.15.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=bb12ed288a31a23aa11b10ba0fad22b7f985eb88#bb12ed288a31a23aa11b10ba0fad22b7f985eb88" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=6c65b295d2c109ab31165c2db016097f3e74d02e#6c65b295d2c109ab31165c2db016097f3e74d02e"
dependencies = [ dependencies = [
"base64 0.22.1", "base64 0.22.1",
"ed25519-dalek", "ed25519-dalek",

View file

@ -344,7 +344,7 @@ version = "0.1.2"
[workspace.dependencies.ruma] [workspace.dependencies.ruma]
git = "https://forgejo.ellis.link/continuwuation/ruwuma" git = "https://forgejo.ellis.link/continuwuation/ruwuma"
#branch = "conduwuit-changes" #branch = "conduwuit-changes"
rev = "bb12ed288a31a23aa11b10ba0fad22b7f985eb88" rev = "6c65b295d2c109ab31165c2db016097f3e74d02e"
features = [ features = [
"compat", "compat",
"rand", "rand",

1
changelog.d/1487.feature Normal file
View file

@ -0,0 +1 @@
Updated [MSC4284: Policy Servers](https://github.com/matrix-org/matrix-spec-proposals/pull/4284) implementation to support the newly stabilised proposal. Contributed by @nex.

View file

@ -4,7 +4,9 @@ mod panic;
mod response; mod response;
mod serde; mod serde;
use std::{any::Any, borrow::Cow, convert::Infallible, sync::PoisonError}; use std::{any::Any, borrow::Cow, convert::Infallible, sync::PoisonError, time::Duration};
use ruma::api::client::error::{ErrorKind, RetryAfter::Delay};
pub use self::{err::visit, log::*}; pub use self::{err::visit, log::*};
@ -91,7 +93,7 @@ pub enum Error {
#[error("Arithmetic operation failed: {0}")] #[error("Arithmetic operation failed: {0}")]
Arithmetic(Cow<'static, str>), Arithmetic(Cow<'static, str>),
#[error("{0}: {1}")] #[error("{0}: {1}")]
BadRequest(ruma::api::client::error::ErrorKind, &'static str), //TODO: remove BadRequest(ErrorKind, &'static str), //TODO: remove
#[error("{0}")] #[error("{0}")]
BadServerResponse(Cow<'static, str>), BadServerResponse(Cow<'static, str>),
#[error(transparent)] #[error(transparent)]
@ -121,7 +123,7 @@ pub enum Error {
#[error("from {0}: {1}")] #[error("from {0}: {1}")]
Redaction(ruma::OwnedServerName, ruma::canonical_json::RedactionError), Redaction(ruma::OwnedServerName, ruma::canonical_json::RedactionError),
#[error("{0}: {1}")] #[error("{0}: {1}")]
Request(ruma::api::client::error::ErrorKind, Cow<'static, str>, http::StatusCode), Request(ErrorKind, Cow<'static, str>, http::StatusCode),
#[error(transparent)] #[error(transparent)]
Ruma(#[from] ruma::api::client::error::Error), Ruma(#[from] ruma::api::client::error::Error),
#[error(transparent)] #[error(transparent)]
@ -166,7 +168,7 @@ impl Error {
/// Returns the Matrix error code / error kind /// Returns the Matrix error code / error kind
#[inline] #[inline]
pub fn kind(&self) -> ruma::api::client::error::ErrorKind { pub fn kind(&self) -> ErrorKind {
use ruma::api::client::error::ErrorKind::{FeatureDisabled, Unknown}; use ruma::api::client::error::ErrorKind::{FeatureDisabled, Unknown};
match self { match self {
@ -201,6 +203,16 @@ impl Error {
/// Result where Ok(None) is instead Err(e) if e.is_not_found(). /// Result where Ok(None) is instead Err(e) if e.is_not_found().
#[inline] #[inline]
pub fn is_not_found(&self) -> bool { self.status_code() == http::StatusCode::NOT_FOUND } pub fn is_not_found(&self) -> bool { self.status_code() == http::StatusCode::NOT_FOUND }
pub fn retry_after(&self) -> Option<Duration> {
match self {
| Self::BadRequest(
ErrorKind::LimitExceeded { retry_after: Some(Delay(retry_after)) },
..,
) => Some(*retry_after),
| _ => None,
}
}
} }
impl std::fmt::Debug for Error { impl std::fmt::Debug for Error {

View file

@ -123,6 +123,7 @@ blurhash.workspace = true
blurhash.optional = true blurhash.optional = true
recaptcha-verify = { version = "0.1.5", default-features = false } recaptcha-verify = { version = "0.1.5", default-features = false }
yansi.workspace = true yansi.workspace = true
ed25519-dalek = "2.2.0"
[target.'cfg(all(unix, target_os = "linux"))'.dependencies] [target.'cfg(all(unix, target_os = "linux"))'.dependencies]
sd-notify.workspace = true sd-notify.workspace = true

View file

@ -6,18 +6,63 @@
use std::{collections::BTreeMap, time::Duration}; use std::{collections::BTreeMap, time::Duration};
use conduwuit::{ use conduwuit::{
Err, Event, PduEvent, Result, debug, debug_error, debug_info, debug_warn, implement, trace, Err, Error, Event, PduEvent, Result, debug, debug_error, debug_info, error, implement, info,
warn, trace, warn,
}; };
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
use ruma::{ use ruma::{
CanonicalJsonObject, CanonicalJsonValue, KeyId, RoomId, ServerName, SigningKeyId, CanonicalJsonObject, CanonicalJsonValue, KeyId, RoomId, ServerName,
api::federation::room::{ api::federation::room::policy_sign::unstable::Request as PolicySignRequest,
policy_check::unstable::Request as PolicyCheckRequest,
policy_sign::unstable::Request as PolicySignRequest,
},
events::{StateEventType, room::policy::RoomPolicyEventContent}, events::{StateEventType, room::policy::RoomPolicyEventContent},
serde::{Base64, base64::UrlSafe},
signatures::canonical_json,
}; };
use serde_json::value::RawValue; use serde_json::value::RawValue;
use tokio::time::sleep;
pub(super) fn verify_policy_signature(
via: &ServerName,
ps_key: &Base64<UrlSafe, Vec<u8>>,
pdu_json: &CanonicalJsonObject,
) -> bool {
let signature = pdu_json
.get("signatures")
.and_then(|sigs| sigs.as_object())
.and_then(|sigs_map| sigs_map.get(via.as_str()))
.and_then(|sigs_for_server| sigs_for_server.as_object())
.and_then(|sigs_for_server_map| sigs_for_server_map.get("ed25519:policy_server"))
.and_then(|sig| sig.as_str())
.and_then(|sig_str| Base64::<UrlSafe, Vec<u8>>::parse(sig_str).ok())
.and_then(|sig_b64| {
Signature::from_slice(sig_b64.as_bytes())
.map(Some)
.unwrap_or(None)
});
let vk = match VerifyingKey::try_from(ps_key.as_bytes()) {
| Ok(vk) => vk,
| Err(e) => {
debug!(
error=%e,
"Failed to parse policy server public key; cannot verify signature"
);
return false;
},
};
let cj = match canonical_json(pdu_json.clone()) {
| Ok(cj) => cj,
| Err(e) => {
debug!(
error=%e,
"Failed to convert event JSON to canonical form; cannot verify policy server signature"
);
return false;
},
};
match signature {
| Some(ref sig) => vk.verify(cj.as_bytes(), sig).is_ok(),
| None => false,
}
}
/// Asks a remote policy server if the event is allowed. /// Asks a remote policy server if the event is allowed.
/// ///
@ -31,29 +76,24 @@ use serde_json::value::RawValue;
/// contacted for whatever reason, Err(e) is returned, which generally is a /// contacted for whatever reason, Err(e) is returned, which generally is a
/// fail-open operation. /// fail-open operation.
#[implement(super::Service)] #[implement(super::Service)]
#[tracing::instrument(skip(self, pdu, pdu_json, room_id), level = "info")] #[tracing::instrument(skip(self, pdu, pdu_json), level = "info")]
pub async fn ask_policy_server( pub async fn policy_server_allows_event(
&self, &self,
pdu: &PduEvent, pdu: &PduEvent,
pdu_json: &mut CanonicalJsonObject, pdu_json: &mut CanonicalJsonObject,
room_id: &RoomId, room_id: &RoomId,
incoming: bool, incoming: bool,
) -> Result<bool> { ) -> Result<()> {
if !self.services.server.config.enable_msc4284_policy_servers {
trace!("policy server checking is disabled");
return Ok(true); // don't ever contact policy servers
}
if *pdu.event_type() == StateEventType::RoomPolicy.into() { if *pdu.event_type() == StateEventType::RoomPolicy.into() {
debug!( debug!(
room_id = %room_id, room_id = %room_id,
event_type = ?pdu.event_type(), event_type = ?pdu.event_type(),
"Skipping spam check for policy server meta-event" "Skipping spam check for policy server meta-event"
); );
return Ok(true); return Ok(());
} }
let Ok(policyserver) = self let Ok(ps) = self
.services .services
.state_accessor .state_accessor
.room_state_get_content(room_id, &StateEventType::RoomPolicy, "") .room_state_get_content(room_id, &StateEventType::RoomPolicy, "")
@ -65,128 +105,144 @@ pub async fn ask_policy_server(
}) })
.map(|c: RoomPolicyEventContent| c) .map(|c: RoomPolicyEventContent| c)
else { else {
debug!("room has no policy server configured"); debug!("room has no policy server configured, skipping spam check");
return Ok(true); return Ok(());
}; };
if self.services.server.config.policy_server_check_own_events let ps_key = match ps.effective_key() {
&& !incoming | Ok(key) => key,
&& policyserver.public_key.is_none() | Err(e) => {
{ debug!(
// don't contact policy servers for locally generated events, but only when the error=%e,
// policy server does not require signatures "room has a policy server configured, but no valid public keys; skipping spam check"
trace!("won't contact policy server for locally generated event"); );
return Ok(true); return Ok(());
}
let via = match policyserver.via {
| Some(ref via) => ServerName::parse(via)?,
| None => {
trace!("No policy server configured for room {room_id}");
return Ok(true);
}, },
}; };
let Some(via) = ps
.via
.as_ref()
.and_then(|via| ServerName::parse(via).map(Some).unwrap_or(None))
else {
trace!("No via configured for room policy server, skipping spam check");
return Ok(());
};
if via.is_empty() { if via.is_empty() {
trace!("Policy server is empty for room {room_id}, skipping spam check"); trace!("Policy server is empty for room {room_id}, skipping spam check");
return Ok(true); return Ok(());
} }
if !self.services.state_cache.server_in_room(via, room_id).await { if !self.services.state_cache.server_in_room(via, room_id).await {
debug!( debug!(
via = %via, via = %via,
"Policy server is not in the room, skipping spam check" "Policy server is not in the room, skipping spam check"
); );
return Ok(true); return Ok(());
} }
if incoming {
// Verify the signature instead of calling a check
if verify_policy_signature(via, &ps_key, pdu_json) {
debug!(
via = %via,
"Event is incoming and has a valid policy server signature"
);
return Ok(());
}
debug_info!(
via = %via,
"Event is incoming but does not have a valid policy server signature; asking policy \
server to sign it now"
);
}
let outgoing = self let outgoing = self
.services .services
.sending .sending
.convert_to_outgoing_federation_event(pdu_json.clone()) .convert_to_outgoing_federation_event(pdu_json.clone())
.await; .await;
if policyserver.public_key.is_some() {
if !incoming { info!(
debug_info!(
via = %via,
outgoing = ?pdu_json,
"Getting policy server signature on event"
);
return self
.fetch_policy_server_signature(pdu, pdu_json, via, outgoing, room_id)
.await;
}
// for incoming events, is it signed by <via> with the key
// "ed25519:policy_server"?
if let Some(CanonicalJsonValue::Object(sigs)) = pdu_json.get("signatures") {
if let Some(CanonicalJsonValue::Object(server_sigs)) = sigs.get(via.as_str()) {
let wanted_key_id: &KeyId<ruma::SigningKeyAlgorithm, ruma::Base64PublicKey> =
SigningKeyId::parse("ed25519:policy_server")?;
if let Some(CanonicalJsonValue::String(_sig_value)) =
server_sigs.get(wanted_key_id.as_str())
{
// TODO: verify signature
}
}
}
debug!(
"Event is not local and has no policy server signature, performing legacy spam check"
);
}
debug_info!(
via = %via, via = %via,
"Checking event for spam with policy server via legacy check" "Asking policy server to sign event"
); );
let response = tokio::time::timeout( self.fetch_policy_server_signature(pdu, pdu_json, via, outgoing, room_id, ps_key, 0)
Duration::from_secs(self.services.server.config.policy_server_request_timeout), .await
self.services }
.sending #[allow(clippy::too_many_arguments)]
.send_federation_request(via, PolicyCheckRequest { #[implement(super::Service)]
event_id: pdu.event_id().to_owned(), async fn handle_policy_server_error(
pdu: Some(outgoing), &self,
}), error: Error,
) pdu: &PduEvent,
.await; pdu_json: &mut CanonicalJsonObject,
let response = match response { via: &ServerName,
| Ok(Ok(response)) => { outgoing: Box<RawValue>,
debug!("Response from policy server: {:?}", response); room_id: &RoomId,
response policy_server_key: Base64<UrlSafe, Vec<u8>>,
}, retries: u8,
| Ok(Err(e)) => { timeout: Duration,
) -> Result<()> {
if let Some(retry_after) = error.retry_after() {
if retries >= 3 {
warn!( warn!(
via = %via, via = %via,
event_id = %pdu.event_id(), event_id = %pdu.event_id(),
room_id = %room_id, room_id = %room_id,
"Failed to contact policy server: {e}" retries,
"Policy server rate-limited us too many times; giving up"
); );
// Network or policy server errors are treated as non-fatal: event is allowed by return Err(error); // Error should be passed to c2s
// default. }
return Err(e); let saturated = retry_after.min(timeout);
}, // ^ don't wait more than 60 seconds
| Err(elapsed) => {
warn!(
%via,
event_id = %pdu.event_id(),
%room_id,
%elapsed,
"Policy server request timed out after 10 seconds"
);
return Err!("Request to policy server timed out");
},
};
trace!("Recommendation from policy server was {}", response.recommendation);
if response.recommendation == "spam" {
warn!( warn!(
via = %via, via = %via,
event_id = %pdu.event_id(), event_id = %pdu.event_id(),
room_id = %room_id, room_id = %room_id,
"Event was marked as spam by policy server", retry_after = %saturated.as_secs(),
retries,
"Policy server rate-limited us; retrying after {retry_after:?}"
);
// TODO: select between this sleep and shutdown signal
sleep(saturated).await;
return Box::pin(self.fetch_policy_server_signature(
pdu,
pdu_json,
via,
outgoing,
room_id,
policy_server_key,
retries.saturating_add(1),
))
.await;
}
if error.status_code().is_client_error() {
warn!(
via = %via,
event_id = %pdu.event_id(),
room_id = %room_id,
error = ?error,
"Policy server marked the event as spam"
);
} else {
info!(
via = %via,
event_id = %pdu.event_id(),
room_id = %room_id,
error = ?error,
"Failed to contact policy server"
); );
return Ok(false);
} }
Ok(true) Err(error)
} }
/// Asks a remote policy server for a signature on this event. /// Asks a remote policy server for a signature on this event.
/// If the policy server signs this event, the original data is mutated. /// If the policy server signs this event, the original data is mutated.
#[allow(clippy::too_many_arguments)]
#[implement(super::Service)] #[implement(super::Service)]
#[tracing::instrument(skip_all, fields(event_id=%pdu.event_id(), via=%via), level = "info")] #[tracing::instrument(skip_all, fields(event_id=%pdu.event_id(), via=%via), level = "info")]
pub async fn fetch_policy_server_signature( pub async fn fetch_policy_server_signature(
@ -196,13 +252,16 @@ pub async fn fetch_policy_server_signature(
via: &ServerName, via: &ServerName,
outgoing: Box<RawValue>, outgoing: Box<RawValue>,
room_id: &RoomId, room_id: &RoomId,
) -> Result<bool> { policy_server_key: Base64<UrlSafe, Vec<u8>>,
retries: u8,
) -> Result<()> {
let timeout = Duration::from_secs(self.services.server.config.policy_server_request_timeout);
debug!("Requesting policy server signature"); debug!("Requesting policy server signature");
let response = tokio::time::timeout( let response = tokio::time::timeout(
Duration::from_secs(self.services.server.config.policy_server_request_timeout), timeout,
self.services self.services
.sending .sending
.send_federation_request(via, PolicySignRequest { pdu: outgoing }), .send_federation_request(via, PolicySignRequest { pdu: outgoing.clone() }),
) )
.await; .await;
@ -212,15 +271,19 @@ pub async fn fetch_policy_server_signature(
response response
}, },
| Ok(Err(e)) => { | Ok(Err(e)) => {
warn!( return self
via = %via, .handle_policy_server_error(
event_id = %pdu.event_id(), e,
room_id = %room_id, pdu,
"Failed to contact policy server: {e}" pdu_json,
); via,
// Network or policy server errors are treated as non-fatal: event is allowed by outgoing,
// default. room_id,
return Err(e); policy_server_key,
retries,
timeout,
)
.await;
}, },
| Err(elapsed) => { | Err(elapsed) => {
warn!( warn!(
@ -228,34 +291,34 @@ pub async fn fetch_policy_server_signature(
event_id = %pdu.event_id(), event_id = %pdu.event_id(),
%room_id, %room_id,
%elapsed, %elapsed,
"Policy server request timed out after 10 seconds" "Policy server signature request timed out"
); );
return Err!("Request to policy server timed out"); return Err!(Request(Forbidden("Policy server did not respond in time")));
}, },
}; };
if response.signatures.is_none() {
debug!("Policy server refused to sign event"); if !response.signatures.contains_key(via) {
return Ok(false); error!(
}
let sigs: ruma::Signatures<ruma::OwnedServerName, ruma::ServerSigningKeyVersion> =
response.signatures.unwrap();
if !sigs.contains_key(via) {
debug_warn!(
"Policy server returned signatures, but did not include the expected server name \ "Policy server returned signatures, but did not include the expected server name \
'{}': {:?}", '{}': {:?}",
via, via, response.signatures
sigs
); );
return Ok(false); return Err!(BadServerResponse(
"Policy server did not include expected server name in signatures"
));
} }
let keypairs = sigs.get(via).unwrap(); let keypairs = response.signatures.get(via).unwrap();
// TODO: need to be able to verify other algorithms
let wanted_key_id = KeyId::parse("ed25519:policy_server")?; let wanted_key_id = KeyId::parse("ed25519:policy_server")?;
if !keypairs.contains_key(wanted_key_id) { if !keypairs.contains_key(wanted_key_id) {
debug_warn!( error!(
signatures = ?response.signatures,
"Policy server returned signature, but did not use the key ID \ "Policy server returned signature, but did not use the key ID \
'ed25519:policy_server'." 'ed25519:policy_server'."
); );
return Ok(false); return Err!(BadServerResponse(
"Policy server signed the event, but did not use the expected key ID"
));
} }
let signatures_entry = pdu_json let signatures_entry = pdu_json
.entry("signatures".to_owned()) .entry("signatures".to_owned())
@ -273,12 +336,12 @@ pub async fn fetch_policy_server_signature(
); );
}, },
| Some(_) => { | Some(_) => {
debug_warn!( // This should never happen
unreachable!(
"Existing `signatures[{}]` field is not an object; cannot insert policy \ "Existing `signatures[{}]` field is not an object; cannot insert policy \
signature", signature",
via via
); );
return Ok(false);
}, },
| None => { | None => {
let mut inner = BTreeMap::new(); let mut inner = BTreeMap::new();
@ -293,11 +356,12 @@ pub async fn fetch_policy_server_signature(
signatures_map.insert(via.as_str().to_owned(), CanonicalJsonValue::Object(inner)); signatures_map.insert(via.as_str().to_owned(), CanonicalJsonValue::Object(inner));
}, },
} }
// TODO: verify signature value was made with the policy_server_key
// rather than the expected key.
} else { } else {
debug_warn!( unreachable!(
"Existing `signatures` field is not an object; cannot insert policy signature" "Existing `signatures` field is not an object; cannot insert policy signature"
); );
return Ok(false);
} }
Ok(true) Ok(())
} }

View file

@ -256,7 +256,7 @@ where
if incoming_pdu.state_key.is_none() { if incoming_pdu.state_key.is_none() {
debug!(event_id = %incoming_pdu.event_id, "Checking policy server for event"); debug!(event_id = %incoming_pdu.event_id, "Checking policy server for event");
match self match self
.ask_policy_server( .policy_server_allows_event(
&incoming_pdu, &incoming_pdu,
&mut incoming_pdu.to_canonical_object(), &mut incoming_pdu.to_canonical_object(),
room_id, room_id,
@ -264,9 +264,10 @@ where
) )
.await .await
{ {
| Ok(false) => { | Err(e) => {
warn!( warn!(
event_id = %incoming_pdu.event_id, event_id = %incoming_pdu.event_id,
error = %e,
"Event has been marked as spam by policy server" "Event has been marked as spam by policy server"
); );
soft_fail = true; soft_fail = true;

View file

@ -9,7 +9,6 @@ use conduwuit_core::{
state_res::{self, RoomVersion}, state_res::{self, RoomVersion},
}, },
utils::{self, IterStream, ReadyExt, stream::TryIgnore}, utils::{self, IterStream, ReadyExt, stream::TryIgnore},
warn,
}; };
use futures::{StreamExt, TryStreamExt, future, future::ready}; use futures::{StreamExt, TryStreamExt, future, future::ready};
use ruma::{ use ruma::{
@ -298,23 +297,15 @@ pub async fn create_hash_and_sign_event(
"Checking event in room {} with policy server", "Checking event in room {} with policy server",
pdu.room_id.as_ref().map_or("None", |id| id.as_str()) pdu.room_id.as_ref().map_or("None", |id| id.as_str())
); );
match self self.services
.services
.event_handler .event_handler
.ask_policy_server(&pdu, &mut pdu_json, pdu.room_id().expect("has room ID"), false) .policy_server_allows_event(
.await &pdu,
{ &mut pdu_json,
| Ok(true) => {}, pdu.room_id().expect("has room ID"),
| Ok(false) => { false,
return Err!(Request(Forbidden(debug_warn!( )
"Policy server marked this event as spam" .await?;
))));
},
| Err(e) => {
// fail open
warn!("Failed to check event with policy server: {e}");
},
}
} }
// Generate short event id // Generate short event id