diff --git a/src/core/matrix/state_res/mod.rs b/src/core/matrix/state_res/mod.rs index 11c754e9..0b2c9e20 100644 --- a/src/core/matrix/state_res/mod.rs +++ b/src/core/matrix/state_res/mod.rs @@ -101,40 +101,40 @@ where debug!(version = ?stateres_version, "State resolution starting"); // Split non-conflicting and conflicting state - let (clean, conflicting) = separate(state_sets.into_iter()); + let (unconflicted, conflicting) = separate(state_sets.into_iter()); - debug!(count = clean.len(), "non-conflicting events"); - trace!(map = ?clean, "non-conflicting events"); + debug!(count = unconflicted.len(), "non-conflicting events"); + trace!(map = ?unconflicted, "non-conflicting events"); if conflicting.is_empty() { debug!("no conflicting state found"); - return Ok(clean); + return Ok(unconflicted); } debug!(count = conflicting.len(), "conflicting events"); trace!(map = ?conflicting, "conflicting events"); - let conflicted_state_subgraph: HashSet<_> = match stateres_version { - | StateResolutionVersion::V2_1 => - calculate_conflicted_subgraph(&conflicting, event_fetch) + let (conflicted_state_subgraph, initial_state) = + if stateres_version == StateResolutionVersion::V2_1 { + let csg = calculate_conflicted_subgraph(&conflicting, event_fetch) .await .ok_or_else(|| { Error::InvalidPdu("Failed to calculate conflicted subgraph".to_owned()) - })?, - | _ => HashSet::new(), - }; - debug!(count = conflicted_state_subgraph.len(), "conflicted subgraph"); - trace!(set = ?conflicted_state_subgraph, "conflicted subgraph"); - - let conflicting_values = conflicting.into_values().flatten().stream(); + })?; + debug!(count = csg.len(), "conflicted subgraph"); + trace!(set = ?csg, "conflicted subgraph"); + (csg, HashMap::new()) + } else { + (HashSet::new(), unconflicted.clone()) + }; // `all_conflicted` contains unique items // synapse says `full_set = {eid for eid in full_conflicted_set if eid in // event_map}` // Hydra: Also consider the conflicted state subgraph let all_conflicted: HashSet<_> = get_auth_chain_diff(auth_chain_sets) - .chain(conflicting_values) - .chain(conflicted_state_subgraph.into_iter().stream()) + .chain(conflicting.into_values().flatten().stream()) .broad_filter_map(async |id| event_exists(id.clone()).await.then_some(id)) + .chain(conflicted_state_subgraph.into_iter().stream()) .collect() .await; @@ -171,7 +171,7 @@ where &room_version, &stateres_version, sorted_control_levels.iter().stream().map(AsRef::as_ref), - clean.clone(), + initial_state, &event_fetch, ) .await?; @@ -201,7 +201,7 @@ where let power_levels_ty_sk = (StateEventType::RoomPowerLevels, StateKey::new()); let power_event = resolved_control.get(&power_levels_ty_sk); - debug!(event_id = ?power_event, "power event"); + trace!(event_id = ?power_event, "power event"); let sorted_left_events = mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch).await?; @@ -212,13 +212,13 @@ where &room_version, &stateres_version, sorted_left_events.iter().stream().map(AsRef::as_ref), - resolved_control.clone(), // The control events are added to the final resolved state + resolved_control, // The control events are added to the final resolved state &event_fetch, ) .await?; // Ensure unconflicting state is in the final state - resolved_state.extend(clean); + resolved_state.extend(unconflicted); debug!("state resolution finished"); trace!( map = ?resolved_state, "final resolved state" );