continuwuity/src/service/media/preview.rs

451 lines
13 KiB
Rust

//! URL Previews
//!
//! This functionality is gated by 'url_preview', but not at the unit level for
//! historical and simplicity reasons. Instead the feature gates the inclusion
//! of dependencies and nulls out results through the existing interface when
//! not featured.
use std::time::SystemTime;
use conduwuit::{Err, Result, debug, err, utils::response::LimitReadExt};
use conduwuit_core::implement;
use ipaddress::IPAddress;
#[cfg(feature = "url_preview")]
use ruma::OwnedMxcUri;
use serde::Serialize;
use url::Url;
use super::Service;
#[derive(Serialize, Default)]
pub struct UrlPreviewData {
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:title"))]
pub title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:description"))]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image"))]
pub image: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "matrix:image:size"))]
pub image_size: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:width"))]
pub image_width: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:height"))]
pub image_height: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:video"))]
pub video: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "matrix:video:size"))]
pub video_size: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:video:width"))]
pub video_width: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:video:height"))]
pub video_height: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:audio"))]
pub audio: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "matrix:audio:size"))]
pub audio_size: Option<usize>,
}
#[implement(Service)]
pub async fn remove_url_preview(&self, url: &str) -> Result<()> {
// TODO: also remove the downloaded image
self.db.remove_url_preview(url)
}
#[implement(Service)]
pub async fn clear_url_previews(&self) { self.db.clear_url_previews().await; }
#[implement(Service)]
pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("valid system time");
self.db.set_url_preview(url, data, now)
}
#[implement(Service)]
pub async fn get_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
if let Ok(preview) = self.db.get_url_preview(url.as_str()).await {
return Ok(preview);
}
// ensure that only one request is made per URL
let _request_lock = self.url_preview_mutex.lock(url.as_str()).await;
match self.db.get_url_preview(url.as_str()).await {
| Ok(preview) => Ok(preview),
| Err(_) => self.request_url_preview(url).await,
}
}
#[implement(Service)]
async fn request_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
if let Ok(ip) = IPAddress::parse(url.host_str().expect("URL previously validated")) {
if !self.services.client.valid_cidr_range(&ip) {
return Err!(Request(Forbidden("Requesting from this address is forbidden")));
}
}
let client = &self.services.client.url_preview;
let response = client.head(url.as_str()).send().await?;
debug!(%url, "URL preview response headers: {:?}", response.headers());
if let Some(remote_addr) = response.remote_addr() {
debug!(%url, "URL preview response remote address: {:?}", remote_addr);
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
if !self.services.client.valid_cidr_range(&ip) {
return Err!(Request(Forbidden("Requesting from this address is forbidden")));
}
}
}
let Some(content_type) = response.headers().get(reqwest::header::CONTENT_TYPE) else {
return Err!(Request(Unknown("Unknown or invalid Content-Type header")));
};
let content_type = content_type
.to_str()
.map_err(|e| err!(Request(Unknown("Unknown or invalid Content-Type header: {e}"))))?;
let data = match content_type {
| html if html.starts_with("text/html") => self.download_html(url.as_str()).await?,
| img if img.starts_with("image/") => self.download_image(url.as_str(), None).await?,
| video if video.starts_with("video/") => self.download_video(url.as_str(), None).await?,
| audio if audio.starts_with("audio/") => self.download_audio(url.as_str(), None).await?,
| _ => return Err!(Request(Unknown("Unsupported Content-Type"))),
};
self.set_url_preview(url.as_str(), &data).await?;
Ok(data)
}
#[cfg(feature = "url_preview")]
#[implement(Service)]
pub async fn download_image(
&self,
url: &str,
preview_data: Option<UrlPreviewData>,
) -> Result<UrlPreviewData> {
use conduwuit::utils::random_string;
use image::ImageReader;
use ruma::Mxc;
let mut preview_data = preview_data.unwrap_or_default();
let image = self
.services
.client
.url_preview
.get(url)
.send()
.await?
.limit_read(
self.services
.server
.config
.max_request_size
.try_into()
.expect("u64 should fit in usize"),
)
.await?;
let mxc = Mxc {
server_name: self.services.globals.server_name(),
media_id: &random_string(super::MXC_LENGTH),
};
self.create(&mxc, None, None, None, &image).await?;
preview_data.image = Some(mxc.to_string());
if preview_data.image_height.is_none() || preview_data.image_width.is_none() {
let cursor = std::io::Cursor::new(&image);
let (width, height) = match ImageReader::new(cursor).with_guessed_format() {
| Err(_) => (None, None),
| Ok(reader) => match reader.into_dimensions() {
| Err(_) => (None, None),
| Ok((width, height)) => (Some(width), Some(height)),
},
};
preview_data.image_width = width;
preview_data.image_height = height;
}
Ok(preview_data)
}
#[cfg(feature = "url_preview")]
#[implement(Service)]
pub async fn download_video(
&self,
url: &str,
preview_data: Option<UrlPreviewData>,
) -> Result<UrlPreviewData> {
let mut preview_data = preview_data.unwrap_or_default();
if self.services.globals.url_preview_allow_audio_video() {
let (url, size) = self.download_media(url).await?;
preview_data.video = Some(url.to_string());
preview_data.video_size = Some(size);
}
Ok(preview_data)
}
#[cfg(feature = "url_preview")]
#[implement(Service)]
pub async fn download_audio(
&self,
url: &str,
preview_data: Option<UrlPreviewData>,
) -> Result<UrlPreviewData> {
let mut preview_data = preview_data.unwrap_or_default();
if self.services.globals.url_preview_allow_audio_video() {
let (url, size) = self.download_media(url).await?;
preview_data.audio = Some(url.to_string());
preview_data.audio_size = Some(size);
}
Ok(preview_data)
}
#[cfg(feature = "url_preview")]
#[implement(Service)]
pub async fn download_media(&self, url: &str) -> Result<(OwnedMxcUri, usize)> {
use conduwuit::utils::random_string;
use http::header::CONTENT_TYPE;
use ruma::Mxc;
let response = self.services.client.url_preview.get(url).send().await?;
let content_type = response.headers().get(CONTENT_TYPE).cloned();
let media = response
.limit_read(
self.services
.server
.config
.max_request_size
.try_into()
.expect("u64 should fit in usize"),
)
.await?;
let mxc = Mxc {
server_name: self.services.globals.server_name(),
media_id: &random_string(super::MXC_LENGTH),
};
let content_type = content_type.and_then(|v| v.to_str().map(ToOwned::to_owned).ok());
self.create(&mxc, None, None, content_type.as_deref(), &media)
.await?;
Ok((OwnedMxcUri::from(mxc.to_string()), media.len()))
}
#[cfg(not(feature = "url_preview"))]
#[implement(Service)]
pub async fn download_image(
&self,
_url: &str,
_preview_data: Option<UrlPreviewData>,
) -> Result<UrlPreviewData> {
Err!(FeatureDisabled("url_preview"))
}
#[cfg(not(feature = "url_preview"))]
#[implement(Service)]
pub async fn download_video(
&self,
_url: &str,
_preview_data: Option<UrlPreviewData>,
) -> Result<UrlPreviewData> {
Err!(FeatureDisabled("url_preview"))
}
#[cfg(not(feature = "url_preview"))]
#[implement(Service)]
pub async fn download_audio(
&self,
_url: &str,
_preview_data: Option<UrlPreviewData>,
) -> Result<UrlPreviewData> {
Err!(FeatureDisabled("url_preview"))
}
#[cfg(not(feature = "url_preview"))]
#[implement(Service)]
pub async fn download_media(&self, _url: &str) -> Result<UrlPreviewData> {
Err!(FeatureDisabled("url_preview"))
}
#[cfg(feature = "url_preview")]
#[implement(Service)]
async fn download_html(&self, url: &str) -> Result<UrlPreviewData> {
use webpage::HTML;
let client = &self.services.client.url_preview;
let body = client
.get(url)
.send()
.await?
.limit_read_text(
self.services
.server
.config
.max_request_size
.try_into()
.expect("u64 should fit in usize"),
)
.await?;
let Ok(html) = HTML::from_string(body.clone(), Some(url.to_owned())) else {
return Err!(Request(Unknown("Failed to parse HTML")));
};
let mut preview_data = UrlPreviewData::default();
if let Some(obj) = html.opengraph.images.first() {
preview_data = self.download_image(&obj.url, Some(preview_data)).await?;
}
if let Some(obj) = html.opengraph.videos.first() {
preview_data = self.download_video(&obj.url, Some(preview_data)).await?;
preview_data.video_width = obj.properties.get("width").and_then(|v| v.parse().ok());
preview_data.video_height = obj.properties.get("height").and_then(|v| v.parse().ok());
}
if let Some(obj) = html.opengraph.audios.first() {
preview_data = self.download_audio(&obj.url, Some(preview_data)).await?;
}
let props = html.opengraph.properties;
/* use OpenGraph title/description, but fall back to HTML if not available */
preview_data.title = props.get("title").cloned().or(html.title);
preview_data.description = props.get("description").cloned().or(html.description);
Ok(preview_data)
}
#[cfg(not(feature = "url_preview"))]
#[implement(Service)]
async fn download_html(&self, _url: &str) -> Result<UrlPreviewData> {
Err!(FeatureDisabled("url_preview"))
}
#[implement(Service)]
pub fn url_preview_allowed(&self, url: &Url) -> bool {
if ["http", "https"]
.iter()
.all(|&scheme| scheme != url.scheme().to_lowercase())
{
debug!("Ignoring non-HTTP/HTTPS URL to preview: {}", url);
return false;
}
let host = match url.host_str() {
| None => {
debug!("Ignoring URL preview for a URL that does not have a host (?): {}", url);
return false;
},
| Some(h) => h.to_owned(),
};
let allowlist_domain_contains = self
.services
.globals
.url_preview_domain_contains_allowlist();
let allowlist_domain_explicit = self
.services
.globals
.url_preview_domain_explicit_allowlist();
let denylist_domain_explicit = self.services.globals.url_preview_domain_explicit_denylist();
let allowlist_url_contains = self.services.globals.url_preview_url_contains_allowlist();
if allowlist_domain_contains.contains(&"*".to_owned())
|| allowlist_domain_explicit.contains(&"*".to_owned())
|| allowlist_url_contains.contains(&"*".to_owned())
{
debug!("Config key contains * which is allowing all URL previews. Allowing URL {}", url);
return true;
}
if !host.is_empty() {
if denylist_domain_explicit.contains(&host) {
debug!(
"Host {} is not allowed by url_preview_domain_explicit_denylist (check 1/4)",
&host
);
return false;
}
if allowlist_domain_explicit.contains(&host) {
debug!(
"Host {} is allowed by url_preview_domain_explicit_allowlist (check 2/4)",
&host
);
return true;
}
if allowlist_domain_contains
.iter()
.any(|domain_s| domain_s.contains(&host.clone()))
{
debug!(
"Host {} is allowed by url_preview_domain_contains_allowlist (check 3/4)",
&host
);
return true;
}
if allowlist_url_contains
.iter()
.any(|url_s| url.to_string().contains(url_s))
{
debug!("URL {} is allowed by url_preview_url_contains_allowlist (check 4/4)", &host);
return true;
}
// check root domain if available and if user has root domain checks
if self.services.globals.url_preview_check_root_domain() {
debug!("Checking root domain");
match host.split_once('.') {
| None => return false,
| Some((_, root_domain)) => {
if denylist_domain_explicit.contains(&root_domain.to_owned()) {
debug!(
"Root domain {} is not allowed by \
url_preview_domain_explicit_denylist (check 1/3)",
&root_domain
);
return true;
}
if allowlist_domain_explicit.contains(&root_domain.to_owned()) {
debug!(
"Root domain {} is allowed by url_preview_domain_explicit_allowlist \
(check 2/3)",
&root_domain
);
return true;
}
if allowlist_domain_contains
.iter()
.any(|domain_s| domain_s.contains(&root_domain.to_owned()))
{
debug!(
"Root domain {} is allowed by url_preview_domain_contains_allowlist \
(check 3/3)",
&root_domain
);
return true;
}
},
}
}
}
false
}