fix: Limit body read size of remote requests (CWE-409)

Reviewed-By: Jade Ellis <jade@ellis.link>
This commit is contained in:
timedout 2026-03-03 19:54:34 +00:00
parent 7207398a9e
commit 37888fb670
No known key found for this signature in database
GPG key ID: 0FA334385D0B689F
14 changed files with 192 additions and 54 deletions

15
Cargo.lock generated
View file

@ -4055,12 +4055,14 @@ dependencies = [
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tokio-util",
"tower", "tower",
"tower-http", "tower-http",
"tower-service", "tower-service",
"url", "url",
"wasm-bindgen", "wasm-bindgen",
"wasm-bindgen-futures", "wasm-bindgen-futures",
"wasm-streams",
"web-sys", "web-sys",
"webpki-roots", "webpki-roots",
] ]
@ -5868,6 +5870,19 @@ dependencies = [
"wasmparser", "wasmparser",
] ]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]] [[package]]
name = "wasmparser" name = "wasmparser"
version = "0.244.0" version = "0.244.0"

View file

@ -144,6 +144,7 @@ features = [
"socks", "socks",
"hickory-dns", "hickory-dns",
"http2", "http2",
"stream",
] ]
[workspace.dependencies.serde] [workspace.dependencies.serde]

View file

@ -15,6 +15,18 @@ disallowed-macros = [
{ path = "log::trace", reason = "use conduwuit_core::trace" }, { path = "log::trace", reason = "use conduwuit_core::trace" },
] ]
disallowed-methods = [ [[disallowed-methods]]
{ path = "tokio::spawn", reason = "use and pass conduuwit_core::server::Server::runtime() to spawn from" }, path = "tokio::spawn"
] reason = "use and pass conduwuit_core::server::Server::runtime() to spawn from"
[[disallowed-methods]]
path = "reqwest::Response::bytes"
reason = "bytes is unsafe, use limit_read via the conduwuit_core::utils::LimitReadExt trait instead"
[[disallowed-methods]]
path = "reqwest::Response::text"
reason = "text is unsafe, use limit_read_text via the conduwuit_core::utils::LimitReadExt trait instead"
[[disallowed-methods]]
path = "reqwest::Response::json"
reason = "json is unsafe, use limit_read_text via the conduwuit_core::utils::LimitReadExt trait instead"

View file

@ -1,6 +1,6 @@
use std::fmt::Write; use std::fmt::Write;
use conduwuit::{Err, Result}; use conduwuit::{Err, Result, utils::response::LimitReadExt};
use futures::StreamExt; use futures::StreamExt;
use ruma::{OwnedRoomId, OwnedServerName, OwnedUserId}; use ruma::{OwnedRoomId, OwnedServerName, OwnedUserId};
@ -55,7 +55,15 @@ pub(super) async fn fetch_support_well_known(&self, server_name: OwnedServerName
.send() .send()
.await?; .await?;
let text = response.text().await?; let text = response
.limit_read_text(
self.services
.config
.max_request_size
.try_into()
.expect("u64 fits into usize"),
)
.await?;
if text.is_empty() { if text.is_empty() {
return Err!("Response text/body is empty."); return Err!("Response text/body is empty.");

View file

@ -11,6 +11,7 @@ pub mod json;
pub mod math; pub mod math;
pub mod mutex_map; pub mod mutex_map;
pub mod rand; pub mod rand;
pub mod response;
pub mod result; pub mod result;
pub mod set; pub mod set;
pub mod stream; pub mod stream;

View file

@ -0,0 +1,51 @@
use futures::StreamExt;
use num_traits::ToPrimitive;
use crate::Err;
/// Reads the response body while enforcing a maximum size limit to prevent
/// memory exhaustion.
pub async fn limit_read(response: reqwest::Response, max_size: u64) -> crate::Result<Vec<u8>> {
if response.content_length().is_some_and(|len| len > max_size) {
return Err!(BadServerResponse("Response too large"));
}
let mut data = Vec::new();
let mut reader = response.bytes_stream();
while let Some(chunk) = reader.next().await {
let chunk = chunk?;
data.extend_from_slice(&chunk);
if data.len() > max_size.to_usize().expect("max_size must fit in usize") {
return Err!(BadServerResponse("Response too large"));
}
}
Ok(data)
}
/// Reads the response body as text while enforcing a maximum size limit to
/// prevent memory exhaustion.
pub async fn limit_read_text(
response: reqwest::Response,
max_size: u64,
) -> crate::Result<String> {
let text = String::from_utf8(limit_read(response, max_size).await?)?;
Ok(text)
}
#[allow(async_fn_in_trait)]
pub trait LimitReadExt {
async fn limit_read(self, max_size: u64) -> crate::Result<Vec<u8>>;
async fn limit_read_text(self, max_size: u64) -> crate::Result<String>;
}
impl LimitReadExt for reqwest::Response {
async fn limit_read(self, max_size: u64) -> crate::Result<Vec<u8>> {
limit_read(self, max_size).await
}
async fn limit_read_text(self, max_size: u64) -> crate::Result<String> {
limit_read_text(self, max_size).await
}
}

View file

@ -18,7 +18,7 @@
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use async_trait::async_trait; use async_trait::async_trait;
use conduwuit::{Result, Server, debug, error, warn}; use conduwuit::{Result, Server, debug, error, utils::response::LimitReadExt, warn};
use database::{Deserialized, Map}; use database::{Deserialized, Map};
use ruma::events::{Mentions, room::message::RoomMessageEventContent}; use ruma::events::{Mentions, room::message::RoomMessageEventContent};
use serde::Deserialize; use serde::Deserialize;
@ -137,7 +137,7 @@ impl Service {
.get(CHECK_FOR_ANNOUNCEMENTS_URL) .get(CHECK_FOR_ANNOUNCEMENTS_URL)
.send() .send()
.await? .await?
.text() .limit_read_text(1024 * 1024)
.await?; .await?;
let response = serde_json::from_str::<CheckForAnnouncementsResponse>(&response)?; let response = serde_json::from_str::<CheckForAnnouncementsResponse>(&response)?;

View file

@ -2,8 +2,8 @@ use std::{fmt::Debug, mem};
use bytes::Bytes; use bytes::Bytes;
use conduwuit::{ use conduwuit::{
Err, Error, Result, debug, debug::INFO_SPAN_LEVEL, debug_error, debug_warn, err, Err, Error, Result, debug, debug::INFO_SPAN_LEVEL, debug_error, debug_warn, err, implement,
error::inspect_debug_log, implement, trace, trace, utils::response::LimitReadExt,
}; };
use http::{HeaderValue, header::AUTHORIZATION}; use http::{HeaderValue, header::AUTHORIZATION};
use ipaddress::IPAddress; use ipaddress::IPAddress;
@ -133,7 +133,22 @@ async fn handle_response<T>(
where where
T: OutgoingRequest + Send, T: OutgoingRequest + Send,
{ {
let response = into_http_response(dest, actual, method, url, response).await?; const HUGE_ENDPOINTS: [&str; 2] =
["/_matrix/federation/v2/send_join/", "/_matrix/federation/v2/state/"];
let size_limit: u64 = if HUGE_ENDPOINTS.iter().any(|e| url.path().starts_with(e)) {
// Some federation endpoints can return huge response bodies, so we'll bump the
// limit for those endpoints specifically.
self.services
.server
.config
.max_request_size
.saturating_mul(10)
} else {
self.services.server.config.max_request_size
}
.try_into()
.expect("size_limit (usize) should fit within a u64");
let response = into_http_response(dest, actual, method, url, response, size_limit).await?;
T::IncomingResponse::try_from_http_response(response) T::IncomingResponse::try_from_http_response(response)
.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}"))) .map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}")))
@ -145,6 +160,7 @@ async fn into_http_response(
method: &Method, method: &Method,
url: &Url, url: &Url,
mut response: Response, mut response: Response,
max_size: u64,
) -> Result<http::Response<Bytes>> { ) -> Result<http::Response<Bytes>> {
let status = response.status(); let status = response.status();
trace!( trace!(
@ -167,14 +183,14 @@ async fn into_http_response(
); );
trace!("Waiting for response body..."); trace!("Waiting for response body...");
let body = response
.bytes()
.await
.inspect_err(inspect_debug_log)
.unwrap_or_else(|_| Vec::new().into());
let http_response = http_response_builder let http_response = http_response_builder
.body(body) .body(
response
.limit_read(max_size)
.await
.unwrap_or_default()
.into(),
)
.expect("reqwest body is valid http body"); .expect("reqwest body is valid http body");
debug!("Got {status:?} for {method} {url}"); debug!("Got {status:?} for {method} {url}");

View file

@ -7,7 +7,7 @@
use std::time::SystemTime; use std::time::SystemTime;
use conduwuit::{Err, Result, debug, err}; use conduwuit::{Err, Result, debug, err, utils::response::LimitReadExt};
use conduwuit_core::implement; use conduwuit_core::implement;
use ipaddress::IPAddress; use ipaddress::IPAddress;
use serde::Serialize; use serde::Serialize;
@ -112,8 +112,22 @@ pub async fn download_image(&self, url: &str) -> Result<UrlPreviewData> {
use image::ImageReader; use image::ImageReader;
use ruma::Mxc; use ruma::Mxc;
let image = self.services.client.url_preview.get(url).send().await?; let image = self
let image = image.bytes().await?; .services
.client
.url_preview
.get(url)
.send()
.await?
.limit_read(
self.services
.server
.config
.max_request_size
.try_into()
.expect("u64 should fit in usize"),
)
.await?;
let mxc = Mxc { let mxc = Mxc {
server_name: self.services.globals.server_name(), server_name: self.services.globals.server_name(),
media_id: &random_string(super::MXC_LENGTH), media_id: &random_string(super::MXC_LENGTH),
@ -151,24 +165,20 @@ async fn download_html(&self, url: &str) -> Result<UrlPreviewData> {
use webpage::HTML; use webpage::HTML;
let client = &self.services.client.url_preview; let client = &self.services.client.url_preview;
let mut response = client.get(url).send().await?; let body = client
.get(url)
let mut bytes: Vec<u8> = Vec::new(); .send()
while let Some(chunk) = response.chunk().await? { .await?
bytes.extend_from_slice(&chunk); .limit_read_text(
if bytes.len() > self.services.globals.url_preview_max_spider_size() { self.services
debug!( .server
"Response body from URL {} exceeds url_preview_max_spider_size ({}), not \ .config
processing the rest of the response body and assuming our necessary data is in \ .max_request_size
this range.", .try_into()
url, .expect("u64 should fit in usize"),
self.services.globals.url_preview_max_spider_size() )
); .await?;
break; let Ok(html) = HTML::from_string(body.clone(), Some(url.to_owned())) else {
}
}
let body = String::from_utf8_lossy(&bytes);
let Ok(html) = HTML::from_string(body.to_string(), Some(url.to_owned())) else {
return Err!(Request(Unknown("Failed to parse HTML"))); return Err!(Request(Unknown("Failed to parse HTML")));
}; };

View file

@ -2,7 +2,7 @@ use std::{fmt::Debug, time::Duration};
use conduwuit::{ use conduwuit::{
Err, Error, Result, debug_warn, err, implement, Err, Error, Result, debug_warn, err, implement,
utils::content_disposition::make_content_disposition, utils::{content_disposition::make_content_disposition, response::LimitReadExt},
}; };
use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE, HeaderValue}; use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE, HeaderValue};
use ruma::{ use ruma::{
@ -286,10 +286,15 @@ async fn location_request(&self, location: &str) -> Result<FileMeta> {
.and_then(Result::ok); .and_then(Result::ok);
response response
.bytes() .limit_read(
self.services
.server
.config
.max_request_size
.try_into()
.expect("u64 should fit in usize"),
)
.await .await
.map(Vec::from)
.map_err(Into::into)
.map(|content| FileMeta { .map(|content| FileMeta {
content: Some(content), content: Some(content),
content_type: content_type.clone(), content_type: content_type.clone(),

View file

@ -1,6 +1,7 @@
use std::{fmt::Debug, mem, sync::Arc}; use std::{fmt::Debug, mem, sync::Arc};
use bytes::BytesMut; use bytes::BytesMut;
use conduwuit::utils::response::LimitReadExt;
use conduwuit_core::{ use conduwuit_core::{
Err, Event, Result, debug_warn, err, trace, Err, Event, Result, debug_warn, err, trace,
utils::{stream::TryIgnore, string_from_bytes}, utils::{stream::TryIgnore, string_from_bytes},
@ -30,7 +31,7 @@ use ruma::{
uint, uint,
}; };
use crate::{Dep, client, globals, rooms, sending, users}; use crate::{Dep, client, config, globals, rooms, sending, users};
pub struct Service { pub struct Service {
db: Data, db: Data,
@ -39,6 +40,7 @@ pub struct Service {
struct Services { struct Services {
globals: Dep<globals::Service>, globals: Dep<globals::Service>,
config: Dep<config::Service>,
client: Dep<client::Service>, client: Dep<client::Service>,
state_accessor: Dep<rooms::state_accessor::Service>, state_accessor: Dep<rooms::state_accessor::Service>,
state_cache: Dep<rooms::state_cache::Service>, state_cache: Dep<rooms::state_cache::Service>,
@ -61,6 +63,7 @@ impl crate::Service for Service {
services: Services { services: Services {
globals: args.depend::<globals::Service>("globals"), globals: args.depend::<globals::Service>("globals"),
client: args.depend::<client::Service>("client"), client: args.depend::<client::Service>("client"),
config: args.depend::<config::Service>("config"),
state_accessor: args state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), .depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
@ -245,7 +248,15 @@ impl Service {
.expect("http::response::Builder is usable"), .expect("http::response::Builder is usable"),
); );
let body = response.bytes().await?; let body = response
.limit_read(
self.services
.config
.max_request_size
.try_into()
.expect("usize fits into u64"),
)
.await?;
if !status.is_success() { if !status.is_success() {
debug_warn!("Push gateway response body: {:?}", string_from_bytes(&body)); debug_warn!("Push gateway response body: {:?}", string_from_bytes(&body));

View file

@ -1,4 +1,6 @@
use conduwuit::{Result, debug, debug_error, debug_info, debug_warn, implement, trace}; use conduwuit::{
Result, debug, debug_error, debug_info, implement, trace, utils::response::LimitReadExt,
};
#[implement(super::Service)] #[implement(super::Service)]
#[tracing::instrument(name = "well-known", level = "debug", skip(self, dest))] #[tracing::instrument(name = "well-known", level = "debug", skip(self, dest))]
@ -24,12 +26,8 @@ pub(super) async fn request_well_known(&self, dest: &str) -> Result<Option<Strin
return Ok(None); return Ok(None);
} }
let text = response.text().await?; let text = response.limit_read_text(8192).await?;
trace!("response text: {text:?}"); trace!("response text: {text:?}");
if text.len() >= 12288 {
debug_warn!("response contains junk");
return Ok(None);
}
let body: serde_json::Value = serde_json::from_str(&text).unwrap_or_default(); let body: serde_json::Value = serde_json::from_str(&text).unwrap_or_default();

View file

@ -1,7 +1,7 @@
use std::{fmt::Debug, mem}; use std::{fmt::Debug, mem};
use bytes::BytesMut; use bytes::BytesMut;
use conduwuit::{Err, Result, debug_error, err, utils, warn}; use conduwuit::{Err, Result, debug_error, err, utils, utils::response::LimitReadExt, warn};
use reqwest::Client; use reqwest::Client;
use ruma::api::{IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken}; use ruma::api::{IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken};
@ -38,7 +38,7 @@ where
.expect("http::response::Builder is usable"), .expect("http::response::Builder is usable"),
); );
let body = response.bytes().await?; // TODO: handle timeout let body = response.limit_read(65535).await?; // TODO: handle timeout
if !status.is_success() { if !status.is_success() {
debug_error!("Antispam response bytes: {:?}", utils::string_from_bytes(&body)); debug_error!("Antispam response bytes: {:?}", utils::string_from_bytes(&body));

View file

@ -1,7 +1,9 @@
use std::{fmt::Debug, mem}; use std::{fmt::Debug, mem};
use bytes::BytesMut; use bytes::BytesMut;
use conduwuit::{Err, Result, debug_error, err, implement, trace, utils, warn}; use conduwuit::{
Err, Result, debug_error, err, implement, trace, utils, utils::response::LimitReadExt, warn,
};
use ruma::api::{ use ruma::api::{
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, appservice::Registration,
}; };
@ -77,7 +79,15 @@ where
.expect("http::response::Builder is usable"), .expect("http::response::Builder is usable"),
); );
let body = response.bytes().await?; let body = response
.limit_read(
self.server
.config
.max_request_size
.try_into()
.expect("usize fits into u64"),
)
.await?;
if !status.is_success() { if !status.is_success() {
debug_error!("Appservice response bytes: {:?}", utils::string_from_bytes(&body)); debug_error!("Appservice response bytes: {:?}", utils::string_from_bytes(&body));