diff --git a/src/api/client/context.rs b/src/api/client/context.rs index 38fe28b2..5775b500 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -91,7 +91,14 @@ pub(crate) async fn get_context_route( .ignore_err() .then(async |mut pdu| { pdu.1.set_unsigned(Some(sender_user)); - // TODO: bundled aggregations + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } pdu }) .ready_filter_map(|item| event_filter(item, filter)) @@ -107,7 +114,14 @@ pub(crate) async fn get_context_route( .ignore_err() .then(async |mut pdu| { pdu.1.set_unsigned(Some(sender_user)); - // TODO: bundled aggregations + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } pdu }) .ready_filter_map(|item| event_filter(item, filter)) diff --git a/src/api/client/message.rs b/src/api/client/message.rs index 5a0b65b9..89db84d7 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -1,7 +1,7 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduwuit::{ - Err, Result, at, + Err, Result, at, debug_warn, matrix::{ event::{Event, Matches}, pdu::PduCount, @@ -142,7 +142,14 @@ pub(crate) async fn get_message_events_route( .take(limit) .then(async |mut pdu| { pdu.1.set_unsigned(Some(sender_user)); - // TODO: bundled aggregations + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } pdu }) .collect() diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index 3586631c..eace0051 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -1,6 +1,6 @@ use axum::extract::State; use conduwuit::{ - Result, at, + Result, at, debug_warn, matrix::{Event, event::RelationTypeEqual, pdu::PduCount}, utils::{IterStream, ReadyExt, result::FlatOk, stream::WidebandExt}, }; @@ -154,6 +154,17 @@ async fn paginate_relations_with_filter( .ready_take_while(|(count, _)| Some(*count) != to) .wide_filter_map(|item| visibility_filter(services, sender_user, item)) .take(limit) + .then(async |mut pdu| { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations to relation: {e}"); + } + pdu + }) .collect() .await; diff --git a/src/api/client/room/event.rs b/src/api/client/room/event.rs index baafafd6..6aca48a5 100644 --- a/src/api/client/room/event.rs +++ b/src/api/client/room/event.rs @@ -1,5 +1,5 @@ use axum::extract::State; -use conduwuit::{Err, Event, Result, err}; +use conduwuit::{Err, Event, Result, debug_warn, err}; use futures::{FutureExt, TryFutureExt, future::try_join}; use ruma::api::client::room::get_room_event; @@ -33,6 +33,15 @@ pub(crate) async fn get_room_event_route( return Err!(Request(Forbidden("You don't have permission to view this event."))); } + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(body.sender_user(), &mut event) + .await + { + debug_warn!("Failed to add bundled aggregations to event: {e}"); + } + event.set_unsigned(body.sender_user.as_deref()); Ok(get_room_event::v3::Response { event: event.into_format() }) diff --git a/src/api/client/room/initial_sync.rs b/src/api/client/room/initial_sync.rs index e5d8a8e8..7d919d60 100644 --- a/src/api/client/room/initial_sync.rs +++ b/src/api/client/room/initial_sync.rs @@ -1,6 +1,6 @@ use axum::extract::State; use conduwuit::{ - Err, Event, Result, at, + Err, Event, Result, at, debug_warn, utils::{BoolExt, stream::TryTools}, }; use futures::{FutureExt, TryStreamExt, future::try_join4}; @@ -50,7 +50,16 @@ pub(crate) async fn room_initial_sync_route( .try_take(limit) .and_then(async |mut pdu| { pdu.1.set_unsigned(body.sender_user.as_deref()); - // TODO: bundled aggregations + if let Some(sender_user) = body.sender_user.as_deref() { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } + } Ok(pdu) }) .try_collect::>(); diff --git a/src/api/client/search.rs b/src/api/client/search.rs index cc745694..794a6731 100644 --- a/src/api/client/search.rs +++ b/src/api/client/search.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use axum::extract::State; use conduwuit::{ - Err, Result, at, is_true, + Err, Result, at, debug_warn, is_true, matrix::Event, result::FlatOk, utils::{IterStream, stream::ReadyExt}, @@ -144,6 +144,17 @@ async fn category_room_events( .map(at!(2)) .flatten() .stream() + .then(|mut pdu| async { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu) + .await + { + debug_warn!("Failed to add bundled aggregations to search result: {e}"); + } + pdu + }) .map(Event::into_format) .map(|result| SearchResult { rank: None, diff --git a/src/api/client/sync/mod.rs b/src/api/client/sync/mod.rs index 83ff34fd..218d5ae0 100644 --- a/src/api/client/sync/mod.rs +++ b/src/api/client/sync/mod.rs @@ -4,7 +4,7 @@ mod v5; use std::collections::VecDeque; use conduwuit::{ - Event, PduCount, Result, err, + Event, PduCount, Result, debug_warn, err, matrix::pdu::PduEvent, ref_at, trace, utils::stream::{BroadbandExt, ReadyExt, TryIgnore}, @@ -73,12 +73,22 @@ async fn load_timeline( .timeline .pdus_rev(room_id, ending_count.map(|count| count.saturating_add(1))) .ignore_err() + .ready_take_while(move |&(pducount, _)| pducount > starting_count) .map(move |mut pdu| { pdu.1.set_unsigned(Some(sender_user)); - // TODO: bundled aggregations pdu }) - .ready_take_while(move |&(pducount, _)| pducount > starting_count) + .then(async move |mut pdu| { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } + pdu + }) .boxed() }, | None => { @@ -91,7 +101,17 @@ async fn load_timeline( .ignore_err() .map(move |mut pdu| { pdu.1.set_unsigned(Some(sender_user)); - // TODO: bundled aggregations + pdu + }) + .then(async move |mut pdu| { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } pdu }) .boxed() diff --git a/src/api/client/threads.rs b/src/api/client/threads.rs index 19cca5f7..32da2adb 100644 --- a/src/api/client/threads.rs +++ b/src/api/client/threads.rs @@ -1,6 +1,6 @@ use axum::extract::State; use conduwuit::{ - Result, at, + Result, at, debug_warn, matrix::{ Event, pdu::{PduCount, PduEvent}, @@ -31,7 +31,6 @@ pub(crate) async fn get_threads_route( .transpose()? .unwrap_or_else(PduCount::max); - // TODO: bundled aggregation // TODO: user_can_see_event and set_unsigned should be at the same level / // function, so unsigned is only set for seen events. let threads: Vec<(PduCount, PduEvent)> = services @@ -48,6 +47,17 @@ pub(crate) async fn get_threads_route( .await .then_some((count, pdu)) }) + .then(|(count, mut pdu)| async move { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(body.sender_user(), &mut pdu) + .await + { + debug_warn!("Failed to add bundled aggregations to thread: {e}"); + } + (count, pdu) + }) .collect() .await; diff --git a/src/service/rooms/pdu_metadata/bundled_aggregations.rs b/src/service/rooms/pdu_metadata/bundled_aggregations.rs new file mode 100644 index 00000000..10309858 --- /dev/null +++ b/src/service/rooms/pdu_metadata/bundled_aggregations.rs @@ -0,0 +1,394 @@ +use conduwuit::{Event, PduEvent, Result, err}; +use ruma::{ + EventId, RoomId, UserId, + api::Direction, + events::relation::{BundledMessageLikeRelations, BundledReference, ReferenceChunk}, +}; + +use crate::rooms::timeline::PdusIterItem; + +const MAX_BUNDLED_RELATIONS: usize = 50; + +impl super::Service { + /// Gets bundled aggregations for an event according to the Matrix + /// specification. + /// - m.replace relations are bundled to include the most recent replacement + /// event. + /// - m.reference relations are bundled to include a chunk of event IDs. + #[tracing::instrument(skip(self), level = "debug")] + pub async fn get_bundled_aggregations( + &self, + user_id: &UserId, + room_id: &RoomId, + event_id: &EventId, + ) -> Result>>> { + let relations = self + .get_relations( + user_id, + room_id, + event_id, + conduwuit::PduCount::max(), + MAX_BUNDLED_RELATIONS, + 0, + Direction::Backward, + ) + .await; + // The relations database code still handles the basic unsigned data + // We don't want to recursively fetch relations + + // TODO: Event visibility check + // TODO: ignored users? + + if relations.is_empty() { + return Ok(None); + } + + let mut replace_events = Vec::with_capacity(relations.len().min(10)); // Most events have few replacements + let mut reference_events = Vec::with_capacity(relations.len()); + + for relation in &relations { + let pdu = &relation.1; + + let content = pdu.get_content_as_value(); + if let Some(relates_to) = content.get("m.relates_to") { + // We don't check that the event relates back, because we assume the database is + // good. + if let Some(rel_type) = relates_to.get("rel_type") { + match rel_type.as_str() { + | Some("m.replace") => { + replace_events.push(relation); + }, + | Some("m.reference") => { + reference_events.push(relation); + }, + | _ => { + // Ignore other relation types for now + // Threads are in the database but not handled here + // Other types are not specified AFAICT. + }, + } + } + } + } + + // If no relations to bundle, return None + if replace_events.is_empty() && reference_events.is_empty() { + return Ok(None); + } + + let mut bundled = BundledMessageLikeRelations::new(); + + // Handle m.replace relations - find the most recent one + if !replace_events.is_empty() { + let most_recent_replacement = Self::find_most_recent_replacement(&replace_events)?; + + // Convert the replacement event to the bundled format + if let Some(replacement_pdu) = most_recent_replacement { + // According to the Matrix spec, we should include the full event as raw JSON + let replacement_json = serde_json::to_string(replacement_pdu) + .map_err(|e| err!(Database("Failed to serialize replacement event: {e}")))?; + let raw_value = serde_json::value::RawValue::from_string(replacement_json) + .map_err(|e| err!(Database("Failed to create RawValue: {e}")))?; + bundled.replace = Some(Box::new(raw_value)); + } + } + + // Handle m.reference relations - collect event IDs + if !reference_events.is_empty() { + let reference_chunk = Self::build_reference_chunk(&reference_events)?; + if !reference_chunk.is_empty() { + bundled.reference = Some(Box::new(ReferenceChunk::new(reference_chunk))); + } + } + + // TODO: Handle other relation types (m.annotation, etc.) when specified + + Ok(Some(bundled)) + } + + /// Build reference chunk for m.reference bundled aggregations + fn build_reference_chunk( + reference_events: &[&PdusIterItem], + ) -> Result> { + let mut chunk = Vec::with_capacity(reference_events.len()); + + for relation in reference_events { + let pdu = &relation.1; + + let reference_entry = BundledReference::new(pdu.event_id().to_owned()); + chunk.push(reference_entry); + } + + // Don't sort, order is unspecified + + Ok(chunk) + } + + /// Find the most recent replacement event based on origin_server_ts and + /// lexicographic event_id ordering + fn find_most_recent_replacement<'a>( + replacement_events: &'a [&'a PdusIterItem], + ) -> Result> { + if replacement_events.is_empty() { + return Ok(None); + } + + let mut most_recent: Option<&PduEvent> = None; + + // Jank, is there a better way to do this? + for relation in replacement_events { + let pdu = &relation.1; + + match most_recent { + | None => { + most_recent = Some(pdu); + }, + | Some(current_most_recent) => { + // Compare by origin_server_ts first + match pdu + .origin_server_ts() + .cmp(¤t_most_recent.origin_server_ts()) + { + | std::cmp::Ordering::Greater => { + most_recent = Some(pdu); + }, + | std::cmp::Ordering::Equal => { + // If timestamps are equal, use lexicographic ordering of event_id + if pdu.event_id() > current_most_recent.event_id() { + most_recent = Some(pdu); + } + }, + | std::cmp::Ordering::Less => { + // Keep current most recent + }, + } + }, + } + } + + Ok(most_recent) + } + + /// Adds bundled aggregations to a PDU's unsigned field + #[tracing::instrument(skip(self, pdu), level = "debug")] + pub async fn add_bundled_aggregations_to_pdu( + &self, + user_id: &UserId, + pdu: &mut PduEvent, + ) -> Result<()> { + if pdu.is_redacted() { + return Ok(()); + } + + let bundled_aggregations = self + .get_bundled_aggregations(user_id, &pdu.room_id_or_hash(), pdu.event_id()) + .await?; + + if let Some(aggregations) = bundled_aggregations { + let aggregations_json = serde_json::to_value(aggregations) + .map_err(|e| err!(Database("Failed to serialize bundled aggregations: {e}")))?; + + Self::add_bundled_aggregations_to_unsigned(pdu, aggregations_json)?; + } + + Ok(()) + } + + /// Helper method to add bundled aggregations to a PDU's unsigned + /// field + fn add_bundled_aggregations_to_unsigned( + pdu: &mut PduEvent, + aggregations_json: serde_json::Value, + ) -> Result<()> { + use serde_json::{ + Map, Value as JsonValue, + value::{RawValue as RawJsonValue, to_raw_value}, + }; + + let mut unsigned: Map = pdu + .unsigned + .as_deref() + .map(RawJsonValue::get) + .map_or_else(|| Ok(Map::new()), serde_json::from_str) + .map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; + + let relations = unsigned + .entry("m.relations") + .or_insert_with(|| JsonValue::Object(Map::new())) + .as_object_mut() + .ok_or_else(|| err!(Database("m.relations is not an object")))?; + + if let JsonValue::Object(aggregations_map) = aggregations_json { + for (rel_type, aggregation) in aggregations_map { + relations.insert(rel_type, aggregation); + } + } + + pdu.unsigned = Some(to_raw_value(&unsigned)?); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use conduwuit_core::pdu::{EventHash, PduEvent}; + use ruma::{UInt, events::TimelineEventType, owned_event_id, owned_room_id, owned_user_id}; + use serde_json::{Value as JsonValue, json, value::to_raw_value}; + + fn create_test_pdu(unsigned_content: Option) -> PduEvent { + PduEvent { + event_id: owned_event_id!("$test:example.com"), + room_id: Some(owned_room_id!("!test:example.com")), + sender: owned_user_id!("@test:example.com"), + origin_server_ts: UInt::try_from(1_234_567_890_u64).unwrap(), + kind: TimelineEventType::RoomMessage, + content: to_raw_value(&json!({"msgtype": "m.text", "body": "test"})).unwrap(), + state_key: None, + prev_events: vec![], + depth: UInt::from(1_u32), + auth_events: vec![], + redacts: None, + unsigned: unsigned_content.map(|content| to_raw_value(&content).unwrap()), + hashes: EventHash { sha256: "test_hash".to_owned() }, + signatures: None, + origin: None, + } + } + + fn create_bundled_aggregations() -> JsonValue { + json!({ + "m.replace": { + "event_id": "$replace:example.com", + "origin_server_ts": 1_234_567_890, + "sender": "@replacer:example.com" + }, + "m.reference": { + "count": 5, + "chunk": [ + "$ref1:example.com", + "$ref2:example.com" + ] + } + }) + } + + #[test] + fn test_add_bundled_aggregations_to_unsigned_no_existing_unsigned() { + let mut pdu = create_test_pdu(None); + let aggregations = create_bundled_aggregations(); + + let result = super::super::Service::add_bundled_aggregations_to_unsigned( + &mut pdu, + aggregations.clone(), + ); + assert!(result.is_ok(), "Should succeed when no unsigned field exists"); + + assert!(pdu.unsigned.is_some(), "Unsigned field should be created"); + + let unsigned_str = pdu.unsigned.as_ref().unwrap().get(); + let unsigned: JsonValue = serde_json::from_str(unsigned_str).unwrap(); + + assert!(unsigned.get("m.relations").is_some(), "m.relations should exist"); + assert_eq!( + unsigned["m.relations"], aggregations, + "Relations should match the aggregations" + ); + } + + #[test] + fn test_add_bundled_aggregations_to_unsigned_overwrite_same_relation_type() { + let existing_unsigned = json!({ + "m.relations": { + "m.replace": { + "event_id": "$old_replace:example.com", + "origin_server_ts": 1_111_111_111, + "sender": "@old_replacer:example.com" + } + } + }); + + let mut pdu = create_test_pdu(Some(existing_unsigned)); + let new_aggregations = create_bundled_aggregations(); + + let result = super::super::Service::add_bundled_aggregations_to_unsigned( + &mut pdu, + new_aggregations.clone(), + ); + assert!(result.is_ok(), "Should succeed when overwriting same relation type"); + + let unsigned_str = pdu.unsigned.as_ref().unwrap().get(); + let unsigned: JsonValue = serde_json::from_str(unsigned_str).unwrap(); + + let relations = &unsigned["m.relations"]; + + assert_eq!( + relations["m.replace"], new_aggregations["m.replace"], + "m.replace should be updated" + ); + assert_eq!( + relations["m.replace"]["event_id"], "$replace:example.com", + "Should have new event_id" + ); + + assert!(relations.get("m.reference").is_some(), "New m.reference should be added"); + } + + #[test] + fn test_add_bundled_aggregations_to_unsigned_preserve_other_unsigned_fields() { + // Test case: Other unsigned fields should be preserved + let existing_unsigned = json!({ + "age": 98765, + "prev_content": {"msgtype": "m.text", "body": "old message"}, + "redacted_because": {"event_id": "$redaction:example.com"}, + "m.relations": { + "m.annotation": {"count": 1} + } + }); + + let mut pdu = create_test_pdu(Some(existing_unsigned)); + let new_aggregations = json!({ + "m.replace": {"event_id": "$new:example.com"} + }); + + let result = super::super::Service::add_bundled_aggregations_to_unsigned( + &mut pdu, + new_aggregations, + ); + assert!(result.is_ok(), "Should succeed while preserving other fields"); + + let unsigned_str = pdu.unsigned.as_ref().unwrap().get(); + let unsigned: JsonValue = serde_json::from_str(unsigned_str).unwrap(); + + // Verify all existing fields are preserved + assert_eq!(unsigned["age"], 98765, "age should be preserved"); + assert!(unsigned.get("prev_content").is_some(), "prev_content should be preserved"); + assert!( + unsigned.get("redacted_because").is_some(), + "redacted_because should be preserved" + ); + + // Verify relations were merged correctly + let relations = &unsigned["m.relations"]; + assert!( + relations.get("m.annotation").is_some(), + "Existing m.annotation should be preserved" + ); + assert!(relations.get("m.replace").is_some(), "New m.replace should be added"); + } + + #[test] + fn test_add_bundled_aggregations_to_unsigned_invalid_existing_unsigned() { + // Test case: Invalid JSON in existing unsigned should result in error + let mut pdu = create_test_pdu(None); + // Manually set invalid unsigned data + pdu.unsigned = Some(to_raw_value(&"invalid json").unwrap()); + + let aggregations = create_bundled_aggregations(); + let result = + super::super::Service::add_bundled_aggregations_to_unsigned(&mut pdu, aggregations); + + assert!(result.is_err(), "fails when existing unsigned is invalid"); + // Should we ignore the error and overwrite anyway? + } +} diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index 6a765f13..11a69e05 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -14,10 +14,11 @@ use futures::{Stream, StreamExt}; use ruma::{EventId, RoomId, UserId, api::Direction}; use crate::{ - Dep, rooms, + Dep, rooms::{ + self, short::{ShortEventId, ShortRoomId}, - timeline::{PduId, RawPduId}, + timeline::{PduId, PdusIterItem, RawPduId}, }, }; @@ -59,7 +60,7 @@ impl Data { target: ShortEventId, from: PduCount, dir: Direction, - ) -> impl Stream + Send + 'a { + ) -> impl Stream + Send + 'a { // Query from exact position then filter excludes it (saturating_inc could skip // events at min/max boundaries) let from_unsigned = from.into_unsigned(); diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index c8e863fa..e1f351a3 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -1,15 +1,16 @@ +mod bundled_aggregations; mod data; use std::sync::Arc; -use conduwuit::{ - Result, - matrix::{Event, PduCount}, -}; +use conduwuit::{Result, matrix::PduCount}; use futures::{StreamExt, future::try_join}; use ruma::{EventId, RoomId, UserId, api::Direction}; use self::data::Data; -use crate::{Dep, rooms}; +use crate::{ + Dep, + rooms::{self, timeline::PdusIterItem}, +}; pub struct Service { services: Services, @@ -56,7 +57,7 @@ impl Service { limit: usize, max_depth: u8, dir: Direction, - ) -> Vec<(PduCount, impl Event)> { + ) -> Vec { let room_id = self.services.short.get_shortroomid(room_id); let target = self.services.timeline.get_pdu_count(target); diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index ac59cd81..259c8889 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use conduwuit::{ - PduCount, Result, + PduCount, PduEvent, Result, arrayvec::ArrayVec, implement, matrix::event::{Event, Matches}, @@ -104,7 +104,7 @@ pub fn deindex_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_b pub async fn search_pdus<'a>( &'a self, query: &'a RoomQuery<'a>, -) -> Result<(usize, impl Stream> + Send + 'a)> { +) -> Result<(usize, impl Stream + Send + 'a)> { let pdu_ids: Vec<_> = self.search_pdu_ids(query).await?.collect().await; let filter = &query.criteria.filter;