From 8e33f9a7d08a863795190ce3088b87e2908c410f Mon Sep 17 00:00:00 2001 From: Jade Ellis Date: Thu, 18 Dec 2025 21:10:14 +0000 Subject: [PATCH] refactor: Improve code style for bundled aggregations --- .../pdu_metadata/bundled_aggregations.rs | 245 ++++++++---------- 1 file changed, 108 insertions(+), 137 deletions(-) diff --git a/src/service/rooms/pdu_metadata/bundled_aggregations.rs b/src/service/rooms/pdu_metadata/bundled_aggregations.rs index 5f3d1af6..570f54fd 100644 --- a/src/service/rooms/pdu_metadata/bundled_aggregations.rs +++ b/src/service/rooms/pdu_metadata/bundled_aggregations.rs @@ -33,6 +33,7 @@ impl super::Service { Direction::Backward, ) .await; + // The relations database code still handles the basic unsigned data // We don't want to recursively fetch relations @@ -43,65 +44,58 @@ impl super::Service { return Ok(None); } - // Get the original event for validation of replacement events - let original_event = self.services.timeline.get_pdu(event_id).await?; + // Partition relations by type + let (replace_events, reference_events): (Vec<_>, Vec<_>) = relations + .iter() + .filter_map(|relation| { + let pdu = &relation.1; + let content = pdu.get_content_as_value(); - let mut replace_events = Vec::with_capacity(relations.len()); - let mut reference_events = Vec::with_capacity(relations.len()); - - for relation in &relations { - let pdu = &relation.1; - - let content = pdu.get_content_as_value(); - if let Some(relates_to) = content.get("m.relates_to") { - // We don't check that the event relates back, because we assume the database is - // good. - if let Some(rel_type) = relates_to.get("rel_type") { - match rel_type.as_str() { - | Some("m.replace") => { - // Only consider valid replacements - if Self::is_valid_replacement_event(&original_event, pdu).await? { - replace_events.push(relation); - } - }, - | Some("m.reference") => { - reference_events.push(relation); - }, - | _ => { - // Ignore other relation types for now - // Threads are in the database but not handled here - // Other types are not specified AFAICT. - }, - } + content + .get("m.relates_to") + .and_then(|relates_to| relates_to.get("rel_type")) + .and_then(|rel_type| rel_type.as_str()) + .and_then(|rel_type_str| match rel_type_str { + | "m.replace" => Some(RelationType::Replace(relation)), + | "m.reference" => Some(RelationType::Reference(relation)), + | _ => None, /* Ignore other relation types (threads are in DB but not + * handled here) */ + }) + }) + .fold((Vec::new(), Vec::new()), |(mut replaces, mut references), rel_type| { + match rel_type { + | RelationType::Replace(r) => replaces.push(r), + | RelationType::Reference(r) => references.push(r), } - } - } + (replaces, references) + }); // If no relations to bundle, return None if replace_events.is_empty() && reference_events.is_empty() { return Ok(None); } - let mut bundled = BundledMessageLikeRelations::new(); + let mut bundled = BundledMessageLikeRelations::>::new(); - // Handle m.replace relations - find the most recent one + // Handle m.replace relations - find the most recent valid one (lazy load + // original event) if !replace_events.is_empty() { - let most_recent_replacement = Self::find_most_recent_replacement(&replace_events)?; + let original_event = self.services.timeline.get_pdu(event_id).await?; - // Convert the replacement event to the bundled format - if let Some(replacement_pdu) = most_recent_replacement { - // According to the Matrix spec, we should include the full event as raw JSON - let replacement_json = serde_json::to_string(replacement_pdu) - .map_err(|e| err!(Database("Failed to serialize replacement event: {e}")))?; - let raw_value = serde_json::value::RawValue::from_string(replacement_json) - .map_err(|e| err!(Database("Failed to create RawValue: {e}")))?; - bundled.replace = Some(Box::new(raw_value)); + if let Some(replacement) = + Self::find_most_recent_valid_replacement(&original_event, &replace_events).await? + { + bundled.replace = Some(Self::serialize_replacement(replacement)?); } } // Handle m.reference relations - collect event IDs if !reference_events.is_empty() { - let reference_chunk = Self::build_reference_chunk(&reference_events)?; + let reference_chunk: Vec<_> = reference_events + .into_iter() + .map(|relation| BundledReference::new(relation.1.event_id().to_owned())) + .collect(); + if !reference_chunk.is_empty() { bundled.reference = Some(Box::new(ReferenceChunk::new(reference_chunk))); } @@ -112,67 +106,48 @@ impl super::Service { Ok(Some(bundled)) } - /// Build reference chunk for m.reference bundled aggregations - fn build_reference_chunk( - reference_events: &[&PdusIterItem], - ) -> Result> { - let mut chunk = Vec::with_capacity(reference_events.len()); + /// Serialize a replacement event to the bundled format + fn serialize_replacement(pdu: &PduEvent) -> Result>> { + let replacement_json = serde_json::to_string(pdu) + .map_err(|e| err!(Database("Failed to serialize replacement event: {e}")))?; - for relation in reference_events { - let pdu = &relation.1; + let raw_value = serde_json::value::RawValue::from_string(replacement_json) + .map_err(|e| err!(Database("Failed to create RawValue: {e}")))?; - let reference_entry = BundledReference::new(pdu.event_id().to_owned()); - chunk.push(reference_entry); - } - - // Don't sort, order is unspecified - - Ok(chunk) + Ok(Box::new(raw_value)) } - /// Find the most recent replacement event based on origin_server_ts and - /// lexicographic event_id ordering - fn find_most_recent_replacement<'a>( - replacement_events: &'a [&'a PdusIterItem], + /// Find the most recent valid replacement event based on origin_server_ts + /// and lexicographic event_id ordering + async fn find_most_recent_valid_replacement<'a>( + original_event: &PduEvent, + replacement_events: &[&'a PdusIterItem], ) -> Result> { - if replacement_events.is_empty() { - return Ok(None); - } + // Filter valid replacements and find the maximum in a single pass + let mut result: Option<&PduEvent> = None; - let mut most_recent: Option<&PduEvent> = None; - - // Jank, is there a better way to do this? for relation in replacement_events { let pdu = &relation.1; - match most_recent { - | None => { - most_recent = Some(pdu); - }, - | Some(current_most_recent) => { - // Compare by origin_server_ts first - match pdu - .origin_server_ts() - .cmp(¤t_most_recent.origin_server_ts()) - { - | std::cmp::Ordering::Greater => { - most_recent = Some(pdu); - }, - | std::cmp::Ordering::Equal => { - // If timestamps are equal, use lexicographic ordering of event_id - if pdu.event_id() > current_most_recent.event_id() { - most_recent = Some(pdu); - } - }, - | std::cmp::Ordering::Less => { - // Keep current most recent - }, + // Validate replacement + if !Self::is_valid_replacement_event(original_event, pdu).await? { + continue; + } + + result = Some(match result { + | None => pdu, + | Some(current) => { + // Compare by origin_server_ts first, then event_id lexicographically + match pdu.origin_server_ts().cmp(¤t.origin_server_ts()) { + | std::cmp::Ordering::Greater => pdu, + | std::cmp::Ordering::Equal if pdu.event_id() > current.event_id() => pdu, + | _ => current, } }, - } + }); } - Ok(most_recent) + Ok(result) } /// Adds bundled aggregations to a PDU's unsigned field @@ -200,8 +175,7 @@ impl super::Service { Ok(()) } - /// Helper method to add bundled aggregations to a PDU's unsigned - /// field + /// Helper method to add bundled aggregations to a PDU's unsigned field fn add_bundled_aggregations_to_unsigned( pdu: &mut PduEvent, aggregations_json: serde_json::Value, @@ -225,9 +199,7 @@ impl super::Service { .ok_or_else(|| err!(Database("m.relations is not an object")))?; if let JsonValue::Object(aggregations_map) = aggregations_json { - for (rel_type, aggregation) in aggregations_map { - relations.insert(rel_type, aggregation); - } + relations.extend(aggregations_map); } pdu.unsigned = Some(to_raw_value(&unsigned)?); @@ -242,48 +214,47 @@ impl super::Service { original_event: &PduEvent, replacement_event: &PduEvent, ) -> Result { - // 1. Same room_id - if original_event.room_id() != replacement_event.room_id() { - return Ok(false); - } - - // 2. Same sender - if original_event.sender() != replacement_event.sender() { - return Ok(false); - } - - // 3. Same type - if original_event.event_type() != replacement_event.event_type() { - return Ok(false); - } - - // 4. Neither event should have a state_key property - if original_event.state_key().is_some() || replacement_event.state_key().is_some() { - return Ok(false); - } - - // 5. Original event must not have rel_type of m.replace - let original_content = original_event.get_content_as_value(); - if let Some(relates_to) = original_content.get("m.relates_to") { - if let Some(rel_type) = relates_to.get("rel_type") { - if rel_type.as_str() == Some("m.replace") { - return Ok(false); - } - } - } - - // 6. Replacement event must have m.new_content property - // Skip this check for encrypted events, as m.new_content would be inside the - // encrypted payload - if replacement_event.event_type() != &ruma::events::TimelineEventType::RoomEncrypted { - let replacement_content = replacement_event.get_content_as_value(); - if replacement_content.get("m.new_content").is_none() { - return Ok(false); - } - } - - Ok(true) + Ok( + // 1. Same room_id + original_event.room_id() == replacement_event.room_id() + // 2. Same sender + && original_event.sender() == replacement_event.sender() + // 3. Same type + && original_event.event_type() == replacement_event.event_type() + // 4. Neither event should have a state_key property + && original_event.state_key().is_none() + && replacement_event.state_key().is_none() + // 5. Original event must not have rel_type of m.replace + && !Self::is_replacement_event(original_event) + // 6. Replacement event must have m.new_content property (skip for encrypted) + && Self::has_new_content_or_encrypted(replacement_event), + ) } + + /// Check if an event is itself a replacement + #[inline] + fn is_replacement_event(event: &PduEvent) -> bool { + event + .get_content_as_value() + .get("m.relates_to") + .and_then(|relates_to| relates_to.get("rel_type")) + .and_then(|rel_type| rel_type.as_str()) + .is_some_and(|rel_type| rel_type == "m.replace") + } + + /// Check if event has m.new_content or is encrypted (where m.new_content + /// would be in the encrypted payload) + #[inline] + fn has_new_content_or_encrypted(event: &PduEvent) -> bool { + event.event_type() == &ruma::events::TimelineEventType::RoomEncrypted + || event.get_content_as_value().get("m.new_content").is_some() + } +} + +/// Helper enum for partitioning relations +enum RelationType<'a> { + Replace(&'a PdusIterItem), + Reference(&'a PdusIterItem), } #[cfg(test)]