Compare commits

...
Sign in to create a new pull request.

3 commits

Author SHA1 Message Date
timedout
07d35c6112
fix: Deserialisation error in public keys 2026-02-06 22:21:45 +00:00
timedout
2e04ae947f
fix: Signature verification 2026-01-30 00:31:44 +00:00
timedout
c4c1481d78
fix: Pull changes from the event auth refactor to fix third party invites 2026-01-29 22:16:16 +00:00
8 changed files with 273 additions and 490 deletions

1
Cargo.lock generated
View file

@ -1090,6 +1090,7 @@ dependencies = [
"core_affinity", "core_affinity",
"ctor", "ctor",
"cyborgtime", "cyborgtime",
"ed25519-dalek",
"either", "either",
"figment", "figment",
"futures", "futures",

View file

@ -657,7 +657,6 @@ async fn join_room_by_id_helper_remote(
let auth_check = state_res::event_auth::auth_check( let auth_check = state_res::event_auth::auth_check(
&state_res::RoomVersion::new(&room_version_id)?, &state_res::RoomVersion::new(&room_version_id)?,
&parsed_join_pdu, &parsed_join_pdu,
None, // TODO: third party invite
|k, s| state_fetch(k.clone(), s.into()), |k, s| state_fetch(k.clone(), s.into()),
&state_fetch(StateEventType::RoomCreate, "".into()) &state_fetch(StateEventType::RoomCreate, "".into())
.await .await

View file

@ -10,31 +10,31 @@ version.workspace = true
[lib] [lib]
path = "mod.rs" path = "mod.rs"
crate-type = [ crate-type = [
"rlib", "rlib",
# "dylib", # "dylib",
] ]
[features] [features]
brotli_compression = [ brotli_compression = [
"reqwest/brotli", "reqwest/brotli",
] ]
conduwuit_mods = [ conduwuit_mods = [
"dep:libloading" "dep:libloading"
] ]
gzip_compression = [ gzip_compression = [
"reqwest/gzip", "reqwest/gzip",
] ]
hardened_malloc = [ hardened_malloc = [
"dep:hardened_malloc-rs" "dep:hardened_malloc-rs"
] ]
jemalloc = [ jemalloc = [
"dep:tikv-jemalloc-sys", "dep:tikv-jemalloc-sys",
"dep:tikv-jemalloc-ctl", "dep:tikv-jemalloc-ctl",
"dep:tikv-jemallocator", "dep:tikv-jemallocator",
] ]
jemalloc_conf = [] jemalloc_conf = []
jemalloc_prof = [ jemalloc_prof = [
"tikv-jemalloc-sys/profiling", "tikv-jemalloc-sys/profiling",
] ]
jemalloc_stats = [ jemalloc_stats = [
"tikv-jemalloc-sys/stats", "tikv-jemalloc-sys/stats",
@ -43,10 +43,10 @@ jemalloc_stats = [
] ]
perf_measurements = [] perf_measurements = []
release_max_log_level = [ release_max_log_level = [
"tracing/max_level_trace", "tracing/max_level_trace",
"tracing/release_max_level_info", "tracing/release_max_level_info",
"log/max_level_trace", "log/max_level_trace",
"log/release_max_level_info", "log/release_max_level_info",
] ]
sentry_telemetry = [] sentry_telemetry = []
zstd_compression = [ zstd_compression = [
@ -110,6 +110,7 @@ tracing.workspace = true
url.workspace = true url.workspace = true
parking_lot.workspace = true parking_lot.workspace = true
lock_api.workspace = true lock_api.workspace = true
ed25519-dalek = "~2"
[target.'cfg(unix)'.dependencies] [target.'cfg(unix)'.dependencies]
nix.workspace = true nix.workspace = true

View file

@ -1,26 +1,36 @@
use std::{borrow::Borrow, collections::BTreeSet}; use std::{borrow::Borrow, collections::BTreeSet};
use ed25519_dalek::{Verifier, VerifyingKey};
use futures::{ use futures::{
Future, Future,
future::{OptionFuture, join, join3}, future::{OptionFuture, join, join3},
}; };
use itertools::Itertools;
use ruma::{ use ruma::{
Int, OwnedUserId, RoomVersionId, UserId, CanonicalJsonObject, Int, OwnedUserId, RoomVersionId, UserId,
canonical_json::to_canonical_value,
events::room::{ events::room::{
create::RoomCreateEventContent, create::RoomCreateEventContent,
join_rules::{JoinRule, RoomJoinRulesEventContent}, join_rules::{JoinRule, RoomJoinRulesEventContent},
member::{MembershipState, ThirdPartyInvite}, member::{MembershipState, ThirdPartyInvite},
power_levels::RoomPowerLevelsEventContent, power_levels::RoomPowerLevelsEventContent,
third_party_invite::RoomThirdPartyInviteEventContent, third_party_invite::{PublicKey, RoomThirdPartyInviteEventContent},
}, },
int, int,
serde::{Base64, Raw}, serde::{
Base64, Base64DecodeError, Raw,
base64::{Standard, UrlSafe},
},
signatures::{ParseError, VerificationError},
}; };
use serde::{ use serde::{
Deserialize, Deserialize,
de::{Error as _, IgnoredAny}, de::{Error as _, IgnoredAny},
}; };
use serde_json::{from_str as from_json_str, value::RawValue as RawJsonValue}; use serde_json::{
from_str as from_json_str, to_value,
value::{RawValue as RawJsonValue, to_raw_value},
};
use super::{ use super::{
Error, Event, Result, StateEventType, StateKey, TimelineEventType, Error, Event, Result, StateEventType, StateKey, TimelineEventType,
@ -30,7 +40,7 @@ use super::{
}, },
room_version::RoomVersion, room_version::RoomVersion,
}; };
use crate::{debug, error, trace, warn}; use crate::{debug, error, trace, utils::to_canonical_object, warn};
// FIXME: field extracting could be bundled for `content` // FIXME: field extracting could be bundled for `content`
#[derive(Deserialize)] #[derive(Deserialize)]
@ -157,15 +167,14 @@ pub fn auth_types_for_event(
pub async fn auth_check<E, F, Fut>( pub async fn auth_check<E, F, Fut>(
room_version: &RoomVersion, room_version: &RoomVersion,
incoming_event: &E, incoming_event: &E,
current_third_party_invite: Option<&E>,
fetch_state: F, fetch_state: F,
create_event: &E, create_event: &E,
) -> Result<bool, Error> ) -> Result<bool, Error>
where where
F: Fn(&StateEventType, &str) -> Fut + Send, F: Fn(&StateEventType, &str) -> Fut + Send + Sync,
Fut: Future<Output = Option<E>> + Send, Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync, E: Event + Send + Sync,
for<'a> &'a E: Event + Send, for<'a> &'a E: Event + Send + Sync,
{ {
debug!( debug!(
event_id = %incoming_event.event_id(), event_id = %incoming_event.event_id(),
@ -415,13 +424,15 @@ where
sender, sender,
sender_member_event.as_ref(), sender_member_event.as_ref(),
incoming_event, incoming_event,
current_third_party_invite,
power_levels_event.as_ref(), power_levels_event.as_ref(),
join_rules_event.as_ref(), join_rules_event.as_ref(),
user_for_join_auth.as_deref(), user_for_join_auth.as_deref(),
&user_for_join_auth_membership, &user_for_join_auth_membership,
&room_create_event, &room_create_event,
)? { &fetch_state,
)
.await?
{
return Ok(false); return Ok(false);
} }
@ -658,23 +669,25 @@ where
/// event and the current State. /// event and the current State.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[allow(clippy::cognitive_complexity)] #[allow(clippy::cognitive_complexity)]
fn valid_membership_change<E>( async fn valid_membership_change<F, Fut, E>(
room_version: &RoomVersion, room_version: &RoomVersion,
target_user: &UserId, target_user: &UserId,
target_user_membership_event: Option<&E>, target_user_membership_event: Option<&E>,
sender: &UserId, sender: &UserId,
sender_membership_event: Option<&E>, sender_membership_event: Option<&E>,
current_event: &E, current_event: &E,
current_third_party_invite: Option<&E>,
power_levels_event: Option<&E>, power_levels_event: Option<&E>,
join_rules_event: Option<&E>, join_rules_event: Option<&E>,
user_for_join_auth: Option<&UserId>, user_for_join_auth: Option<&UserId>,
user_for_join_auth_membership: &MembershipState, user_for_join_auth_membership: &MembershipState,
create_room: &E, create_room: &E,
fetch_state: &F,
) -> Result<bool> ) -> Result<bool>
where where
F: Fn(&StateEventType, &str) -> Fut + Send + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync, E: Event + Send + Sync,
for<'a> &'a E: Event + Send, for<'a> &'a E: Event + Send + Sync,
{ {
#[derive(Deserialize)] #[derive(Deserialize)]
struct GetThirdPartyInvite { struct GetThirdPartyInvite {
@ -950,68 +963,62 @@ where
| MembershipState::Invite => { | MembershipState::Invite => {
// If content has third_party_invite key // If content has third_party_invite key
trace!("starting target_membership=invite check"); trace!("starting target_membership=invite check");
match third_party_invite.and_then(|i| i.deserialize().ok()) { if let Some(third_party_invite) = third_party_invite {
| Some(tp_id) => let allow = verify_third_party_invite(
if target_user_current_membership == MembershipState::Ban { target_user_current_membership,
warn!(?target_user_membership_event_id, "Can't invite banned user"); &serde_json::to_value(third_party_invite)?,
false target_user,
} else { current_event,
let allow = verify_third_party_invite( fetch_state,
Some(target_user), )
sender, .await;
&tp_id, if !allow {
current_third_party_invite, warn!("Third party invite invalid");
); }
if !allow { return Ok(allow);
warn!("Third party invite invalid");
}
allow
},
| _ =>
if !sender_is_joined {
warn!(
%sender,
?sender_membership_event_id,
?sender_membership,
"sender cannot produce an invite without being joined to the room",
);
false
} else if matches!(
target_user_current_membership,
MembershipState::Join | MembershipState::Ban
) {
warn!(
?target_user_membership_event_id,
?target_user_current_membership,
"cannot invite a user who is banned or already joined",
);
false
} else {
let allow = sender_creator
|| sender_power
.filter(|&p| p >= &power_levels.invite)
.is_some();
if !allow {
warn!(
%sender,
has=?sender_power,
required=?power_levels.invite,
"sender does not have enough power to produce invites",
);
}
trace!(
%sender,
?sender_membership_event_id,
?sender_membership,
?target_user_membership_event_id,
?target_user_current_membership,
sender_pl=?sender_power,
required_pl=?power_levels.invite,
"allowing invite"
);
allow
},
} }
if !sender_is_joined {
warn!(
%sender,
?sender_membership_event_id,
?sender_membership,
"sender cannot produce an invite without being joined to the room",
);
return Ok(false);
} else if matches!(
target_user_current_membership,
MembershipState::Join | MembershipState::Ban
) {
warn!(
?target_user_membership_event_id,
?target_user_current_membership,
"cannot invite a user who is banned or already joined",
);
return Ok(false);
}
let allow = sender_creator
|| sender_power
.filter(|&p| p >= &power_levels.invite)
.is_some();
if !allow {
warn!(
%sender,
has=?sender_power,
required=?power_levels.invite,
"sender does not have enough power to produce invites",
);
}
trace!(
%sender,
?sender_membership_event_id,
?sender_membership,
?target_user_membership_event_id,
?target_user_current_membership,
sender_pl=?sender_power,
required_pl=?power_levels.invite,
"allowing invite"
);
return Ok(allow);
}, },
| MembershipState::Leave => { | MembershipState::Leave => {
let can_unban = if target_user_current_membership == MembershipState::Ban { let can_unban = if target_user_current_membership == MembershipState::Ban {
@ -1499,399 +1506,187 @@ fn get_send_level(
.unwrap_or_else(|| if state_key.is_some() { int!(50) } else { int!(0) }) .unwrap_or_else(|| if state_key.is_some() { int!(50) } else { int!(0) })
} }
fn verify_third_party_invite( fn verify_payload(pk: &[u8], sig: &[u8], c: &[u8]) -> Result<(), ruma::signatures::Error> {
target_user: Option<&UserId>, VerifyingKey::from_bytes(
sender: &UserId, pk.try_into()
tp_id: &ThirdPartyInvite, .map_err(|_| ParseError::PublicKey(ed25519_dalek::SignatureError::new()))?,
current_third_party_invite: Option<&impl Event>, )
) -> bool { .map_err(ParseError::PublicKey)?
// 1. Check for user being banned happens before this is called .verify(c, &sig.try_into().map_err(ParseError::Signature)?)
// checking for mxid and token keys is done by ruma when deserializing .map_err(VerificationError::Signature)
.map_err(ruma::signatures::Error::from)
}
// The state key must match the invitee /// Decodes a base64 string as either URL-safe or standard base64, as per the
if target_user != Some(&tp_id.signed.mxid) { /// spec. It attempts to decode urlsafe first.
fn decode_base64(content: &str) -> Result<Vec<u8>, Base64DecodeError> {
if let Ok(decoded) = Base64::<UrlSafe>::parse(content) {
Ok(decoded.as_bytes().to_vec())
} else {
Base64::<Standard>::parse(content).map(|v| v.as_bytes().to_vec())
}
}
fn get_public_keys(event: &CanonicalJsonObject) -> Vec<Vec<u8>> {
let mut public_keys = Vec::new();
if let Some(public_key) = event.get("public_key").and_then(|v| v.as_str()) {
if let Ok(v) = decode_base64(public_key) {
trace!(
encoded = public_key,
decoded = ?v,
"found public key in public_key property of m.room.third_party_invite event",
);
public_keys.push(v);
} else {
warn!("m.room.third_party_invite event has invalid public_key");
}
}
if let Some(keys) = event.get("public_keys").and_then(|v| v.as_array()) {
for key in keys {
if let Some(key_obj) = key.as_object() {
if let Some(public_key) = key_obj.get("public_key").and_then(|v| v.as_str()) {
if let Ok(v) = decode_base64(public_key) {
trace!(
encoded = public_key,
decoded = ?v,
"found public key in public_keys list of m.room.third_party_invite \
event",
);
public_keys.push(v);
} else {
warn!(
"m.room.third_party_invite event has invalid public_key in \
public_keys list"
);
}
} else {
warn!(
"m.room.third_party_invite event has entry in public_keys list missing \
public_key property"
);
}
} else {
warn!(
"m.room.third_party_invite event has invalid entry in public_keys list, \
expected object"
);
}
}
}
public_keys
}
/// Checks a third-party invite is valid.
async fn verify_third_party_invite<F, Fut, E>(
target_current_membership: MembershipState,
raw_third_party_invite: &serde_json::Value,
target: &UserId,
event: &E,
fetch_state: &F,
) -> bool
where
F: Fn(&StateEventType, &str) -> Fut + Send + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync,
for<'a> &'a E: Event + Send + Sync,
{
// 4.1.1: If target user is banned, reject.
if target_current_membership == MembershipState::Ban {
warn!("invite target is banned");
return false; return false;
} }
// 4.1.2: If content.third_party_invite does not have a signed property, reject.
let Some(signed) = raw_third_party_invite.get("signed") else {
warn!("invite event third_party_invite missing signed property");
return false;
};
// 4.2.3: If signed does not have mxid and token properties, reject.
let Some(mxid) = signed.get("mxid").and_then(|v| v.as_str()) else {
warn!("invite event third_party_invite signed missing/invalid mxid property");
return false;
};
let Some(token) = signed.get("token").and_then(|v| v.as_str()) else {
warn!("invite event third_party_invite signed missing token property");
return false;
};
// 4.2.4: If mxid does not match state_key, reject.
if mxid != target.as_str() {
warn!("invite event third_party_invite signed mxid does not match state_key");
return false;
}
// 4.2.5: If there is no m.room.third_party_invite event in the room
// state matching the token, reject.
let Some(third_party_invite_event) =
fetch_state(&StateEventType::RoomThirdPartyInvite, token).await
else {
warn!("invite event third_party_invite token has no matching m.room.third_party_invite");
return false;
};
// 4.2.6: If sender does not match sender of the m.room.third_party_invite,
// reject.
if third_party_invite_event.sender() != event.sender() {
warn!("invite event sender does not match m.room.third_party_invite sender");
return false;
}
// 4.2.7: If any signature in signed matches any public key in the
// m.room.third_party_invite event, allow. The public keys are in
// content of m.room.third_party_invite as:
// 1. A single public key in the public_key property.
// 2. A list of public keys in the public_keys property.
debug!(
"Fetching signatures in third-party-invite event {}",
third_party_invite_event.event_id()
);
trace!("third-party-invite event content: {}", third_party_invite_event.content().get());
// If there is no m.room.third_party_invite event in the current room state with let Some(signatures) = signed.get("signatures").and_then(|v| v.as_object()) else {
// state_key matching token, reject warn!("invite event third_party_invite signed missing/invalid signatures");
#[allow(clippy::manual_let_else)] return false;
let current_tpid = match current_third_party_invite {
| Some(id) => id,
| None => return false,
}; };
if current_tpid.state_key() != Some(&tp_id.signed.token) { for pk in get_public_keys(
return false; &to_canonical_object(third_party_invite_event.content())
} .expect("m.room.third_party_invite event content is not a JSON object"),
) {
if sender != current_tpid.sender() { // signatures -> { server_name: { ed25519:N: signature } }
return false; for (server_name, server_sigs) in signatures {
} trace!("Searching for signatures from {}", server_name);
if let Some(server_sigs) = server_sigs.as_object() {
// If any signature in signed matches any public key in the for (key_id, signature_value) in server_sigs {
// m.room.third_party_invite event, allow trace!("Checking signature with key id {}", key_id);
#[allow(clippy::manual_let_else)] if let Some(signature_str) = signature_value.as_str() {
let tpid_ev = if let Ok(signature) = decode_base64(signature_str) {
match from_json_str::<RoomThirdPartyInviteEventContent>(current_tpid.content().get()) { debug!(
| Ok(ev) => ev, %server_name,
| Err(_) => return false, %key_id,
}; "verifying third-party invite signature",
);
#[allow(clippy::manual_let_else)] match verify_payload(
let decoded_invite_token = match Base64::parse(&tp_id.signed.token) { &pk,
| Ok(tok) => tok, &signature,
// FIXME: Log a warning? serde_json::to_string(&to_canonical_value(signed).unwrap())
| Err(_) => return false, .unwrap()
}; .as_bytes(),
) {
// A list of public keys in the public_keys field | Ok(()) => {
for key in tpid_ev.public_keys.unwrap_or_default() { debug!("valid third-party invite signature found");
if key.public_key == decoded_invite_token { return true;
return true; },
| Err(e) => {
warn!(
%server_name,
%key_id,
"invalid third-party invite signature: {e}",
);
},
}
}
}
}
}
} }
} }
// A single public key in the public_key field warn!("no valid signature found for third-party invite");
tpid_ev.public_key == decoded_invite_token false
}
#[cfg(test)]
mod tests {
use ruma::events::{
StateEventType, TimelineEventType,
room::{
join_rules::{
AllowRule, JoinRule, Restricted, RoomJoinRulesEventContent, RoomMembership,
},
member::{MembershipState, RoomMemberEventContent},
},
};
use serde_json::value::to_raw_value as to_raw_json_value;
use crate::{
matrix::{Event, EventTypeExt, Pdu as PduEvent},
state_res::{
RoomVersion, StateMap,
event_auth::valid_membership_change,
test_utils::{
INITIAL_EVENTS, INITIAL_EVENTS_CREATE_ROOM, alice, charlie, ella, event_id,
member_content_ban, member_content_join, room_id, to_pdu_event,
},
},
};
#[test]
fn test_ban_pass() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let events = INITIAL_EVENTS();
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
alice(),
TimelineEventType::RoomMember,
Some(charlie().as_str()),
member_content_ban(),
&[],
&["IMC"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = charlie();
let sender = alice();
assert!(
valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
#[test]
fn test_join_non_creator() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let events = INITIAL_EVENTS_CREATE_ROOM();
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
charlie(),
TimelineEventType::RoomMember,
Some(charlie().as_str()),
member_content_join(),
&["CREATE"],
&["CREATE"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = charlie();
let sender = charlie();
assert!(
!valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
#[test]
fn test_join_creator() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let events = INITIAL_EVENTS_CREATE_ROOM();
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
alice(),
TimelineEventType::RoomMember,
Some(alice().as_str()),
member_content_join(),
&["CREATE"],
&["CREATE"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = alice();
let sender = alice();
assert!(
valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
#[test]
fn test_ban_fail() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let events = INITIAL_EVENTS();
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
charlie(),
TimelineEventType::RoomMember,
Some(alice().as_str()),
member_content_ban(),
&[],
&["IMC"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = alice();
let sender = charlie();
assert!(
!valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
#[test]
fn test_restricted_join_rule() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let mut events = INITIAL_EVENTS();
*events.get_mut(&event_id("IJR")).unwrap() = to_pdu_event(
"IJR",
alice(),
TimelineEventType::RoomJoinRules,
Some(""),
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Restricted(
Restricted::new(vec![AllowRule::RoomMembership(RoomMembership::new(
room_id().to_owned(),
))]),
)))
.unwrap(),
&["CREATE", "IMA", "IPOWER"],
&["IPOWER"],
);
let mut member = RoomMemberEventContent::new(MembershipState::Join);
member.join_authorized_via_users_server = Some(alice().to_owned());
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
ella(),
TimelineEventType::RoomMember,
Some(ella().as_str()),
to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Join)).unwrap(),
&["CREATE", "IJR", "IPOWER", "new"],
&["new"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = ella();
let sender = ella();
assert!(
valid_membership_change(
&RoomVersion::V9,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
Some(alice()),
&MembershipState::Join,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
assert!(
!valid_membership_change(
&RoomVersion::V9,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
Some(ella()),
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
#[test]
fn test_knock() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let mut events = INITIAL_EVENTS();
*events.get_mut(&event_id("IJR")).unwrap() = to_pdu_event(
"IJR",
alice(),
TimelineEventType::RoomJoinRules,
Some(""),
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Knock)).unwrap(),
&["CREATE", "IMA", "IPOWER"],
&["IPOWER"],
);
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
ella(),
TimelineEventType::RoomMember,
Some(ella().as_str()),
to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Knock)).unwrap(),
&[],
&["IMC"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = ella();
let sender = ella();
assert!(
valid_membership_change(
&RoomVersion::V7,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
} }

View file

@ -717,9 +717,6 @@ where
// The key for this is (eventType + a state_key of the signed token not sender) // The key for this is (eventType + a state_key of the signed token not sender)
// so search for it // so search for it
let current_third_party = auth_state.iter().find_map(|(_, pdu)| {
(*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu)
});
let fetch_state = |ty: &StateEventType, key: &str| { let fetch_state = |ty: &StateEventType, key: &str| {
future::ready( future::ready(
@ -732,7 +729,6 @@ where
let auth_result = auth_check( let auth_result = auth_check(
room_version, room_version,
&event, &event,
current_third_party,
fetch_state, fetch_state,
&fetch_state(&StateEventType::RoomCreate, "") &fetch_state(&StateEventType::RoomCreate, "")
.await .await

View file

@ -184,7 +184,6 @@ where
let auth_check = state_res::event_auth::auth_check( let auth_check = state_res::event_auth::auth_check(
&to_room_version(&room_version_id), &to_room_version(&room_version_id),
&pdu_event, &pdu_event,
None, // TODO: third party invite
state_fetch, state_fetch,
create_event.as_pdu(), create_event.as_pdu(),
) )

View file

@ -100,7 +100,6 @@ where
let auth_check = state_res::event_auth::auth_check( let auth_check = state_res::event_auth::auth_check(
&room_version, &room_version,
&incoming_pdu, &incoming_pdu,
None, // TODO: third party invite
|ty, sk| state_fetch(ty.clone(), sk.into()), |ty, sk| state_fetch(ty.clone(), sk.into()),
create_event.as_pdu(), create_event.as_pdu(),
) )
@ -140,7 +139,6 @@ where
let auth_check = state_res::event_auth::auth_check( let auth_check = state_res::event_auth::auth_check(
&room_version, &room_version,
&incoming_pdu, &incoming_pdu,
None, // third-party invite
state_fetch, state_fetch,
create_event.as_pdu(), create_event.as_pdu(),
) )

View file

@ -236,15 +236,9 @@ pub async fn create_hash_and_sign_event(
| _ => create_pdu.as_ref().unwrap().as_pdu(), | _ => create_pdu.as_ref().unwrap().as_pdu(),
}; };
let auth_check = state_res::auth_check( let auth_check = state_res::auth_check(&room_version, &pdu, auth_fetch, create_event)
&room_version, .await
&pdu, .map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?;
None, // TODO: third_party_invite
auth_fetch,
create_event,
)
.await
.map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?;
if !auth_check { if !auth_check {
return Err!(Request(Forbidden("Event is not authorized."))); return Err!(Request(Forbidden("Event is not authorized.")));