From b2a87e2fb97657615a03568943157a325da1a1f4 Mon Sep 17 00:00:00 2001 From: Ginger Date: Thu, 12 Feb 2026 10:16:03 -0500 Subject: [PATCH] refactor: Add support for multiple static tokens to registration token service --- src/api/client/account.rs | 13 +++------- src/service/config/mod.rs | 13 ++++++++++ src/service/registration_tokens/mod.rs | 33 ++++++++++---------------- 3 files changed, 28 insertions(+), 31 deletions(-) diff --git a/src/api/client/account.rs b/src/api/client/account.rs index ca078f69..8bb2ff70 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -185,17 +185,10 @@ pub(crate) async fn register_route( ))); } - if is_guest - && (!services.config.allow_guest_registration - || (services.config.allow_registration - && services - .registration_tokens - .get_config_file_token() - .is_some())) - { + if is_guest && !services.config.allow_guest_registration { info!( - "Guest registration disabled / registration enabled with token configured, \ - rejecting guest registration attempt, initial device name: \"{}\"", + "Guest registration disabled, rejecting guest registration attempt, initial device \ + name: \"{}\"", body.initial_device_display_name.as_deref().unwrap_or("") ); return Err!(Request(GuestAccessForbidden("Guest registration is disabled."))); diff --git a/src/service/config/mod.rs b/src/service/config/mod.rs index efeed743..20cb431d 100644 --- a/src/service/config/mod.rs +++ b/src/service/config/mod.rs @@ -7,12 +7,25 @@ use conduwuit::{ error, implement, }; +use crate::registration_tokens::{ValidToken, ValidTokenSource}; + pub struct Service { server: Arc, } const SIGNAL: &str = "SIGUSR1"; +impl Service { + /// Get the registration token set in the config file, if it exists. + #[must_use] + pub fn get_config_file_token(&self) -> Option { + self.registration_token.clone().map(|token| ValidToken { + token, + source: ValidTokenSource::ConfigFile, + }) + } +} + #[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { diff --git a/src/service/registration_tokens/mod.rs b/src/service/registration_tokens/mod.rs index 2f75aa3c..b8d4b020 100644 --- a/src/service/registration_tokens/mod.rs +++ b/src/service/registration_tokens/mod.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use conduwuit::{Err, Result, utils}; use data::Data; pub use data::{DatabaseTokenInfo, TokenExpires}; -use futures::{Stream, StreamExt, stream}; +use futures::{Stream, StreamExt}; use ruma::OwnedUserId; use crate::{Dep, config}; @@ -84,29 +84,20 @@ impl Service { (token, info) } - /// Get the registration token set in the config file, if it exists. - pub fn get_config_file_token(&self) -> Option { - self.services - .config - .registration_token - .clone() - .map(|token| ValidToken { - token, - source: ValidTokenSource::ConfigFile, - }) + /// Get all the "special" registration tokens that aren't defined in the + /// database. + fn iterate_static_tokens(&self) -> impl Iterator { + // right now this is just the config file token + self.services.config.get_config_file_token().into_iter() } /// Validate a registration token. pub async fn validate_token(&self, token: String) -> Option { - // Check the registration token in the config first - if self - .get_config_file_token() - .is_some_and(|valid_token| valid_token == *token) - { - return Some(ValidToken { - token, - source: ValidTokenSource::ConfigFile, - }); + // Check static registration tokens first + for static_token in self.iterate_static_tokens() { + if static_token == *token { + return Some(static_token); + } } // Now check the database @@ -167,6 +158,6 @@ impl Service { source: ValidTokenSource::Database(info), }); - stream::iter(self.get_config_file_token()).chain(db_tokens) + futures::stream::iter(self.iterate_static_tokens()).chain(db_tokens) } }