refactor: Add support for multiple static tokens to registration token service
This commit is contained in:
parent
7d0686f33c
commit
b2a87e2fb9
3 changed files with 28 additions and 31 deletions
|
|
@ -185,17 +185,10 @@ pub(crate) async fn register_route(
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_guest
|
if is_guest && !services.config.allow_guest_registration {
|
||||||
&& (!services.config.allow_guest_registration
|
|
||||||
|| (services.config.allow_registration
|
|
||||||
&& services
|
|
||||||
.registration_tokens
|
|
||||||
.get_config_file_token()
|
|
||||||
.is_some()))
|
|
||||||
{
|
|
||||||
info!(
|
info!(
|
||||||
"Guest registration disabled / registration enabled with token configured, \
|
"Guest registration disabled, rejecting guest registration attempt, initial device \
|
||||||
rejecting guest registration attempt, initial device name: \"{}\"",
|
name: \"{}\"",
|
||||||
body.initial_device_display_name.as_deref().unwrap_or("")
|
body.initial_device_display_name.as_deref().unwrap_or("")
|
||||||
);
|
);
|
||||||
return Err!(Request(GuestAccessForbidden("Guest registration is disabled.")));
|
return Err!(Request(GuestAccessForbidden("Guest registration is disabled.")));
|
||||||
|
|
|
||||||
|
|
@ -7,12 +7,25 @@ use conduwuit::{
|
||||||
error, implement,
|
error, implement,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::registration_tokens::{ValidToken, ValidTokenSource};
|
||||||
|
|
||||||
pub struct Service {
|
pub struct Service {
|
||||||
server: Arc<Server>,
|
server: Arc<Server>,
|
||||||
}
|
}
|
||||||
|
|
||||||
const SIGNAL: &str = "SIGUSR1";
|
const SIGNAL: &str = "SIGUSR1";
|
||||||
|
|
||||||
|
impl Service {
|
||||||
|
/// Get the registration token set in the config file, if it exists.
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_config_file_token(&self) -> Option<ValidToken> {
|
||||||
|
self.registration_token.clone().map(|token| ValidToken {
|
||||||
|
token,
|
||||||
|
source: ValidTokenSource::ConfigFile,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl crate::Service for Service {
|
impl crate::Service for Service {
|
||||||
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
|
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ use std::sync::Arc;
|
||||||
use conduwuit::{Err, Result, utils};
|
use conduwuit::{Err, Result, utils};
|
||||||
use data::Data;
|
use data::Data;
|
||||||
pub use data::{DatabaseTokenInfo, TokenExpires};
|
pub use data::{DatabaseTokenInfo, TokenExpires};
|
||||||
use futures::{Stream, StreamExt, stream};
|
use futures::{Stream, StreamExt};
|
||||||
use ruma::OwnedUserId;
|
use ruma::OwnedUserId;
|
||||||
|
|
||||||
use crate::{Dep, config};
|
use crate::{Dep, config};
|
||||||
|
|
@ -84,29 +84,20 @@ impl Service {
|
||||||
(token, info)
|
(token, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the registration token set in the config file, if it exists.
|
/// Get all the "special" registration tokens that aren't defined in the
|
||||||
pub fn get_config_file_token(&self) -> Option<ValidToken> {
|
/// database.
|
||||||
self.services
|
fn iterate_static_tokens(&self) -> impl Iterator<Item = ValidToken> {
|
||||||
.config
|
// right now this is just the config file token
|
||||||
.registration_token
|
self.services.config.get_config_file_token().into_iter()
|
||||||
.clone()
|
|
||||||
.map(|token| ValidToken {
|
|
||||||
token,
|
|
||||||
source: ValidTokenSource::ConfigFile,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate a registration token.
|
/// Validate a registration token.
|
||||||
pub async fn validate_token(&self, token: String) -> Option<ValidToken> {
|
pub async fn validate_token(&self, token: String) -> Option<ValidToken> {
|
||||||
// Check the registration token in the config first
|
// Check static registration tokens first
|
||||||
if self
|
for static_token in self.iterate_static_tokens() {
|
||||||
.get_config_file_token()
|
if static_token == *token {
|
||||||
.is_some_and(|valid_token| valid_token == *token)
|
return Some(static_token);
|
||||||
{
|
}
|
||||||
return Some(ValidToken {
|
|
||||||
token,
|
|
||||||
source: ValidTokenSource::ConfigFile,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now check the database
|
// Now check the database
|
||||||
|
|
@ -167,6 +158,6 @@ impl Service {
|
||||||
source: ValidTokenSource::Database(info),
|
source: ValidTokenSource::Database(info),
|
||||||
});
|
});
|
||||||
|
|
||||||
stream::iter(self.get_config_file_token()).chain(db_tokens)
|
futures::stream::iter(self.iterate_static_tokens()).chain(db_tokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue