From 37888fb67098c7dfba680cea122bc72c3a970406 Mon Sep 17 00:00:00 2001 From: timedout Date: Tue, 3 Mar 2026 19:54:34 +0000 Subject: [PATCH] fix: Limit body read size of remote requests (CWE-409) Reviewed-By: Jade Ellis --- Cargo.lock | 15 +++++++++ Cargo.toml | 1 + clippy.toml | 18 +++++++++-- src/admin/federation/commands.rs | 12 +++++-- src/core/utils/mod.rs | 1 + src/core/utils/response.rs | 51 +++++++++++++++++++++++++++++ src/service/announcements/mod.rs | 4 +-- src/service/federation/execute.rs | 36 +++++++++++++++------ src/service/media/preview.rs | 52 ++++++++++++++++++------------ src/service/media/remote.rs | 13 +++++--- src/service/pusher/mod.rs | 15 +++++++-- src/service/resolver/well_known.rs | 10 +++--- src/service/sending/antispam.rs | 4 +-- src/service/sending/appservice.rs | 14 ++++++-- 14 files changed, 192 insertions(+), 54 deletions(-) create mode 100644 src/core/utils/response.rs diff --git a/Cargo.lock b/Cargo.lock index 027fc558..a3be409b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4055,12 +4055,14 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-rustls", + "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", ] @@ -5868,6 +5870,19 @@ dependencies = [ "wasmparser", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "wasmparser" version = "0.244.0" diff --git a/Cargo.toml b/Cargo.toml index 2a9ab25e..1236953e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -144,6 +144,7 @@ features = [ "socks", "hickory-dns", "http2", + "stream", ] [workspace.dependencies.serde] diff --git a/clippy.toml b/clippy.toml index 863759aa..7ee0fd45 100644 --- a/clippy.toml +++ b/clippy.toml @@ -15,6 +15,18 @@ disallowed-macros = [ { path = "log::trace", reason = "use conduwuit_core::trace" }, ] -disallowed-methods = [ - { path = "tokio::spawn", reason = "use and pass conduuwit_core::server::Server::runtime() to spawn from" }, -] +[[disallowed-methods]] +path = "tokio::spawn" +reason = "use and pass conduwuit_core::server::Server::runtime() to spawn from" + +[[disallowed-methods]] +path = "reqwest::Response::bytes" +reason = "bytes is unsafe, use limit_read via the conduwuit_core::utils::LimitReadExt trait instead" + +[[disallowed-methods]] +path = "reqwest::Response::text" +reason = "text is unsafe, use limit_read_text via the conduwuit_core::utils::LimitReadExt trait instead" + +[[disallowed-methods]] +path = "reqwest::Response::json" +reason = "json is unsafe, use limit_read_text via the conduwuit_core::utils::LimitReadExt trait instead" diff --git a/src/admin/federation/commands.rs b/src/admin/federation/commands.rs index 47b29ffb..6668e917 100644 --- a/src/admin/federation/commands.rs +++ b/src/admin/federation/commands.rs @@ -1,6 +1,6 @@ use std::fmt::Write; -use conduwuit::{Err, Result}; +use conduwuit::{Err, Result, utils::response::LimitReadExt}; use futures::StreamExt; use ruma::{OwnedRoomId, OwnedServerName, OwnedUserId}; @@ -55,7 +55,15 @@ pub(super) async fn fetch_support_well_known(&self, server_name: OwnedServerName .send() .await?; - let text = response.text().await?; + let text = response + .limit_read_text( + self.services + .config + .max_request_size + .try_into() + .expect("u64 fits into usize"), + ) + .await?; if text.is_empty() { return Err!("Response text/body is empty."); diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index be7ccb19..1d62e750 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -11,6 +11,7 @@ pub mod json; pub mod math; pub mod mutex_map; pub mod rand; +pub mod response; pub mod result; pub mod set; pub mod stream; diff --git a/src/core/utils/response.rs b/src/core/utils/response.rs new file mode 100644 index 00000000..618c727f --- /dev/null +++ b/src/core/utils/response.rs @@ -0,0 +1,51 @@ +use futures::StreamExt; +use num_traits::ToPrimitive; + +use crate::Err; + +/// Reads the response body while enforcing a maximum size limit to prevent +/// memory exhaustion. +pub async fn limit_read(response: reqwest::Response, max_size: u64) -> crate::Result> { + if response.content_length().is_some_and(|len| len > max_size) { + return Err!(BadServerResponse("Response too large")); + } + let mut data = Vec::new(); + let mut reader = response.bytes_stream(); + + while let Some(chunk) = reader.next().await { + let chunk = chunk?; + data.extend_from_slice(&chunk); + + if data.len() > max_size.to_usize().expect("max_size must fit in usize") { + return Err!(BadServerResponse("Response too large")); + } + } + + Ok(data) +} + +/// Reads the response body as text while enforcing a maximum size limit to +/// prevent memory exhaustion. +pub async fn limit_read_text( + response: reqwest::Response, + max_size: u64, +) -> crate::Result { + let text = String::from_utf8(limit_read(response, max_size).await?)?; + Ok(text) +} + +#[allow(async_fn_in_trait)] +pub trait LimitReadExt { + async fn limit_read(self, max_size: u64) -> crate::Result>; + async fn limit_read_text(self, max_size: u64) -> crate::Result; +} + +impl LimitReadExt for reqwest::Response { + async fn limit_read(self, max_size: u64) -> crate::Result> { + limit_read(self, max_size).await + } + + async fn limit_read_text(self, max_size: u64) -> crate::Result { + limit_read_text(self, max_size).await + } +} diff --git a/src/service/announcements/mod.rs b/src/service/announcements/mod.rs index be5da101..edee89b6 100644 --- a/src/service/announcements/mod.rs +++ b/src/service/announcements/mod.rs @@ -18,7 +18,7 @@ use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use conduwuit::{Result, Server, debug, error, warn}; +use conduwuit::{Result, Server, debug, error, utils::response::LimitReadExt, warn}; use database::{Deserialized, Map}; use ruma::events::{Mentions, room::message::RoomMessageEventContent}; use serde::Deserialize; @@ -137,7 +137,7 @@ impl Service { .get(CHECK_FOR_ANNOUNCEMENTS_URL) .send() .await? - .text() + .limit_read_text(1024 * 1024) .await?; let response = serde_json::from_str::(&response)?; diff --git a/src/service/federation/execute.rs b/src/service/federation/execute.rs index 2f635503..9ea1d260 100644 --- a/src/service/federation/execute.rs +++ b/src/service/federation/execute.rs @@ -2,8 +2,8 @@ use std::{fmt::Debug, mem}; use bytes::Bytes; use conduwuit::{ - Err, Error, Result, debug, debug::INFO_SPAN_LEVEL, debug_error, debug_warn, err, - error::inspect_debug_log, implement, trace, + Err, Error, Result, debug, debug::INFO_SPAN_LEVEL, debug_error, debug_warn, err, implement, + trace, utils::response::LimitReadExt, }; use http::{HeaderValue, header::AUTHORIZATION}; use ipaddress::IPAddress; @@ -133,7 +133,22 @@ async fn handle_response( where T: OutgoingRequest + Send, { - let response = into_http_response(dest, actual, method, url, response).await?; + const HUGE_ENDPOINTS: [&str; 2] = + ["/_matrix/federation/v2/send_join/", "/_matrix/federation/v2/state/"]; + let size_limit: u64 = if HUGE_ENDPOINTS.iter().any(|e| url.path().starts_with(e)) { + // Some federation endpoints can return huge response bodies, so we'll bump the + // limit for those endpoints specifically. + self.services + .server + .config + .max_request_size + .saturating_mul(10) + } else { + self.services.server.config.max_request_size + } + .try_into() + .expect("size_limit (usize) should fit within a u64"); + let response = into_http_response(dest, actual, method, url, response, size_limit).await?; T::IncomingResponse::try_from_http_response(response) .map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}"))) @@ -145,6 +160,7 @@ async fn into_http_response( method: &Method, url: &Url, mut response: Response, + max_size: u64, ) -> Result> { let status = response.status(); trace!( @@ -167,14 +183,14 @@ async fn into_http_response( ); trace!("Waiting for response body..."); - let body = response - .bytes() - .await - .inspect_err(inspect_debug_log) - .unwrap_or_else(|_| Vec::new().into()); - let http_response = http_response_builder - .body(body) + .body( + response + .limit_read(max_size) + .await + .unwrap_or_default() + .into(), + ) .expect("reqwest body is valid http body"); debug!("Got {status:?} for {method} {url}"); diff --git a/src/service/media/preview.rs b/src/service/media/preview.rs index 514bc2d5..e5a22e27 100644 --- a/src/service/media/preview.rs +++ b/src/service/media/preview.rs @@ -7,7 +7,7 @@ use std::time::SystemTime; -use conduwuit::{Err, Result, debug, err}; +use conduwuit::{Err, Result, debug, err, utils::response::LimitReadExt}; use conduwuit_core::implement; use ipaddress::IPAddress; use serde::Serialize; @@ -112,8 +112,22 @@ pub async fn download_image(&self, url: &str) -> Result { use image::ImageReader; use ruma::Mxc; - let image = self.services.client.url_preview.get(url).send().await?; - let image = image.bytes().await?; + let image = self + .services + .client + .url_preview + .get(url) + .send() + .await? + .limit_read( + self.services + .server + .config + .max_request_size + .try_into() + .expect("u64 should fit in usize"), + ) + .await?; let mxc = Mxc { server_name: self.services.globals.server_name(), media_id: &random_string(super::MXC_LENGTH), @@ -151,24 +165,20 @@ async fn download_html(&self, url: &str) -> Result { use webpage::HTML; let client = &self.services.client.url_preview; - let mut response = client.get(url).send().await?; - - let mut bytes: Vec = Vec::new(); - while let Some(chunk) = response.chunk().await? { - bytes.extend_from_slice(&chunk); - if bytes.len() > self.services.globals.url_preview_max_spider_size() { - debug!( - "Response body from URL {} exceeds url_preview_max_spider_size ({}), not \ - processing the rest of the response body and assuming our necessary data is in \ - this range.", - url, - self.services.globals.url_preview_max_spider_size() - ); - break; - } - } - let body = String::from_utf8_lossy(&bytes); - let Ok(html) = HTML::from_string(body.to_string(), Some(url.to_owned())) else { + let body = client + .get(url) + .send() + .await? + .limit_read_text( + self.services + .server + .config + .max_request_size + .try_into() + .expect("u64 should fit in usize"), + ) + .await?; + let Ok(html) = HTML::from_string(body.clone(), Some(url.to_owned())) else { return Err!(Request(Unknown("Failed to parse HTML"))); }; diff --git a/src/service/media/remote.rs b/src/service/media/remote.rs index da457a8b..29e354ce 100644 --- a/src/service/media/remote.rs +++ b/src/service/media/remote.rs @@ -2,7 +2,7 @@ use std::{fmt::Debug, time::Duration}; use conduwuit::{ Err, Error, Result, debug_warn, err, implement, - utils::content_disposition::make_content_disposition, + utils::{content_disposition::make_content_disposition, response::LimitReadExt}, }; use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE, HeaderValue}; use ruma::{ @@ -286,10 +286,15 @@ async fn location_request(&self, location: &str) -> Result { .and_then(Result::ok); response - .bytes() + .limit_read( + self.services + .server + .config + .max_request_size + .try_into() + .expect("u64 should fit in usize"), + ) .await - .map(Vec::from) - .map_err(Into::into) .map(|content| FileMeta { content: Some(content), content_type: content_type.clone(), diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 1af64e60..cd9c8799 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -1,6 +1,7 @@ use std::{fmt::Debug, mem, sync::Arc}; use bytes::BytesMut; +use conduwuit::utils::response::LimitReadExt; use conduwuit_core::{ Err, Event, Result, debug_warn, err, trace, utils::{stream::TryIgnore, string_from_bytes}, @@ -30,7 +31,7 @@ use ruma::{ uint, }; -use crate::{Dep, client, globals, rooms, sending, users}; +use crate::{Dep, client, config, globals, rooms, sending, users}; pub struct Service { db: Data, @@ -39,6 +40,7 @@ pub struct Service { struct Services { globals: Dep, + config: Dep, client: Dep, state_accessor: Dep, state_cache: Dep, @@ -61,6 +63,7 @@ impl crate::Service for Service { services: Services { globals: args.depend::("globals"), client: args.depend::("client"), + config: args.depend::("config"), state_accessor: args .depend::("rooms::state_accessor"), state_cache: args.depend::("rooms::state_cache"), @@ -245,7 +248,15 @@ impl Service { .expect("http::response::Builder is usable"), ); - let body = response.bytes().await?; + let body = response + .limit_read( + self.services + .config + .max_request_size + .try_into() + .expect("usize fits into u64"), + ) + .await?; if !status.is_success() { debug_warn!("Push gateway response body: {:?}", string_from_bytes(&body)); diff --git a/src/service/resolver/well_known.rs b/src/service/resolver/well_known.rs index 68a8e620..c07549ca 100644 --- a/src/service/resolver/well_known.rs +++ b/src/service/resolver/well_known.rs @@ -1,4 +1,6 @@ -use conduwuit::{Result, debug, debug_error, debug_info, debug_warn, implement, trace}; +use conduwuit::{ + Result, debug, debug_error, debug_info, implement, trace, utils::response::LimitReadExt, +}; #[implement(super::Service)] #[tracing::instrument(name = "well-known", level = "debug", skip(self, dest))] @@ -24,12 +26,8 @@ pub(super) async fn request_well_known(&self, dest: &str) -> Result= 12288 { - debug_warn!("response contains junk"); - return Ok(None); - } let body: serde_json::Value = serde_json::from_str(&text).unwrap_or_default(); diff --git a/src/service/sending/antispam.rs b/src/service/sending/antispam.rs index 2e328b1b..c43cad4b 100644 --- a/src/service/sending/antispam.rs +++ b/src/service/sending/antispam.rs @@ -1,7 +1,7 @@ use std::{fmt::Debug, mem}; use bytes::BytesMut; -use conduwuit::{Err, Result, debug_error, err, utils, warn}; +use conduwuit::{Err, Result, debug_error, err, utils, utils::response::LimitReadExt, warn}; use reqwest::Client; use ruma::api::{IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken}; @@ -38,7 +38,7 @@ where .expect("http::response::Builder is usable"), ); - let body = response.bytes().await?; // TODO: handle timeout + let body = response.limit_read(65535).await?; // TODO: handle timeout if !status.is_success() { debug_error!("Antispam response bytes: {:?}", utils::string_from_bytes(&body)); diff --git a/src/service/sending/appservice.rs b/src/service/sending/appservice.rs index 6a92b6be..dc4f2dd1 100644 --- a/src/service/sending/appservice.rs +++ b/src/service/sending/appservice.rs @@ -1,7 +1,9 @@ use std::{fmt::Debug, mem}; use bytes::BytesMut; -use conduwuit::{Err, Result, debug_error, err, implement, trace, utils, warn}; +use conduwuit::{ + Err, Result, debug_error, err, implement, trace, utils, utils::response::LimitReadExt, warn, +}; use ruma::api::{ IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, appservice::Registration, }; @@ -77,7 +79,15 @@ where .expect("http::response::Builder is usable"), ); - let body = response.bytes().await?; + let body = response + .limit_read( + self.server + .config + .max_request_size + .try_into() + .expect("usize fits into u64"), + ) + .await?; if !status.is_success() { debug_error!("Appservice response bytes: {:?}", utils::string_from_bytes(&body));