feat: Do more refactoring

This commit is contained in:
timedout 2026-01-20 01:27:05 +00:00
parent bd404e808c
commit c69e7c7d1b
No known key found for this signature in database
GPG key ID: 0FA334385D0B689F
8 changed files with 628 additions and 1538 deletions

View file

@ -1,31 +1,107 @@
//! Auth checks relevant to any event's `auth_events`.
//!
//! See: https://spec.matrix.org/v1.16/rooms/v12/#authorization-rules
use std::{
collections::{HashMap, HashSet},
future::Future,
};
use std::collections::{HashMap, HashSet};
use ruma::{EventId, OwnedEventId, RoomId, events::StateEventType};
use ruma::{
EventId, OwnedEventId, RoomId, UserId,
events::{
StateEventType, TimelineEventType,
room::member::{MembershipState, RoomMemberEventContent, ThirdPartyInvite},
},
};
use crate::{Event, EventTypeExt, Pdu, RoomVersion, matrix::StateKey, state_res::Error, warn};
// Checks for duplicate auth events in the `auth_events` field of an event.
// Note: the caller should already have all of the auth events fetched.
//
// If there are multiple auth events of the same type and state key, this
// returns an error. Otherwise, it returns a map of (type, state_key) to the
// corresponding auth event.
pub async fn check_duplicate_auth_events<E, Fut>(
/// For the given event `kind` what are the relevant auth events that are needed
/// to authenticate this `content`.
///
/// # Errors
///
/// This function will return an error if the supplied `content` is not a JSON
/// object.
pub fn auth_types_for_event(
room_version: &RoomVersion,
event_type: &TimelineEventType,
state_key: Option<&StateKey>,
sender: &UserId,
member_content: Option<RoomMemberEventContent>,
) -> serde_json::Result<Vec<(StateEventType, StateKey)>> {
if event_type == &TimelineEventType::RoomCreate {
// Create events never have auth events
return Ok(vec![]);
}
let mut auth_types = if room_version.room_ids_as_hashes {
vec![
StateEventType::RoomMember.with_state_key(sender.as_str()),
StateEventType::RoomPowerLevels.with_state_key(""),
]
} else {
// For room versions that do not use room IDs as hashes, include the
// RoomCreate event as an auth event.
vec![
StateEventType::RoomMember.with_state_key(sender.as_str()),
StateEventType::RoomPowerLevels.with_state_key(""),
StateEventType::RoomCreate.with_state_key(""),
]
};
if event_type == &TimelineEventType::RoomMember {
let member_content =
member_content.expect("member_content must be provided for RoomMember events");
// Include the target's membership (if available)
auth_types.push((
StateEventType::RoomMember,
state_key
.expect("state_key must be provided for RoomMember events")
.to_owned(),
));
if matches!(
member_content.membership,
MembershipState::Join | MembershipState::Invite | MembershipState::Knock
) {
// Include the join rules
auth_types.push(StateEventType::RoomJoinRules.with_state_key(""));
}
if matches!(member_content.membership, MembershipState::Invite) {
// If this is an invite, include the third party invite if it exists
if let Some(ThirdPartyInvite { signed, .. }) = member_content.third_party_invite {
auth_types
.push(StateEventType::RoomThirdPartyInvite.with_state_key(signed.token));
}
}
if matches!(member_content.membership, MembershipState::Join)
&& room_version.restricted_join_rules
{
// If this is a restricted join, include the authorizing user's membership
if let Some(authorizing_user) = member_content.join_authorized_via_users_server {
auth_types
.push(StateEventType::RoomMember.with_state_key(authorizing_user.as_str()));
}
}
}
Ok(auth_types)
}
/// Checks for duplicate auth events in the `auth_events` field of an event.
/// Note: the caller should already have all of the auth events fetched.
///
/// If there are multiple auth events of the same type and state key, this
/// returns an error. Otherwise, it returns a map of (type, state_key) to the
/// corresponding auth event.
pub async fn check_duplicate_auth_events<FE>(
auth_events: &[OwnedEventId],
fetch_event: impl Fn(&EventId) -> Fut + Send,
) -> Result<HashMap<(StateEventType, StateKey), E>, Error>
fetch_event: FE,
) -> Result<HashMap<(StateEventType, StateKey), Pdu>, Error>
where
Fut: Future<Output = Result<Option<E>, Error>> + Send,
E: Event + Send + Sync,
for<'a> &'a E: Event + Send,
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
{
let mut seen: HashMap<(StateEventType, StateKey), E> = HashMap::new();
let mut seen: HashMap<(StateEventType, StateKey), Pdu> = HashMap::new();
// Considering all of the event's auth events:
for auth_event_id in auth_events {
@ -79,23 +155,15 @@ pub fn check_unnecessary_auth_events(
// Checks that all provided auth events were not rejected previously.
//
// TODO: this is currently a no-op and always returns Ok(()).
pub fn check_all_auth_events_accepted<E>(
_auth_events: &HashMap<(StateEventType, StateKey), E>,
) -> Result<(), Error>
where
E: Event + Send + Sync,
for<'a> &'a E: Event + Send,
{
pub fn check_all_auth_events_accepted(
_auth_events: &HashMap<(StateEventType, StateKey), Pdu>,
) -> Result<(), Error> {
Ok(())
}
// Checks that all auth events are from the same room as the event being
// validated.
pub fn check_auth_same_room<E>(auth_events: &Vec<E>, room_id: &RoomId) -> bool
where
E: Event + Send + Sync,
for<'a> &'a E: Event + Send,
{
pub fn check_auth_same_room(auth_events: &Vec<Pdu>, room_id: &RoomId) -> bool {
for auth_event in auth_events {
if let Some(auth_room_id) = &auth_event.room_id() {
if auth_room_id.as_str() != room_id.as_str() {
@ -115,17 +183,15 @@ where
true
}
// Performs all auth event checks for the given event.
pub async fn check_auth_events<E, Fut>(
/// Performs all auth event checks for the given event.
pub async fn check_auth_events<FE>(
event: &Pdu,
room_id: &RoomId,
room_version: &RoomVersion,
fetch_event: impl Fn(&EventId) -> Fut + Send,
) -> Result<HashMap<(StateEventType, StateKey), E>, Error>
fetch_event: &FE,
) -> Result<HashMap<(StateEventType, StateKey), Pdu>, Error>
where
Fut: Future<Output = Result<Option<E>, Error>> + Send,
E: Event + Send + Sync,
for<'a> &'a E: Event + Send,
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
{
// If there are duplicate entries for a given type and state_key pair, reject.
let auth_events_map = check_duplicate_auth_events(&event.auth_events, fetch_event).await?;
@ -135,12 +201,19 @@ where
// If there are entries whose type and state_key dont match those specified by
// the auth events selection algorithm described in the server specification,
// reject.
let expected_auth_events = crate::state_res::auth_types_for_event(
event.kind(),
event.sender(),
event.state_key(),
event.content(),
let member_event_content = match event.kind() {
| TimelineEventType::RoomMember =>
Some(event.get_content::<RoomMemberEventContent>().map_err(|e| {
Error::InvalidPdu(format!("Failed to parse m.room.member content: {}", e))
})?),
| _ => None,
};
let expected_auth_events = auth_types_for_event(
room_version,
event.kind(),
event.state_key.as_ref(),
event.sender(),
member_event_content,
)?;
if let Err(e) = check_unnecessary_auth_events(&auth_events_set, &expected_auth_events) {
return Err(e);
@ -154,7 +227,7 @@ where
// If any event in auth_events has a room_id which does not match that of the
// event being authorised, reject.
let auth_event_refs: Vec<E> = auth_events_map.values().cloned().collect();
let auth_event_refs: Vec<Pdu> = auth_events_map.values().cloned().collect();
if !check_auth_same_room(&auth_event_refs, room_id) {
return Err(Error::InvalidPdu(
"One or more auth events are from a different room".to_owned(),

View file

@ -61,6 +61,7 @@ where
Ok(vec![create_event.sender().to_owned()])
} else {
// Have to check the event content
#[allow(deprecated)]
if let Some(creator) = content.creator {
Ok(vec![creator])
} else {

File diff suppressed because it is too large Load diff

View file

@ -14,7 +14,6 @@ use ruma::{
serde::Base64,
signatures::{PublicKeyMap, PublicKeySet, verify_json},
};
use serde::Deserializer;
use crate::{
Event, EventTypeExt, Pdu, RoomVersion,
@ -200,7 +199,123 @@ where
}
}
async fn check_invite_event<FE, FS>(
/// Checks a third-party invite is valid.
async fn check_third_party_invite(
target_current_membership: PartialMembershipObject,
raw_third_party_invite: &serde_json::Value,
target: &UserId,
event: &Pdu,
fetch_state: impl AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
) -> Result<(), Error> {
// 4.1.1: If target user is banned, reject.
if target_current_membership
.membership
.is_some_and(|m| m == "ban")
{
return Err(Error::AuthConditionFailed("invite target is banned".to_owned()));
}
// 4.1.2: If content.third_party_invite does not have a signed property, reject.
let signed = raw_third_party_invite.get("signed").ok_or_else(|| {
Error::AuthConditionFailed(
"invite event third_party_invite missing signed property".to_owned(),
)
})?;
// 4.2.3: If signed does not have mxid and token properties, reject.
let mxid = signed.get("mxid").and_then(|v| v.as_str()).ok_or_else(|| {
Error::AuthConditionFailed(
"invite event third_party_invite signed missing/invalid mxid property".to_owned(),
)
})?;
let token = signed
.get("token")
.and_then(|v| v.as_str())
.ok_or_else(|| {
Error::AuthConditionFailed(
"invite event third_party_invite signed missing token property".to_owned(),
)
})?;
// 4.2.4: If mxid does not match state_key, reject.
if mxid != target.as_str() {
return Err(Error::AuthConditionFailed(
"invite event third_party_invite signed mxid does not match state_key".to_owned(),
));
}
// 4.2.5: If there is no m.room.third_party_invite event in the room
// state matching the token, reject.
let Some(third_party_invite_event) =
fetch_state(StateEventType::RoomThirdPartyInvite.with_state_key(token)).await?
else {
return Err(Error::AuthConditionFailed(
"invite event third_party_invite token has no matching m.room.third_party_invite"
.to_owned(),
));
};
// 4.2.6: If sender does not match sender of the m.room.third_party_invite,
// reject.
if third_party_invite_event.sender() != event.sender() {
return Err(Error::AuthConditionFailed(
"invite event sender does not match m.room.third_party_invite sender".to_owned(),
));
}
// 4.2.7: If any signature in signed matches any public key in the
// m.room.third_party_invite event, allow. The public keys are in
// content of m.room.third_party_invite as:
// 1. A single public key in the public_key property.
// 2. A list of public keys in the public_keys property.
let tpi_content = third_party_invite_event
.get_content::<RoomThirdPartyInviteEventContent>()
.or_else(|_| {
Err(Error::InvalidPdu(
"m.room.third_party_invite event has invalid content".to_owned(),
))
})?;
let mut public_keys = tpi_content.public_keys.unwrap_or_default();
public_keys.push(PublicKey {
public_key: tpi_content.public_key,
key_validity_url: None,
});
let signatures = signed
.get("signatures")
.and_then(|v| v.as_object())
.ok_or_else(|| {
Error::InvalidPdu(
"invite event third_party_invite signed missing/invalid signatures".to_owned(),
)
})?;
let mut public_key_map = PublicKeyMap::new();
for (server_name, sig_map) in signatures {
let mut pk_set = PublicKeySet::new();
if let Some(sig_map) = sig_map.as_object() {
for (key_id, sig) in sig_map {
let sig_b64 = Base64::parse(sig.as_str().ok_or(Error::InvalidPdu(
"invite event third_party_invite signature is not a string".to_owned(),
))?)
.map_err(|_| {
Error::InvalidPdu(
"invite event third_party_invite signature is not valid Base64"
.to_owned(),
)
})?;
pk_set.insert(key_id.clone(), sig_b64);
}
}
public_key_map.insert(server_name.clone(), pk_set);
}
verify_json(
&public_key_map,
to_canonical_object(signed).expect("signed was already validated"),
)
.map_err(|e| {
Error::AuthConditionFailed(format!(
"invite event third_party_invite signature verification failed: {e}"
))
})?;
// If there was no error, there was a valid signature, so allow.
Ok(())
}
async fn check_invite_event<FS>(
room_version: &RoomVersion,
event: &Pdu,
membership: &PartialMembershipObject,
@ -208,120 +323,20 @@ async fn check_invite_event<FE, FS>(
fetch_state: &FS,
) -> Result<(), Error>
where
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
{
let target_current_membership = fetch_membership(fetch_state, target).await?;
// 4.1: If content has a third_party_invite property:
if let Some(raw_third_party_invite) = &membership.third_party_invite {
// 4.1.1: If target user is banned, reject.
if target_current_membership
.membership
.is_some_and(|m| m == "ban")
{
return Err(Error::AuthConditionFailed("invite target is banned".to_owned()));
}
// 4.1.2: If content.third_party_invite does not have a signed property, reject.
let signed = raw_third_party_invite.get("signed").ok_or_else(|| {
Error::AuthConditionFailed(
"invite event third_party_invite missing signed property".to_owned(),
)
})?;
// 4.2.3: If signed does not have mxid and token properties, reject.
let mxid = signed.get("mxid").and_then(|v| v.as_str()).ok_or_else(|| {
Error::AuthConditionFailed(
"invite event third_party_invite signed missing/invalid mxid property".to_owned(),
)
})?;
let token = signed
.get("token")
.and_then(|v| v.as_str())
.ok_or_else(|| {
Error::AuthConditionFailed(
"invite event third_party_invite signed missing token property".to_owned(),
)
})?;
// 4.2.4: If mxid does not match state_key, reject.
if mxid != target.as_str() {
return Err(Error::AuthConditionFailed(
"invite event third_party_invite signed mxid does not match state_key".to_owned(),
));
}
// 4.2.5: If there is no m.room.third_party_invite event in the room
// state matching the token, reject.
let Some(third_party_invite_event) =
fetch_state(StateEventType::RoomThirdPartyInvite.with_state_key(token)).await?
else {
return Err(Error::AuthConditionFailed(
"invite event third_party_invite token has no matching m.room.third_party_invite"
.to_owned(),
));
};
// 4.2.6: If sender does not match sender of the m.room.third_party_invite,
// reject.
if third_party_invite_event.sender() != event.sender() {
return Err(Error::AuthConditionFailed(
"invite event sender does not match m.room.third_party_invite sender".to_owned(),
));
}
// 4.2.7: If any signature in signed matches any public key in the
// m.room.third_party_invite event, allow. The public keys are in
// content of m.room.third_party_invite as:
// 1. A single public key in the public_key property.
// 2. A list of public keys in the public_keys property.
let tpi_content = third_party_invite_event
.get_content::<RoomThirdPartyInviteEventContent>()
.or_else(|_| {
Err(Error::InvalidPdu(
"m.room.third_party_invite event has invalid content".to_owned(),
))
})?;
let mut public_keys = tpi_content.public_keys.unwrap_or_default();
public_keys.push(PublicKey {
public_key: tpi_content.public_key,
key_validity_url: None,
});
let signatures = signed
.get("signatures")
.and_then(|v| v.as_object())
.ok_or_else(|| {
Error::InvalidPdu(
"invite event third_party_invite signed missing/invalid signatures"
.to_owned(),
)
})?;
let mut public_key_map = PublicKeyMap::new();
for (server_name, sig_map) in signatures {
let mut pk_set = PublicKeySet::new();
if let Some(sig_map) = sig_map.as_object() {
for (key_id, sig) in sig_map {
let sig_b64 = Base64::parse(sig.as_str().ok_or(Error::InvalidPdu(
"invite event third_party_invite signature is not a string".to_owned(),
))?)
.map_err(|_| {
Error::InvalidPdu(
"invite event third_party_invite signature is not valid Base64"
.to_owned(),
)
})?;
pk_set.insert(key_id.clone(), sig_b64);
}
}
public_key_map.insert(server_name.clone(), pk_set);
}
verify_json(
&public_key_map,
to_canonical_object(signed).expect("signed was already validated"),
return check_third_party_invite(
target_current_membership,
raw_third_party_invite,
target,
event,
fetch_state,
)
.map_err(|e| {
Error::AuthConditionFailed(format!(
"invite event third_party_invite signature verification failed: {e}"
))
})?;
// If there was no error, there was a valid signature, so allow.
return Ok(());
.await;
}
// 4.2: If the senders current membership state is not join, reject.
@ -354,7 +369,7 @@ where
}
pub async fn check_member_event<FE, FS>(
room_version: RoomVersion,
room_version: &RoomVersion,
event: &Pdu,
fetch_event: FE,
fetch_state: FS,
@ -395,10 +410,10 @@ where
match content.membership.as_deref().unwrap() {
| "join" =>
check_join_event(&room_version, event, &content, &target, &fetch_event, &fetch_state)
check_join_event(room_version, event, &content, &target, &fetch_event, &fetch_state)
.await?,
| "invite" =>
check_invite_event(&room_version, event, &content, &target, &fetch_state).await?,
check_invite_event(room_version, event, &content, &target, &fetch_state).await?,
| _ => {
todo!()
},

View file

@ -3,3 +3,4 @@ mod context;
pub mod create_event;
pub mod iterative_auth_checks;
pub mod member_event;
mod power_levels;

View file

@ -0,0 +1,157 @@
use ruma::{OwnedUserId, events::room::power_levels::RoomPowerLevelsEventContent};
use crate::{
Event, Pdu, RoomVersion,
state_res::{Error, event_auth::context::UserPower},
};
/// Verifies that a m.room.power_levels event is well-formed according to the
/// Matrix specification.
///
/// Creators must contain the m.room.create sender and any additional creators.
pub async fn check_power_levels(
room_version: &RoomVersion,
event: &Pdu,
current_power_levels: Option<&RoomPowerLevelsEventContent>,
creators: Vec<OwnedUserId>,
) -> Result<(), Error> {
let content = event
.get_content::<RoomPowerLevelsEventContent>()
.map_err(|e| {
Error::InvalidPdu(format!("m.room.power_levels event has invalid content: {}", e))
})?;
// If any of the properties users_default, events_default, state_default, ban,
// redact, kick, or invite in content are present and not an integer, reject.
//
// If either of the properties events or notifications in content are present
// and not an object with values that are integers, reject.
//
// NOTE: Deserialisation fails if this is not the case, so we don't need to
// check these here.
// If the users property in content is not an object with keys that are valid
// user IDs with values that are integers (or a string that is an integer),
// reject.
while let Some(user_id) = content.users.keys().next() {
// NOTE: Deserialisation fails if the power level is not an integer, so we don't
// need to check that here.
if let Err(e) = user_id.validate_historical() {
return Err(Error::InvalidPdu(format!(
"m.room.power_levels event has invalid user ID in users map: {}",
e
)));
}
// Since v12, If the users property in content contains the sender of the
// m.room.create event or any of the additional_creators array (if present)
// from the content of the m.room.create event, reject.
if room_version.explicitly_privilege_room_creators && creators.contains(user_id) {
return Err(Error::InvalidPdu(
"m.room.power_levels event users map contains a room creator".to_string(),
));
}
}
// If there is no previous m.room.power_levels event in the room, allow.
if current_power_levels.is_none() {
return Ok(());
}
let current_power_levels = current_power_levels.unwrap();
// For the properties users_default, events_default, state_default, ban, redact,
// kick, invite check if they were added, changed or removed. For each found
// alteration:
// If the current value is higher than the senders current power level, reject.
// If the new value is higher than the senders current power level, reject.
let sender = event.sender();
let rank = if room_version.explicitly_privilege_room_creators {
if creators.contains(&sender.to_owned()) {
UserPower::Creator
} else {
UserPower::Standard
}
} else {
UserPower::Standard
};
let sender_pl = current_power_levels
.users
.get(sender)
.unwrap_or(&current_power_levels.users_default);
if rank != UserPower::Creator {
let checks = [
("users_default", current_power_levels.users_default, content.users_default),
("events_default", current_power_levels.events_default, content.events_default),
("state_default", current_power_levels.state_default, content.state_default),
("ban", current_power_levels.ban, content.ban),
("redact", current_power_levels.redact, content.redact),
("kick", current_power_levels.kick, content.kick),
("invite", current_power_levels.invite, content.invite),
];
for (name, old_value, new_value) in checks.iter() {
if old_value != new_value {
if *old_value > *sender_pl {
return Err(Error::AuthConditionFailed(format!(
"sender cannot change level for {}",
name
)));
}
if *new_value > *sender_pl {
return Err(Error::AuthConditionFailed(format!(
"sender cannot raise level for {} to {}",
name, new_value
)));
}
}
}
// For each entry being changed in, or removed from, the events
// property:
// If the current value is greater than the senders current power level,
// reject.
for (event_type, new_value) in content.events.iter() {
let old_value = current_power_levels.events.get(event_type);
if old_value != Some(new_value) {
let old_pl = old_value.unwrap_or(&current_power_levels.events_default);
if *old_pl > *sender_pl {
return Err(Error::AuthConditionFailed(format!(
"sender cannot change event level for {}",
event_type
)));
}
if *new_value > *sender_pl {
return Err(Error::AuthConditionFailed(format!(
"sender cannot raise event level for {} to {}",
event_type, new_value
)));
}
}
}
// For each entry being changed in, or removed from, the events or
// notifications properties:
// If the current value is greater than the senders current power
// level, reject.
// If the new value is greater than the senders current power level,
// reject.
// TODO after making ruwuma's notifications value a BTreeMap
// For each entry being added to, or changed in, the users property:
// If the new value is greater than the senders current power level, reject.
for (user_id, new_value) in content.users.iter() {
let old_value = current_power_levels.users.get(user_id);
if old_value != Some(new_value) {
if *new_value > *sender_pl {
return Err(Error::AuthConditionFailed(format!(
"sender cannot raise user level for {} to {}",
user_id, new_value
)));
}
}
}
}
Ok(())
}

View file

@ -15,7 +15,8 @@ use std::{
hash::{BuildHasher, Hash},
};
use futures::{Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, future};
use futures::{Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use itertools::Itertools;
use ruma::{
EventId, Int, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId,
events::{
@ -28,14 +29,13 @@ use serde_json::from_str as from_json_str;
pub(crate) use self::error::Error;
use self::power_levels::PowerLevelsContentFields;
pub use self::{
event_auth::iterative_auth_checks::{auth_types_for_event, iterative_auth_check},
room_version::RoomVersion,
};
pub use self::{event_auth::iterative_auth_checks::auth_check, room_version::RoomVersion};
use crate::{
debug, debug_error, err,
Pdu, debug, err, error as log_error,
matrix::{Event, StateKey},
state_res::room_version::StateResolutionVersion,
state_res::{
event_auth::auth_events::auth_types_for_event, room_version::StateResolutionVersion,
},
trace,
utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, WidebandExt},
warn,
@ -72,23 +72,19 @@ type Result<T, E = Error> = crate::Result<T, E>;
/// event is part of the same room.
//#[tracing::instrument(level = "debug", skip(state_sets, auth_chain_sets,
//#[tracing::instrument(level event_fetch))]
pub async fn resolve<'a, Pdu, Sets, SetIter, Hasher, Fetch, FetchFut, Exists, ExistsFut>(
pub async fn resolve<'a, Sets, SetIter, Hasher, FE, Exists>(
room_version: &RoomVersionId,
state_sets: Sets,
auth_chain_sets: &'a [HashSet<OwnedEventId, Hasher>],
event_fetch: &Fetch,
event_fetch: &FE,
event_exists: &Exists,
) -> Result<StateMap<OwnedEventId>>
where
Fetch: Fn(OwnedEventId) -> FetchFut + Sync,
FetchFut: Future<Output = Option<Pdu>> + Send,
Exists: Fn(OwnedEventId) -> ExistsFut + Sync,
ExistsFut: Future<Output = bool> + Send,
FE: AsyncFn(OwnedEventId) -> Result<Option<Pdu>> + Sync,
Exists: AsyncFn(OwnedEventId) -> bool + Sync,
Sets: IntoIterator<IntoIter = SetIter> + Send,
SetIter: Iterator<Item = &'a StateMap<OwnedEventId>> + Clone + Send,
Hasher: BuildHasher + Send + Sync,
Pdu: Event + Clone + Send + Sync,
for<'b> &'b Pdu: Event + Send,
{
use RoomVersionId::*;
let stateres_version = match room_version {
@ -166,7 +162,7 @@ where
// Sequentially auth check each control event.
let resolved_control = iterative_auth_check(
&room_version,
sorted_control_levels.iter().stream().map(AsRef::as_ref),
sorted_control_levels.iter().stream().map(ToOwned::to_owned),
initial_state,
&event_fetch,
)
@ -206,7 +202,7 @@ where
let mut resolved_state = iterative_auth_check(
&room_version,
sorted_left_events.iter().stream().map(AsRef::as_ref),
sorted_left_events.iter().stream(),
resolved_control, // The control events are added to the final resolved state
&event_fetch,
)
@ -270,14 +266,12 @@ where
}
/// Calculate the conflicted subgraph
async fn calculate_conflicted_subgraph<F, Fut, E>(
async fn calculate_conflicted_subgraph<FE>(
conflicted: &StateMap<Vec<OwnedEventId>>,
fetch_event: &F,
fetch_event: &FE,
) -> Option<HashSet<OwnedEventId>>
where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync,
FE: AsyncFn(OwnedEventId) -> Result<Option<Pdu>> + Sync,
{
let conflicted_events: HashSet<_> = conflicted.values().flatten().cloned().collect();
let mut subgraph: HashSet<OwnedEventId> = HashSet::new();
@ -309,7 +303,17 @@ where
continue;
}
trace!(event_id = event_id.as_str(), "fetching event for its auth events");
let evt = fetch_event(event_id.clone()).await;
let evt = fetch_event(event_id.clone())
.await
.inspect_err(|e| {
log_error!(
"error fetching event {} for conflicted state subgraph: {}",
event_id,
e
)
})
.ok()
.flatten();
if evt.is_none() {
err!("could not fetch event {} to calculate conflicted subgraph", event_id);
path.pop();
@ -356,15 +360,13 @@ where
/// The power level is negative because a higher power level is equated to an
/// earlier (further back in time) origin server timestamp.
#[tracing::instrument(level = "debug", skip_all)]
async fn reverse_topological_power_sort<E, F, Fut>(
async fn reverse_topological_power_sort<FE>(
events_to_sort: Vec<OwnedEventId>,
auth_diff: &HashSet<OwnedEventId>,
fetch_event: &F,
fetch_event: &FE,
) -> Result<Vec<OwnedEventId>>
where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync,
FE: AsyncFn(OwnedEventId) -> Result<Option<Pdu>> + Sync,
{
debug!("reverse topological sort of power events");
@ -402,7 +404,7 @@ where
.ok_or_else(|| Error::NotFound(String::new()))?;
let ev = fetch_event(event_id)
.await
.await?
.ok_or_else(|| Error::NotFound(String::new()))?;
Ok((pl, ev.origin_server_ts()))
@ -541,18 +543,13 @@ where
/// Do NOT use this any where but topological sort, we find the power level for
/// the eventId at the eventId's generation (we walk backwards to `EventId`s
/// most recent previous power level event).
async fn get_power_level_for_sender<E, F, Fut>(
event_id: &EventId,
fetch_event: &F,
) -> serde_json::Result<Int>
async fn get_power_level_for_sender<FE>(event_id: &EventId, fetch_event: &FE) -> Result<Int>
where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
FE: AsyncFn(OwnedEventId) -> Result<Option<Pdu>> + Sync,
{
debug!("fetch event ({event_id}) senders power level");
let event = fetch_event(event_id.to_owned()).await;
let event = fetch_event(event_id.to_owned()).await?;
let auth_events = event.as_ref().map(Event::auth_events);
@ -591,27 +588,23 @@ where
/// the the `fetch_event` closure and verify each event using the
/// `event_auth::auth_check` function.
#[tracing::instrument(level = "trace", skip_all)]
async fn iterative_auth_check<'a, E, F, Fut, S>(
async fn iterative_auth_check<FE, S>(
room_version: &RoomVersion,
events_to_check: S,
unconflicted_state: StateMap<OwnedEventId>,
fetch_event: &F,
fetch_event: &FE,
) -> Result<StateMap<OwnedEventId>>
where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
S: Stream<Item = &'a EventId> + Send + 'a,
E: Event + Clone + Send + Sync,
for<'b> &'b E: Event + Send,
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error> + Sync + Send,
S: Stream<Item = OwnedEventId> + Send,
{
debug!("starting iterative auth check");
let events_to_check: Vec<_> = events_to_check
.map(Result::Ok)
.broad_and_then(async |event_id| {
fetch_event(event_id.to_owned())
.await
.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}")))
.map(Ok::<OwnedEventId, Error>)
.broad_and_then(async |event_id| match fetch_event(&event_id).await {
| Ok(Some(e)) => Ok(e),
| _ => Err(Error::NotFound(format!("could not find {event_id}")))?,
})
.try_collect()
.boxed()
@ -624,16 +617,20 @@ where
let auth_event_ids: HashSet<OwnedEventId> = events_to_check
.iter()
.flat_map(|event: &E| event.auth_events().map(ToOwned::to_owned))
.flat_map(|event: &Pdu| event.auth_events().map(ToOwned::to_owned))
.collect();
trace!(set = ?auth_event_ids, "auth event IDs to fetch");
let auth_events: HashMap<OwnedEventId, E> = auth_event_ids
let auth_events: HashMap<OwnedEventId, Pdu> = auth_event_ids
.into_iter()
.stream()
.broad_filter_map(fetch_event)
.map(|auth_event| (auth_event.event_id().to_owned(), auth_event))
.broad_filter_map(async |event_id| {
fetch_event(&event_id)
.await
.map(|ev_opt| ev_opt.map(|ev| (event_id.clone(), ev)))
.unwrap_or_default()
})
.collect()
.boxed()
.await;
@ -652,29 +649,23 @@ where
.state_key()
.ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?;
let member_event_content = match event.kind() {
| TimelineEventType::RoomMember =>
Some(event.get_content::<RoomMemberEventContent>().map_err(|e| {
Error::InvalidPdu(format!("Failed to parse m.room.member content: {}", e))
})?),
| _ => None,
};
let auth_types = auth_types_for_event(
event.event_type(),
event.sender(),
Some(state_key),
event.content(),
room_version,
event.kind(),
event.state_key().map(StateKey::from_str).as_ref(),
event.sender(),
member_event_content,
)?;
trace!(list = ?auth_types, event_id = event.event_id().as_str(), "auth types for event");
let mut auth_state = StateMap::new();
if room_version.room_ids_as_hashes {
trace!("room version uses hashed IDs, manually fetching create event");
let create_event_id_raw = event.room_id_or_hash().as_str().replace('!', "$");
let create_event_id = EventId::parse(&create_event_id_raw).map_err(|e| {
Error::InvalidPdu(format!(
"Failed to parse create event ID from room ID/hash: {e}"
))
})?;
let create_event = fetch_event(create_event_id.into())
.await
.ok_or_else(|| Error::NotFound("Failed to find create event".into()))?;
auth_state.insert(create_event.event_type().with_state_key(""), create_event);
}
let mut auth_state = StateMap::with_capacity(event.auth_events.len());
for aid in event.auth_events() {
if let Some(ev) = auth_events.get(aid) {
//TODO: synapse checks "rejected_reason" which is most likely related to
@ -700,7 +691,13 @@ where
if let Some(event) = auth_events.get(ev_id) {
Some((key, event.clone()))
} else {
Some((key, fetch_event(ev_id.clone()).await?))
match fetch_event(ev_id).await {
| Ok(Some(event)) => Some((key, event)),
| _ => {
warn!(event_id = ev_id.as_str(), "unable to fetch auth event");
None
},
}
}
})
.ready_for_each(|(key, event)| {
@ -712,30 +709,16 @@ where
debug!(event_id = event.event_id().as_str(), "Running auth checks");
// The key for this is (eventType + a state_key of the signed token not sender)
// so search for it
let current_third_party = auth_state.iter().find_map(|(_, pdu)| {
(*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu)
});
let fetch_state = |ty: &StateEventType, key: &str| {
future::ready(
auth_state
.get(&ty.with_state_key(key))
.map(ToOwned::to_owned),
)
let fetch_state = async |t: (StateEventType, StateKey)| {
Ok(auth_state
.get(&t.0.with_state_key(t.1.as_str()))
.map(ToOwned::to_owned))
};
let auth_result = iterative_auth_check(
room_version,
&event,
current_third_party,
fetch_state,
&fetch_state(&StateEventType::RoomCreate, "")
.await
.expect("create event must exist"),
)
.await;
let create_event = fetch_state((StateEventType::RoomCreate, StateKey::new())).await?;
let auth_result =
auth_check(room_version, &event, fetch_event, &fetch_state, create_event.as_ref())
.await;
match auth_result {
| Ok(true) => {
@ -755,7 +738,7 @@ where
warn!("event {} failed the authentication check", event.event_id());
},
| Err(e) => {
debug_error!("event {} failed the authentication check: {e}", event.event_id());
log_error!("event {} failed the authentication check: {e}", event.event_id());
return Err(e);
},
}
@ -774,15 +757,13 @@ where
/// after the most recent are depth 0, the events before (with the first power
/// level as a parent) will be marked as depth 1. depth 1 is "older" than depth
/// 0.
async fn mainline_sort<E, F, Fut>(
async fn mainline_sort<FE>(
to_sort: &[OwnedEventId],
resolved_power_level: Option<OwnedEventId>,
fetch_event: &F,
fetch_event: &FE,
) -> Result<Vec<OwnedEventId>>
where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Clone + Send + Sync,
FE: AsyncFn(OwnedEventId) -> Result<Option<Pdu>> + Sync,
{
debug!("mainline sort of events");
@ -797,13 +778,13 @@ where
mainline.push(p.clone());
let event = fetch_event(p.clone())
.await
.await?
.ok_or_else(|| Error::NotFound(format!("Failed to find {p}")))?;
pl = None;
for aid in event.auth_events() {
let ev = fetch_event(aid.to_owned())
.await
.await?
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
if is_type_and_key(&ev, &TimelineEventType::RoomPowerLevels, "") {
@ -824,7 +805,11 @@ where
.iter()
.stream()
.broad_filter_map(async |ev_id| {
fetch_event(ev_id.clone()).await.map(|event| (event, ev_id))
fetch_event(ev_id.clone())
.await
.ok()
.flatten()
.map(|event| (event, ev_id))
})
.broad_filter_map(|(event, ev_id)| {
get_mainline_depth(Some(event.clone()), &mainline_map, fetch_event)
@ -846,15 +831,13 @@ where
/// Get the mainline depth from the `mainline_map` or finds a power_level event
/// that has an associated mainline depth.
async fn get_mainline_depth<E, F, Fut>(
mut event: Option<E>,
async fn get_mainline_depth<FE>(
mut event: Option<Pdu>,
mainline_map: &HashMap<OwnedEventId, usize>,
fetch_event: &F,
fetch_event: &FE,
) -> Result<usize>
where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync,
FE: AsyncFn(OwnedEventId) -> Result<Option<Pdu>> + Sync,
{
while let Some(sort_ev) = event {
debug!(event_id = sort_ev.event_id().as_str(), "mainline");
@ -867,7 +850,7 @@ where
event = None;
for aid in sort_ev.auth_events() {
let aev = fetch_event(aid.to_owned())
.await
.await?
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
@ -880,20 +863,18 @@ where
Ok(0)
}
async fn add_event_and_auth_chain_to_graph<E, F, Fut>(
async fn add_event_and_auth_chain_to_graph<FE>(
graph: &mut HashMap<OwnedEventId, HashSet<OwnedEventId>>,
event_id: OwnedEventId,
auth_diff: &HashSet<OwnedEventId>,
fetch_event: &F,
fetch_event: &FE,
) where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync,
FE: AsyncFn(OwnedEventId) -> Result<Option<Pdu>> + Sync,
{
let mut state = vec![event_id];
while let Some(eid) = state.pop() {
graph.entry(eid.clone()).or_default();
let event = fetch_event(eid.clone()).await;
let event = fetch_event(eid.clone()).await.ok().flatten();
let auth_events = event.as_ref().map(Event::auth_events).into_iter().flatten();
// Prefer the store to event as the store filters dedups the events
@ -912,14 +893,12 @@ async fn add_event_and_auth_chain_to_graph<E, F, Fut>(
}
}
async fn is_power_event_id<E, F, Fut>(event_id: &EventId, fetch: &F) -> bool
async fn is_power_event_id<FE>(event_id: &EventId, fetch: &FE) -> bool
where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
FE: AsyncFn(OwnedEventId) -> Result<Option<Pdu>> + Sync,
{
match fetch(event_id.to_owned()).await.as_ref() {
| Some(state) => is_power_event(state),
| Ok(Some(state)) => is_power_event(state),
| _ => false,
}
}
@ -976,6 +955,7 @@ where
mod tests {
use std::collections::{HashMap, HashSet};
use itertools::Itertools;
use maplit::{hashmap, hashset};
use rand::seq::SliceRandom;
use ruma::{
@ -1031,7 +1011,7 @@ mod tests {
.await
.unwrap();
let resolved_power = super::iterative_auth_check(
let resolved_power = super::auth_check(
&RoomVersion::V6,
sorted_power_events.iter().map(AsRef::as_ref).stream(),
HashMap::new(), // unconflicted events

View file

@ -236,7 +236,7 @@ pub async fn create_hash_and_sign_event(
| _ => create_pdu.as_ref().unwrap().as_pdu(),
};
let auth_check = state_res::iterative_auth_check(
let auth_check = state_res::auth_check(
&room_version,
&pdu,
None, // TODO: third_party_invite