refactor: Improve code style for bundled aggregations
This commit is contained in:
parent
8d3e4eba99
commit
8e33f9a7d0
1 changed files with 108 additions and 137 deletions
|
|
@ -33,6 +33,7 @@ impl super::Service {
|
|||
Direction::Backward,
|
||||
)
|
||||
.await;
|
||||
|
||||
// The relations database code still handles the basic unsigned data
|
||||
// We don't want to recursively fetch relations
|
||||
|
||||
|
|
@ -43,65 +44,58 @@ impl super::Service {
|
|||
return Ok(None);
|
||||
}
|
||||
|
||||
// Get the original event for validation of replacement events
|
||||
let original_event = self.services.timeline.get_pdu(event_id).await?;
|
||||
// Partition relations by type
|
||||
let (replace_events, reference_events): (Vec<_>, Vec<_>) = relations
|
||||
.iter()
|
||||
.filter_map(|relation| {
|
||||
let pdu = &relation.1;
|
||||
let content = pdu.get_content_as_value();
|
||||
|
||||
let mut replace_events = Vec::with_capacity(relations.len());
|
||||
let mut reference_events = Vec::with_capacity(relations.len());
|
||||
|
||||
for relation in &relations {
|
||||
let pdu = &relation.1;
|
||||
|
||||
let content = pdu.get_content_as_value();
|
||||
if let Some(relates_to) = content.get("m.relates_to") {
|
||||
// We don't check that the event relates back, because we assume the database is
|
||||
// good.
|
||||
if let Some(rel_type) = relates_to.get("rel_type") {
|
||||
match rel_type.as_str() {
|
||||
| Some("m.replace") => {
|
||||
// Only consider valid replacements
|
||||
if Self::is_valid_replacement_event(&original_event, pdu).await? {
|
||||
replace_events.push(relation);
|
||||
}
|
||||
},
|
||||
| Some("m.reference") => {
|
||||
reference_events.push(relation);
|
||||
},
|
||||
| _ => {
|
||||
// Ignore other relation types for now
|
||||
// Threads are in the database but not handled here
|
||||
// Other types are not specified AFAICT.
|
||||
},
|
||||
}
|
||||
content
|
||||
.get("m.relates_to")
|
||||
.and_then(|relates_to| relates_to.get("rel_type"))
|
||||
.and_then(|rel_type| rel_type.as_str())
|
||||
.and_then(|rel_type_str| match rel_type_str {
|
||||
| "m.replace" => Some(RelationType::Replace(relation)),
|
||||
| "m.reference" => Some(RelationType::Reference(relation)),
|
||||
| _ => None, /* Ignore other relation types (threads are in DB but not
|
||||
* handled here) */
|
||||
})
|
||||
})
|
||||
.fold((Vec::new(), Vec::new()), |(mut replaces, mut references), rel_type| {
|
||||
match rel_type {
|
||||
| RelationType::Replace(r) => replaces.push(r),
|
||||
| RelationType::Reference(r) => references.push(r),
|
||||
}
|
||||
}
|
||||
}
|
||||
(replaces, references)
|
||||
});
|
||||
|
||||
// If no relations to bundle, return None
|
||||
if replace_events.is_empty() && reference_events.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut bundled = BundledMessageLikeRelations::new();
|
||||
let mut bundled = BundledMessageLikeRelations::<Box<serde_json::value::RawValue>>::new();
|
||||
|
||||
// Handle m.replace relations - find the most recent one
|
||||
// Handle m.replace relations - find the most recent valid one (lazy load
|
||||
// original event)
|
||||
if !replace_events.is_empty() {
|
||||
let most_recent_replacement = Self::find_most_recent_replacement(&replace_events)?;
|
||||
let original_event = self.services.timeline.get_pdu(event_id).await?;
|
||||
|
||||
// Convert the replacement event to the bundled format
|
||||
if let Some(replacement_pdu) = most_recent_replacement {
|
||||
// According to the Matrix spec, we should include the full event as raw JSON
|
||||
let replacement_json = serde_json::to_string(replacement_pdu)
|
||||
.map_err(|e| err!(Database("Failed to serialize replacement event: {e}")))?;
|
||||
let raw_value = serde_json::value::RawValue::from_string(replacement_json)
|
||||
.map_err(|e| err!(Database("Failed to create RawValue: {e}")))?;
|
||||
bundled.replace = Some(Box::new(raw_value));
|
||||
if let Some(replacement) =
|
||||
Self::find_most_recent_valid_replacement(&original_event, &replace_events).await?
|
||||
{
|
||||
bundled.replace = Some(Self::serialize_replacement(replacement)?);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle m.reference relations - collect event IDs
|
||||
if !reference_events.is_empty() {
|
||||
let reference_chunk = Self::build_reference_chunk(&reference_events)?;
|
||||
let reference_chunk: Vec<_> = reference_events
|
||||
.into_iter()
|
||||
.map(|relation| BundledReference::new(relation.1.event_id().to_owned()))
|
||||
.collect();
|
||||
|
||||
if !reference_chunk.is_empty() {
|
||||
bundled.reference = Some(Box::new(ReferenceChunk::new(reference_chunk)));
|
||||
}
|
||||
|
|
@ -112,67 +106,48 @@ impl super::Service {
|
|||
Ok(Some(bundled))
|
||||
}
|
||||
|
||||
/// Build reference chunk for m.reference bundled aggregations
|
||||
fn build_reference_chunk(
|
||||
reference_events: &[&PdusIterItem],
|
||||
) -> Result<Vec<BundledReference>> {
|
||||
let mut chunk = Vec::with_capacity(reference_events.len());
|
||||
/// Serialize a replacement event to the bundled format
|
||||
fn serialize_replacement(pdu: &PduEvent) -> Result<Box<Box<serde_json::value::RawValue>>> {
|
||||
let replacement_json = serde_json::to_string(pdu)
|
||||
.map_err(|e| err!(Database("Failed to serialize replacement event: {e}")))?;
|
||||
|
||||
for relation in reference_events {
|
||||
let pdu = &relation.1;
|
||||
let raw_value = serde_json::value::RawValue::from_string(replacement_json)
|
||||
.map_err(|e| err!(Database("Failed to create RawValue: {e}")))?;
|
||||
|
||||
let reference_entry = BundledReference::new(pdu.event_id().to_owned());
|
||||
chunk.push(reference_entry);
|
||||
}
|
||||
|
||||
// Don't sort, order is unspecified
|
||||
|
||||
Ok(chunk)
|
||||
Ok(Box::new(raw_value))
|
||||
}
|
||||
|
||||
/// Find the most recent replacement event based on origin_server_ts and
|
||||
/// lexicographic event_id ordering
|
||||
fn find_most_recent_replacement<'a>(
|
||||
replacement_events: &'a [&'a PdusIterItem],
|
||||
/// Find the most recent valid replacement event based on origin_server_ts
|
||||
/// and lexicographic event_id ordering
|
||||
async fn find_most_recent_valid_replacement<'a>(
|
||||
original_event: &PduEvent,
|
||||
replacement_events: &[&'a PdusIterItem],
|
||||
) -> Result<Option<&'a PduEvent>> {
|
||||
if replacement_events.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
// Filter valid replacements and find the maximum in a single pass
|
||||
let mut result: Option<&PduEvent> = None;
|
||||
|
||||
let mut most_recent: Option<&PduEvent> = None;
|
||||
|
||||
// Jank, is there a better way to do this?
|
||||
for relation in replacement_events {
|
||||
let pdu = &relation.1;
|
||||
|
||||
match most_recent {
|
||||
| None => {
|
||||
most_recent = Some(pdu);
|
||||
},
|
||||
| Some(current_most_recent) => {
|
||||
// Compare by origin_server_ts first
|
||||
match pdu
|
||||
.origin_server_ts()
|
||||
.cmp(¤t_most_recent.origin_server_ts())
|
||||
{
|
||||
| std::cmp::Ordering::Greater => {
|
||||
most_recent = Some(pdu);
|
||||
},
|
||||
| std::cmp::Ordering::Equal => {
|
||||
// If timestamps are equal, use lexicographic ordering of event_id
|
||||
if pdu.event_id() > current_most_recent.event_id() {
|
||||
most_recent = Some(pdu);
|
||||
}
|
||||
},
|
||||
| std::cmp::Ordering::Less => {
|
||||
// Keep current most recent
|
||||
},
|
||||
// Validate replacement
|
||||
if !Self::is_valid_replacement_event(original_event, pdu).await? {
|
||||
continue;
|
||||
}
|
||||
|
||||
result = Some(match result {
|
||||
| None => pdu,
|
||||
| Some(current) => {
|
||||
// Compare by origin_server_ts first, then event_id lexicographically
|
||||
match pdu.origin_server_ts().cmp(¤t.origin_server_ts()) {
|
||||
| std::cmp::Ordering::Greater => pdu,
|
||||
| std::cmp::Ordering::Equal if pdu.event_id() > current.event_id() => pdu,
|
||||
| _ => current,
|
||||
}
|
||||
},
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(most_recent)
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Adds bundled aggregations to a PDU's unsigned field
|
||||
|
|
@ -200,8 +175,7 @@ impl super::Service {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Helper method to add bundled aggregations to a PDU's unsigned
|
||||
/// field
|
||||
/// Helper method to add bundled aggregations to a PDU's unsigned field
|
||||
fn add_bundled_aggregations_to_unsigned(
|
||||
pdu: &mut PduEvent,
|
||||
aggregations_json: serde_json::Value,
|
||||
|
|
@ -225,9 +199,7 @@ impl super::Service {
|
|||
.ok_or_else(|| err!(Database("m.relations is not an object")))?;
|
||||
|
||||
if let JsonValue::Object(aggregations_map) = aggregations_json {
|
||||
for (rel_type, aggregation) in aggregations_map {
|
||||
relations.insert(rel_type, aggregation);
|
||||
}
|
||||
relations.extend(aggregations_map);
|
||||
}
|
||||
|
||||
pdu.unsigned = Some(to_raw_value(&unsigned)?);
|
||||
|
|
@ -242,48 +214,47 @@ impl super::Service {
|
|||
original_event: &PduEvent,
|
||||
replacement_event: &PduEvent,
|
||||
) -> Result<bool> {
|
||||
// 1. Same room_id
|
||||
if original_event.room_id() != replacement_event.room_id() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// 2. Same sender
|
||||
if original_event.sender() != replacement_event.sender() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// 3. Same type
|
||||
if original_event.event_type() != replacement_event.event_type() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// 4. Neither event should have a state_key property
|
||||
if original_event.state_key().is_some() || replacement_event.state_key().is_some() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// 5. Original event must not have rel_type of m.replace
|
||||
let original_content = original_event.get_content_as_value();
|
||||
if let Some(relates_to) = original_content.get("m.relates_to") {
|
||||
if let Some(rel_type) = relates_to.get("rel_type") {
|
||||
if rel_type.as_str() == Some("m.replace") {
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Replacement event must have m.new_content property
|
||||
// Skip this check for encrypted events, as m.new_content would be inside the
|
||||
// encrypted payload
|
||||
if replacement_event.event_type() != &ruma::events::TimelineEventType::RoomEncrypted {
|
||||
let replacement_content = replacement_event.get_content_as_value();
|
||||
if replacement_content.get("m.new_content").is_none() {
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(true)
|
||||
Ok(
|
||||
// 1. Same room_id
|
||||
original_event.room_id() == replacement_event.room_id()
|
||||
// 2. Same sender
|
||||
&& original_event.sender() == replacement_event.sender()
|
||||
// 3. Same type
|
||||
&& original_event.event_type() == replacement_event.event_type()
|
||||
// 4. Neither event should have a state_key property
|
||||
&& original_event.state_key().is_none()
|
||||
&& replacement_event.state_key().is_none()
|
||||
// 5. Original event must not have rel_type of m.replace
|
||||
&& !Self::is_replacement_event(original_event)
|
||||
// 6. Replacement event must have m.new_content property (skip for encrypted)
|
||||
&& Self::has_new_content_or_encrypted(replacement_event),
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if an event is itself a replacement
|
||||
#[inline]
|
||||
fn is_replacement_event(event: &PduEvent) -> bool {
|
||||
event
|
||||
.get_content_as_value()
|
||||
.get("m.relates_to")
|
||||
.and_then(|relates_to| relates_to.get("rel_type"))
|
||||
.and_then(|rel_type| rel_type.as_str())
|
||||
.is_some_and(|rel_type| rel_type == "m.replace")
|
||||
}
|
||||
|
||||
/// Check if event has m.new_content or is encrypted (where m.new_content
|
||||
/// would be in the encrypted payload)
|
||||
#[inline]
|
||||
fn has_new_content_or_encrypted(event: &PduEvent) -> bool {
|
||||
event.event_type() == &ruma::events::TimelineEventType::RoomEncrypted
|
||||
|| event.get_content_as_value().get("m.new_content").is_some()
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper enum for partitioning relations
|
||||
enum RelationType<'a> {
|
||||
Replace(&'a PdusIterItem),
|
||||
Reference(&'a PdusIterItem),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue