diff --git a/src/config/check.rs b/src/config/check.rs new file mode 100644 index 000000000..01bc6acd6 --- /dev/null +++ b/src/config/check.rs @@ -0,0 +1,168 @@ +// not unix specific, just only for UNIX sockets stuff and *nix container checks +#[cfg(unix)] +use std::path::Path; // not unix specific, just only for UNIX sockets stuff and *nix container checks + +use tracing::{debug, error, info, warn}; + +use crate::{utils::error::Error, Config}; + +pub fn check(config: &Config) -> Result<(), Error> { + config.warn_deprecated(); + config.warn_unknown_key(); + + if config.unix_socket_path.is_some() && !cfg!(unix) { + return Err(Error::bad_config( + "UNIX socket support is only available on *nix platforms. Please remove \"unix_socket_path\" from your \ + config.", + )); + } + + if config.address.is_loopback() && cfg!(unix) { + debug!( + "Found loopback listening address {}, running checks if we're in a container.", + config.address + ); + + #[cfg(unix)] + if Path::new("/proc/vz").exists() /* Guest */ && !Path::new("/proc/bz").exists() + /* Host */ + { + error!( + "You are detected using OpenVZ with a loopback/localhost listening address of {}. If you are using \ + OpenVZ for containers and you use NAT-based networking to communicate with the host and guest, this \ + will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can ignore.", + config.address + ); + } + + #[cfg(unix)] + if Path::new("/.dockerenv").exists() { + error!( + "You are detected using Docker with a loopback/localhost listening address of {}. If you are using a \ + reverse proxy on the host and require communication to conduwuit in the Docker container via \ + NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, \ + you can ignore.", + config.address + ); + } + + #[cfg(unix)] + if Path::new("/run/.containerenv").exists() { + error!( + "You are detected using Podman with a loopback/localhost listening address of {}. If you are using a \ + reverse proxy on the host and require communication to conduwuit in the Podman container via \ + NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, \ + you can ignore.", + config.address + ); + } + } + + // rocksdb does not allow max_log_files to be 0 + if config.rocksdb_max_log_files == 0 && cfg!(feature = "rocksdb") { + return Err(Error::bad_config( + "When using RocksDB, rocksdb_max_log_files cannot be 0. Please set a value at least 1.", + )); + } + + // yeah, unless the user built a debug build hopefully for local testing only + if config.server_name == "your.server.name" && !cfg!(debug_assertions) { + return Err(Error::bad_config( + "You must specify a valid server name for production usage of conduwuit.", + )); + } + + if cfg!(debug_assertions) { + info!("Note: conduwuit was built without optimisations (i.e. debug build)"); + } + + // check if the user specified a registration token as `""` + if config.registration_token == Some(String::new()) { + return Err(Error::bad_config("Registration token was specified but is empty (\"\")")); + } + + if config.max_request_size < 16384 { + return Err(Error::bad_config("Max request size is less than 16KB. Please increase it.")); + } + + // check if user specified valid IP CIDR ranges on startup + for cidr in &config.ip_range_denylist { + if let Err(e) = ipaddress::IPAddress::parse(cidr) { + error!("Error parsing specified IP CIDR range from string: {e}"); + return Err(Error::bad_config("Error parsing specified IP CIDR ranges from strings")); + } + } + + if config.allow_registration + && !config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse + && config.registration_token.is_none() + { + return Err(Error::bad_config( + "!! You have `allow_registration` enabled without a token configured in your config which means you are \ + allowing ANYONE to register on your conduwuit instance without any 2nd-step (e.g. registration token).\n +If this is not the intended behaviour, please set a registration token with the `registration_token` config option.\n +For security and safety reasons, conduwuit will shut down. If you are extra sure this is the desired behaviour you \ + want, please set the following config option to true: +`yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse`", + )); + } + + if config.allow_registration + && config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse + && config.registration_token.is_none() + { + warn!( + "Open registration is enabled via setting \ + `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` and `allow_registration` to \ + true without a registration token configured. You are expected to be aware of the risks now.\n + If this is not the desired behaviour, please set a registration token." + ); + } + + if config.allow_outgoing_presence && !config.allow_local_presence { + return Err(Error::bad_config( + "Outgoing presence requires allowing local presence. Please enable \"allow_local_presence\".", + )); + } + + if config.allow_outgoing_presence { + warn!( + "! Outgoing federated presence is not spec compliant due to relying on PDUs and EDUs combined.\nOutgoing \ + presence will not be very reliable due to this and any issues with federated outgoing presence are very \ + likely attributed to this issue.\nIncoming presence and local presence are unaffected." + ); + } + + if config + .url_preview_domain_contains_allowlist + .contains(&"*".to_owned()) + { + warn!( + "All URLs are allowed for URL previews via setting \"url_preview_domain_contains_allowlist\" to \"*\". \ + This opens up significant attack surface to your server. You are expected to be aware of the risks by \ + doing this." + ); + } + if config + .url_preview_domain_explicit_allowlist + .contains(&"*".to_owned()) + { + warn!( + "All URLs are allowed for URL previews via setting \"url_preview_domain_explicit_allowlist\" to \"*\". \ + This opens up significant attack surface to your server. You are expected to be aware of the risks by \ + doing this." + ); + } + if config + .url_preview_url_contains_allowlist + .contains(&"*".to_owned()) + { + warn!( + "All URLs are allowed for URL previews via setting \"url_preview_url_contains_allowlist\" to \"*\". This \ + opens up significant attack surface to your server. You are expected to be aware of the risks by doing \ + this." + ); + } + + Ok(()) +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 730740e85..671c3a9b9 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,12 +1,18 @@ use std::{ collections::BTreeMap, fmt::{self, Write as _}, - net::{IpAddr, Ipv4Addr}, + net::{IpAddr, Ipv4Addr, SocketAddr}, path::PathBuf, }; -use either::Either; -use figment::Figment; +use either::{ + Either, + Either::{Left, Right}, +}; +use figment::{ + providers::{Env, Format, Toml}, + Figment, +}; use itertools::Itertools; use regex::RegexSet; use ruma::{OwnedRoomId, OwnedServerName, RoomVersionId}; @@ -14,7 +20,9 @@ use serde::{de::IgnoredAny, Deserialize}; use tracing::{debug, error, warn}; use self::proxy::ProxyConfig; +use crate::utils::error::Error; +mod check; mod proxy; #[derive(Deserialize, Clone, Debug)] @@ -299,6 +307,35 @@ pub struct TlsConfig { const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; impl Config { + /// Initialize config + pub fn new(path: Option) -> Result { + let raw_config = if let Some(config_file_env) = Env::var("CONDUIT_CONFIG") { + Figment::new() + .merge(Toml::file(config_file_env).nested()) + .merge(Env::prefixed("CONDUIT_").global()) + } else if let Some(config_file_arg) = path { + Figment::new() + .merge(Toml::file(config_file_arg).nested()) + .merge(Env::prefixed("CONDUIT_").global()) + } else { + Figment::new().merge(Env::prefixed("CONDUIT_").global()) + }; + + let config = match raw_config.extract::() { + Err(e) => return Err(Error::BadConfig(format!("{e}"))), + Ok(config) => config, + }; + + check::check(&config)?; + + // don't start if we're listening on both UNIX sockets and TCP at same time + if config.is_dual_listening(&raw_config) { + return Err(Error::bad_config("dual listening on UNIX and TCP sockets not allowed.")); + }; + + Ok(config) + } + /// Iterates over all the keys in the config file and warns if there is a /// deprecated key specified pub fn warn_deprecated(&self) { @@ -336,7 +373,7 @@ impl Config { /// Checks the presence of the `address` and `unix_socket_path` keys in the /// raw_config, exiting the process if both keys were detected. - pub fn is_dual_listening(&self, raw_config: &Figment) -> bool { + fn is_dual_listening(&self, raw_config: &Figment) -> bool { let check_address = raw_config.find_value("address"); let check_unix_socket = raw_config.find_value("unix_socket_path"); @@ -349,6 +386,27 @@ impl Config { false } + + #[must_use] + pub fn get_bind_addrs(&self) -> Vec { + match &self.port.ports { + Left(port) => { + // Left is only 1 value, so make a vec with 1 value only + let port_vec = [port]; + + port_vec + .iter() + .copied() + .map(|port| SocketAddr::from((self.address, *port))) + .collect::>() + }, + Right(ports) => ports + .iter() + .copied() + .map(|port| SocketAddr::from((self.address, port))) + .collect::>(), + } + } } impl fmt::Display for Config { @@ -628,7 +686,7 @@ fn default_address() -> IpAddr { Ipv4Addr::LOCALHOST.into() } fn default_port() -> ListeningPort { ListeningPort { - ports: Either::Left(8008), + ports: Left(8008), } } diff --git a/src/database/mod.rs b/src/database/mod.rs index 86f499b2e..2a8af795f 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -238,7 +238,7 @@ impl KeyValueDatabase { debug!("Database path does not exist, assuming this is a new setup and creating it"); fs::create_dir_all(&config.database_path).map_err(|e| { error!("Failed to create database path: {e}"); - Error::BadConfig( + Error::bad_config( "Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please \ create the database folder yourself or allow conduwuit the permissions to create directories and \ files.", @@ -250,19 +250,19 @@ impl KeyValueDatabase { "sqlite" => { debug!("Got sqlite database backend"); #[cfg(not(feature = "sqlite"))] - return Err(Error::BadConfig("Database backend not found.")); + return Err(Error::bad_config("Database backend not found.")); #[cfg(feature = "sqlite")] Arc::new(Arc::::open(&config)?) }, "rocksdb" => { debug!("Got rocksdb database backend"); #[cfg(not(feature = "rocksdb"))] - return Err(Error::BadConfig("Database backend not found.")); + return Err(Error::bad_config("Database backend not found.")); #[cfg(feature = "rocksdb")] Arc::new(Arc::::open(&config)?) }, _ => { - return Err(Error::BadConfig( + return Err(Error::bad_config( "Database backend not found. sqlite (not recommended) and rocksdb are the only supported backends.", )); }, diff --git a/src/main.rs b/src/main.rs index 63b3b8bf4..80679e259 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,38 +2,27 @@ use std::fs::Permissions; // not unix specific, just only for UNIX sockets stuff and *nix container checks #[cfg(unix)] use std::os::unix::fs::PermissionsExt as _; -#[cfg(unix)] -use std::path::Path; // not unix specific, just only for UNIX sockets stuff and *nix container checks -use std::{future::Future, io, net::SocketAddr, sync::atomic, time::Duration}; +// not unix specific, just only for UNIX sockets stuff and *nix container checks +use std::{io, net::SocketAddr, sync::atomic, time::Duration}; use axum::{ - extract::{DefaultBodyLimit, FromRequestParts, MatchedPath}, + extract::{DefaultBodyLimit, MatchedPath}, response::IntoResponse, - routing::{get, on, post, MethodFilter}, Router, }; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; #[cfg(feature = "axum_dual_protocol")] use axum_server_dual_protocol::ServerExt; -use conduit::api::{client_server, server_server}; pub use conduit::*; // Re-export everything from the library crate -use either::Either::{Left, Right}; -use figment::{ - providers::{Env, Format, Toml}, - Figment, -}; use http::{ header::{self, HeaderName}, - Method, StatusCode, Uri, + Method, StatusCode, }; #[cfg(unix)] use hyperlocal::SocketIncoming; -use ruma::api::{ - client::{ - error::{Error as RumaError, ErrorBody, ErrorKind}, - uiaa::UiaaResponse, - }, - IncomingRequest, +use ruma::api::client::{ + error::{Error as RumaError, ErrorBody, ErrorKind}, + uiaa::UiaaResponse, }; #[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))] use tikv_jemallocator::Jemalloc; @@ -51,650 +40,183 @@ use tower_http::{ use tracing::{debug, error, info, warn, Level}; use tracing_subscriber::{prelude::*, EnvFilter}; +mod routes; + #[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))] #[global_allocator] static GLOBAL: Jemalloc = Jemalloc; -fn main() { - let args = clap::parse(); - - // Initialize config - let raw_config = if let Some(config_file_env) = Env::var("CONDUIT_CONFIG") { - Figment::new() - .merge(Toml::file(config_file_env).nested()) - .merge(Env::prefixed("CONDUIT_").global()) - } else if let Some(config_file_arg) = args.config { - Figment::new() - .merge(Toml::file(config_file_arg).nested()) - .merge(Env::prefixed("CONDUIT_").global()) - } else { - Figment::new().merge(Env::prefixed("CONDUIT_").global()) - }; - - let config = match raw_config.extract::() { - Ok(s) => s, - Err(e) => { - eprintln!("It looks like your config is invalid. The following error occurred: {e}"); - return; - }, - }; - - // don't start if we're listening on both UNIX sockets and TCP at same time - if config.is_dual_listening(&raw_config) { - return; - }; +struct Server { + config: Config, - #[cfg(feature = "sentry_telemetry")] - let _guard; + runtime: tokio::runtime::Runtime, #[cfg(feature = "sentry_telemetry")] - if config.sentry { - _guard = sentry::init(( - "https://fe2eb4536aa04949e28eff3128d64757@o4506996327251968.ingest.us.sentry.io/4506996334657536", - sentry::ClientOptions { - release: sentry::release_name!(), - traces_sample_rate: config.sentry_traces_sample_rate, - server_name: if config.sentry_send_server_name { - Some(config.server_name.to_string().into()) - } else { - None - }, - ..Default::default() - }, - )); - } - - if config.allow_jaeger { - #[cfg(feature = "perf_measurements")] - { - opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); - let tracer = opentelemetry_jaeger::new_agent_pipeline() - .with_auto_split_batch(true) - .with_service_name("conduwuit") - .install_batch(opentelemetry_sdk::runtime::Tokio) - .unwrap(); - let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); - - let filter_layer = match EnvFilter::try_new(&config.log) { - Ok(s) => s, - Err(e) => { - eprintln!("It looks like your log config is invalid. The following error occurred: {e}"); - EnvFilter::try_new("warn").unwrap() - }, - }; - - let subscriber = tracing_subscriber::Registry::default() - .with(filter_layer) - .with(telemetry); - tracing::subscriber::set_global_default(subscriber).unwrap(); - } - } else if config.tracing_flame { - #[cfg(feature = "perf_measurements")] - { - let registry = tracing_subscriber::Registry::default(); - let (flame_layer, _guard) = tracing_flame::FlameLayer::with_file("./tracing.folded").unwrap(); - let flame_layer = flame_layer.with_empty_samples(false); - - let filter_layer = EnvFilter::new("trace,h2=off"); - - let subscriber = registry.with(filter_layer).with(flame_layer); - tracing::subscriber::set_global_default(subscriber).unwrap(); - } - } else { - let registry = tracing_subscriber::Registry::default(); - let fmt_layer = tracing_subscriber::fmt::Layer::new(); - let filter_layer = match EnvFilter::try_new(&config.log) { - Ok(s) => s, - Err(e) => { - eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}"); - EnvFilter::try_new("warn").unwrap() - }, - }; - - #[cfg(feature = "sentry_telemetry")] - let sentry_layer = sentry_tracing::layer(); - - let subscriber; + _sentry_guard: Option, +} - #[allow(clippy::unnecessary_operation)] // error[E0658]: attributes on expressions are experimental - #[cfg(feature = "sentry_telemetry")] - { - subscriber = registry - .with(filter_layer) - .with(fmt_layer) - .with(sentry_layer); - }; +fn main() -> Result<(), Error> { + let args = clap::parse(); + let conduwuit: Server = init(args)?; - #[allow(clippy::unnecessary_operation)] // error[E0658]: attributes on expressions are experimental - #[cfg(not(feature = "sentry_telemetry"))] - { - subscriber = registry.with(filter_layer).with(fmt_layer); - }; + conduwuit + .runtime + .block_on(async { async_main(&conduwuit).await }) +} - tracing::subscriber::set_global_default(subscriber).unwrap(); +async fn async_main(server: &Server) -> Result<(), Error> { + if let Err(error) = start(server).await { + error!("Critical error starting server: {error}"); + return Err(Error::Error(format!("{error}"))); } - #[cfg(feature = "sentry_telemetry")] - if config.sentry { - // just notifying the user - info!("Sentry.io crash reporting and telemetry is enabled"); - } + if let Err(error) = run(server).await { + error!("Critical error running server: {error}"); + return Err(Error::Error(format!("{error}"))); + }; - if let Err(e) = check_config(&config) { - error!("Config check failed: {e}"); - return; + if let Err(error) = stop(server).await { + error!("Critical error stopping server: {error}"); + return Err(Error::Error(format!("{error}"))); } - // This is needed for opening lots of file descriptors, which tends to - // happen more often when using RocksDB and making lots of federation - // connections at startup. The soft limit is usually 1024, and the hard - // limit is usually 512000; I've personally seen it hit >2000. - // - // * https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6 - // * https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741 - #[cfg(unix)] - maximize_fd_limit().expect("Unable to increase maximum soft and hard file descriptor limit"); - - tokio::runtime::Builder::new_multi_thread() - .enable_io() - .enable_time() - .thread_name("conduwuit:worker") - .worker_threads(num_cpus::get_physical()) - .build() - .unwrap() - .block_on(async { - info!("Loading database"); - let db_load_time = std::time::Instant::now(); - if let Err(error) = KeyValueDatabase::load_or_create(config).await { - error!(?error, "The database couldn't be loaded or created"); - return; - }; - info!("Database took {:?} to load, now starting server", db_load_time.elapsed()); - - if let Err(e) = run_server().await { - error!("Critical error starting server: {e}"); - }; - }); + Ok(()) } -async fn run_server() -> io::Result<()> { - let config = &services().globals.config; - - let addrs = match &config.port.ports { - Left(port) => { - // Left is only 1 value, so make a vec with 1 value only - let port_vec = [port]; - - port_vec - .iter() - .copied() - .map(|port| SocketAddr::from((config.address, *port))) - .collect::>() - }, - Right(ports) => ports - .iter() - .copied() - .map(|port| SocketAddr::from((config.address, port))) - .collect::>(), - }; - - let x_requested_with = HeaderName::from_static("x-requested-with"); - let x_forwarded_for = HeaderName::from_static("x-forwarded-for"); - - let base_middlewares = ServiceBuilder::new(); - - #[cfg(feature = "sentry_telemetry")] - let base_middlewares = base_middlewares.layer(sentry_tower::NewSentryLayer::>::new_from_top()); +async fn run(server: &Server) -> io::Result<()> { + let app = build(server).await?; + let (tx, rx) = oneshot::channel::<()>(); + let handle = ServerHandle::new(); + tokio::spawn(shutdown(handle.clone(), tx)); - let middlewares = base_middlewares - .sensitive_headers([header::AUTHORIZATION]) - .sensitive_request_headers([x_forwarded_for].into()) - .layer(axum::middleware::from_fn(spawn_task)) - .layer( - TraceLayer::new_for_http() - .make_span_with(|request: &http::Request<_>| { - let path = if let Some(path) = request.extensions().get::() { - path.as_str() - } else { - request.uri().path() - }; - - tracing::info_span!("http_request", %path) - }) - .on_failure(DefaultOnFailure::new().level(Level::INFO)), - ) - .layer(axum::middleware::from_fn(unrecognized_method)) - .layer( - CorsLayer::new() - .allow_origin(cors::Any) - .allow_methods([ - Method::GET, - Method::HEAD, - Method::POST, - Method::PUT, - Method::DELETE, - Method::OPTIONS, - ]) - .allow_headers([ - header::ORIGIN, - x_requested_with, - header::CONTENT_TYPE, - header::ACCEPT, - header::AUTHORIZATION, - ]) - .max_age(Duration::from_secs(86400)), - ) - .layer(DefaultBodyLimit::max( - config - .max_request_size - .try_into() - .expect("failed to convert max request size"), - )); + #[cfg(unix)] + if server.config.unix_socket_path.is_some() { + return run_unix_socket_server(server, app, rx).await; + } - #[cfg(any(feature = "zstd_compresion", feature = "gzip_compression", feature = "brotli_compression"))] - let mut compression_layer = tower_http::compression::CompressionLayer::new(); + let addrs = server.config.get_bind_addrs(); + if server.config.tls.is_some() { + return run_tls_server(server, app, handle, addrs).await; + } - #[cfg(feature = "zstd_compression")] - { - if config.zstd_compression { - compression_layer = compression_layer.zstd(true); - } else { - compression_layer = compression_layer.no_zstd(); - }; - }; + let mut join_set = JoinSet::new(); + for addr in &addrs { + join_set.spawn(bind(*addr).handle(handle.clone()).serve(app.clone())); + } - #[cfg(feature = "gzip_compression")] - { - if config.gzip_compression { - compression_layer = compression_layer.gzip(true); - } else { - compression_layer = compression_layer.no_gzip(); - }; - }; + #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental + #[cfg(feature = "systemd")] + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); - #[cfg(feature = "brotli_compression")] - { - if config.brotli_compression { - compression_layer = compression_layer.br(true); - } else { - compression_layer = compression_layer.no_br(); - }; - }; + info!("Listening on {:?}", addrs); + join_set.join_next().await; - let app; + Ok(()) +} - #[cfg(any(feature = "zstd_compresion", feature = "gzip_compression", feature = "brotli_compression"))] - { - app = routes() - .layer(compression_layer) - .layer(middlewares) - .into_make_service(); - }; +async fn run_tls_server( + server: &Server, app: axum::routing::IntoMakeService, handle: ServerHandle, addrs: Vec, +) -> io::Result<()> { + let tls = server.config.tls.as_ref().unwrap(); - #[cfg(not(any(feature = "zstd_compresion", feature = "gzip_compression", feature = "brotli_compression")))] - { - app = routes().layer(middlewares).into_make_service(); - }; + debug!( + "Using direct TLS. Certificate path {} and certificate private key path {}", + &tls.certs, &tls.key + ); + info!( + "Note: It is strongly recommended that you use a reverse proxy instead of running conduwuit directly with TLS." + ); + let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?; - let handle = ServerHandle::new(); + if cfg!(feature = "axum_dual_protocol") { + info!( + "conduwuit was built with axum_dual_protocol feature to listen on both HTTP and HTTPS. This will only \ + take affect if `dual_protocol` is enabled in `[global.tls]`" + ); + } - #[allow(unused_variables)] // only rx is unused on non-*nix platforms - let (tx, rx) = oneshot::channel::<()>(); + let mut join_set = JoinSet::new(); - tokio::spawn(shutdown_signal(handle.clone(), tx)); - - #[allow(unused_variables)] // path is unused on non-*nix platforms - if let Some(path) = &config.unix_socket_path { - #[cfg(unix)] - { - if path.exists() { - warn!( - "UNIX socket path {:#?} already exists (unclean shutdown?), attempting to remove it.", - path.display() - ); - tokio::fs::remove_file(&path).await?; - } - - tokio::fs::create_dir_all(path.parent().unwrap()).await?; - - let socket_perms = config.unix_socket_perms.to_string(); - let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap(); - - let listener = tokio::net::UnixListener::bind(path.clone())?; - tokio::fs::set_permissions(path, Permissions::from_mode(octal_perms)) - .await - .unwrap(); - let socket = SocketIncoming::from_listener(listener); - - #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental - #[cfg(feature = "systemd")] - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); - - info!("Listening at {:?}", path); - let server = hyper::Server::builder(socket).serve(app); - let graceful = server.with_graceful_shutdown(async { - rx.await.ok(); - }); - - if let Err(e) = graceful.await { - error!("Server error: {:?}", e); - } + if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol { + #[cfg(feature = "axum_dual_protocol")] + for addr in &addrs { + join_set.spawn( + axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone()) + .set_upgrade(false) + .handle(handle.clone()) + .serve(app.clone()), + ); } } else { - match &config.tls { - Some(tls) => { - debug!( - "Using direct TLS. Certificate path {} and certificate private key path {}", - &tls.certs, &tls.key - ); - info!( - "Note: It is strongly recommended that you use a reverse proxy instead of running conduwuit \ - directly with TLS." - ); - let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?; - - if cfg!(feature = "axum_dual_protocol") { - info!( - "conduwuit was built with axum_dual_protocol feature to listen on both HTTP and HTTPS. This \ - will only take affect if `dual_protocol` is enabled in `[global.tls]`" - ); - } - - let mut join_set = JoinSet::new(); - - if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol { - #[cfg(feature = "axum_dual_protocol")] - for addr in &addrs { - join_set.spawn( - axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone()) - .set_upgrade(false) - .handle(handle.clone()) - .serve(app.clone()), - ); - } - } else { - for addr in &addrs { - join_set.spawn( - bind_rustls(*addr, conf.clone()) - .handle(handle.clone()) - .serve(app.clone()), - ); - } - } - - #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental - #[cfg(feature = "systemd")] - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); - - if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol { - warn!( - "Listening on {:?} with TLS certificate {} and supporting plain text (HTTP) connections too \ - (insecure!)", - addrs, &tls.certs - ); - } else { - info!("Listening on {:?} with TLS certificate {}", addrs, &tls.certs); - } - - join_set.join_next().await; - }, - None => { - let mut join_set = JoinSet::new(); - for addr in &addrs { - join_set.spawn(bind(*addr).handle(handle.clone()).serve(app.clone())); - } - - #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental - #[cfg(feature = "systemd")] - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); - - info!("Listening on {:?}", addrs); - join_set.join_next().await; - }, + for addr in &addrs { + join_set.spawn( + bind_rustls(*addr, conf.clone()) + .handle(handle.clone()) + .serve(app.clone()), + ); } } + #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental + #[cfg(feature = "systemd")] + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); + + if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol { + warn!( + "Listening on {:?} with TLS certificate {} and supporting plain text (HTTP) connections too (insecure!)", + addrs, &tls.certs + ); + } else { + info!("Listening on {:?} with TLS certificate {}", addrs, &tls.certs); + } + + join_set.join_next().await; + Ok(()) } -async fn spawn_task( - req: http::Request, next: axum::middleware::Next, -) -> std::result::Result { - if services().globals.shutdown.load(atomic::Ordering::Relaxed) { - return Err(StatusCode::SERVICE_UNAVAILABLE); +#[cfg(unix)] +async fn run_unix_socket_server( + server: &Server, app: axum::routing::IntoMakeService, rx: oneshot::Receiver<()>, +) -> io::Result<()> { + let path = server.config.unix_socket_path.as_ref().unwrap(); + + if path.exists() { + warn!( + "UNIX socket path {:#?} already exists (unclean shutdown?), attempting to remove it.", + path.display() + ); + tokio::fs::remove_file(&path).await?; } - tokio::spawn(next.run(req)) + + tokio::fs::create_dir_all(path.parent().unwrap()).await?; + + let socket_perms = server.config.unix_socket_perms.to_string(); + let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap(); + + let listener = tokio::net::UnixListener::bind(path.clone())?; + tokio::fs::set_permissions(path, Permissions::from_mode(octal_perms)) .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) -} + .unwrap(); + let socket = SocketIncoming::from_listener(listener); -async fn unrecognized_method( - req: http::Request, next: axum::middleware::Next, -) -> std::result::Result { - let method = req.method().clone(); - let uri = req.uri().clone(); - let inner = next.run(req).await; - if inner.status() == StatusCode::METHOD_NOT_ALLOWED { - if uri.path().contains("_matrix/") { - warn!("Method not allowed: {method} {uri}"); - } else { - info!("Method not allowed: {method} {uri}"); - } - return Ok(RumaResponse(UiaaResponse::MatrixError(RumaError { - body: ErrorBody::Standard { - kind: ErrorKind::Unrecognized, - message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(), - }, - status_code: StatusCode::METHOD_NOT_ALLOWED, - })) - .into_response()); + #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental + #[cfg(feature = "systemd")] + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); + info!("Listening at {:?}", path); + let server = hyper::Server::builder(socket).serve(app); + let graceful = server.with_graceful_shutdown(async { + rx.await.ok(); + }); + + if let Err(e) = graceful.await { + error!("Server error: {:?}", e); } - Ok(inner) -} -fn routes() -> Router { - Router::new() - .ruma_route(client_server::get_supported_versions_route) - .ruma_route(client_server::get_register_available_route) - .ruma_route(client_server::register_route) - .ruma_route(client_server::get_login_types_route) - .ruma_route(client_server::login_route) - .ruma_route(client_server::whoami_route) - .ruma_route(client_server::logout_route) - .ruma_route(client_server::logout_all_route) - .ruma_route(client_server::change_password_route) - .ruma_route(client_server::deactivate_route) - .ruma_route(client_server::third_party_route) - .ruma_route(client_server::request_3pid_management_token_via_email_route) - .ruma_route(client_server::request_3pid_management_token_via_msisdn_route) - .ruma_route(client_server::get_capabilities_route) - .ruma_route(client_server::get_pushrules_all_route) - .ruma_route(client_server::set_pushrule_route) - .ruma_route(client_server::get_pushrule_route) - .ruma_route(client_server::set_pushrule_enabled_route) - .ruma_route(client_server::get_pushrule_enabled_route) - .ruma_route(client_server::get_pushrule_actions_route) - .ruma_route(client_server::set_pushrule_actions_route) - .ruma_route(client_server::delete_pushrule_route) - .ruma_route(client_server::get_room_event_route) - .ruma_route(client_server::get_room_aliases_route) - .ruma_route(client_server::get_filter_route) - .ruma_route(client_server::create_filter_route) - .ruma_route(client_server::set_global_account_data_route) - .ruma_route(client_server::set_room_account_data_route) - .ruma_route(client_server::get_global_account_data_route) - .ruma_route(client_server::get_room_account_data_route) - .ruma_route(client_server::set_displayname_route) - .ruma_route(client_server::get_displayname_route) - .ruma_route(client_server::set_avatar_url_route) - .ruma_route(client_server::get_avatar_url_route) - .ruma_route(client_server::get_profile_route) - .ruma_route(client_server::set_presence_route) - .ruma_route(client_server::get_presence_route) - .ruma_route(client_server::upload_keys_route) - .ruma_route(client_server::get_keys_route) - .ruma_route(client_server::claim_keys_route) - .ruma_route(client_server::create_backup_version_route) - .ruma_route(client_server::update_backup_version_route) - .ruma_route(client_server::delete_backup_version_route) - .ruma_route(client_server::get_latest_backup_info_route) - .ruma_route(client_server::get_backup_info_route) - .ruma_route(client_server::add_backup_keys_route) - .ruma_route(client_server::add_backup_keys_for_room_route) - .ruma_route(client_server::add_backup_keys_for_session_route) - .ruma_route(client_server::delete_backup_keys_for_room_route) - .ruma_route(client_server::delete_backup_keys_for_session_route) - .ruma_route(client_server::delete_backup_keys_route) - .ruma_route(client_server::get_backup_keys_for_room_route) - .ruma_route(client_server::get_backup_keys_for_session_route) - .ruma_route(client_server::get_backup_keys_route) - .ruma_route(client_server::set_read_marker_route) - .ruma_route(client_server::create_receipt_route) - .ruma_route(client_server::create_typing_event_route) - .ruma_route(client_server::create_room_route) - .ruma_route(client_server::redact_event_route) - .ruma_route(client_server::report_event_route) - .ruma_route(client_server::create_alias_route) - .ruma_route(client_server::delete_alias_route) - .ruma_route(client_server::get_alias_route) - .ruma_route(client_server::join_room_by_id_route) - .ruma_route(client_server::join_room_by_id_or_alias_route) - .ruma_route(client_server::joined_members_route) - .ruma_route(client_server::leave_room_route) - .ruma_route(client_server::forget_room_route) - .ruma_route(client_server::joined_rooms_route) - .ruma_route(client_server::kick_user_route) - .ruma_route(client_server::ban_user_route) - .ruma_route(client_server::unban_user_route) - .ruma_route(client_server::invite_user_route) - .ruma_route(client_server::set_room_visibility_route) - .ruma_route(client_server::get_room_visibility_route) - .ruma_route(client_server::get_public_rooms_route) - .ruma_route(client_server::get_public_rooms_filtered_route) - .ruma_route(client_server::search_users_route) - .ruma_route(client_server::get_member_events_route) - .ruma_route(client_server::get_protocols_route) - .ruma_route(client_server::send_message_event_route) - .ruma_route(client_server::send_state_event_for_key_route) - .ruma_route(client_server::get_state_events_route) - .ruma_route(client_server::get_state_events_for_key_route) - // Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes - // share one Ruma request / response type pair with {get,send}_state_event_for_key_route - .route( - "/_matrix/client/r0/rooms/:room_id/state/:event_type", - get(client_server::get_state_events_for_empty_key_route) - .put(client_server::send_state_event_for_empty_key_route), - ) - .route( - "/_matrix/client/v3/rooms/:room_id/state/:event_type", - get(client_server::get_state_events_for_empty_key_route) - .put(client_server::send_state_event_for_empty_key_route), - ) - // These two endpoints allow trailing slashes - .route( - "/_matrix/client/r0/rooms/:room_id/state/:event_type/", - get(client_server::get_state_events_for_empty_key_route) - .put(client_server::send_state_event_for_empty_key_route), - ) - .route( - "/_matrix/client/v3/rooms/:room_id/state/:event_type/", - get(client_server::get_state_events_for_empty_key_route) - .put(client_server::send_state_event_for_empty_key_route), - ) - .ruma_route(client_server::sync_events_route) - .ruma_route(client_server::sync_events_v4_route) - .ruma_route(client_server::get_context_route) - .ruma_route(client_server::get_message_events_route) - .ruma_route(client_server::search_events_route) - .ruma_route(client_server::turn_server_route) - .ruma_route(client_server::send_event_to_device_route) - .ruma_route(client_server::get_media_config_route) - .ruma_route(client_server::get_media_preview_route) - .ruma_route(client_server::create_content_route) - // legacy v1 media routes - .route( - "/_matrix/media/v1/preview_url", - get(client_server::get_media_preview_v1_route) - ) - .route( - "/_matrix/media/v1/config", - get(client_server::get_media_config_v1_route) - ) - .route( - "/_matrix/media/v1/upload", - post(client_server::create_content_v1_route) - ) - .route( - "/_matrix/media/v1/download/:server_name/:media_id", - get(client_server::get_content_v1_route) - ) - .route( - "/_matrix/media/v1/download/:server_name/:media_id/:file_name", - get(client_server::get_content_as_filename_v1_route) - ) - .route( - "/_matrix/media/v1/thumbnail/:server_name/:media_id", - get(client_server::get_content_thumbnail_v1_route) - ) - .ruma_route(client_server::get_content_route) - .ruma_route(client_server::get_content_as_filename_route) - .ruma_route(client_server::get_content_thumbnail_route) - .ruma_route(client_server::get_devices_route) - .ruma_route(client_server::get_device_route) - .ruma_route(client_server::update_device_route) - .ruma_route(client_server::delete_device_route) - .ruma_route(client_server::delete_devices_route) - .ruma_route(client_server::get_tags_route) - .ruma_route(client_server::update_tag_route) - .ruma_route(client_server::delete_tag_route) - .ruma_route(client_server::upload_signing_keys_route) - .ruma_route(client_server::upload_signatures_route) - .ruma_route(client_server::get_key_changes_route) - .ruma_route(client_server::get_pushers_route) - .ruma_route(client_server::set_pushers_route) - // .ruma_route(client_server::third_party_route) - .ruma_route(client_server::upgrade_room_route) - .ruma_route(client_server::get_threads_route) - .ruma_route(client_server::get_relating_events_with_rel_type_and_event_type_route) - .ruma_route(client_server::get_relating_events_with_rel_type_route) - .ruma_route(client_server::get_relating_events_route) - .ruma_route(client_server::get_hierarchy_route) - .ruma_route(server_server::get_server_version_route) - .route("/_matrix/key/v2/server", get(server_server::get_server_keys_route)) - .route( - "/_matrix/key/v2/server/:key_id", - get(server_server::get_server_keys_deprecated_route), - ) - .ruma_route(server_server::get_public_rooms_route) - .ruma_route(server_server::get_public_rooms_filtered_route) - .ruma_route(server_server::send_transaction_message_route) - .ruma_route(server_server::get_event_route) - .ruma_route(server_server::get_backfill_route) - .ruma_route(server_server::get_missing_events_route) - .ruma_route(server_server::get_event_authorization_route) - .ruma_route(server_server::get_room_state_route) - .ruma_route(server_server::get_room_state_ids_route) - .ruma_route(server_server::create_join_event_template_route) - .ruma_route(server_server::create_join_event_v1_route) - .ruma_route(server_server::create_join_event_v2_route) - .ruma_route(server_server::create_invite_route) - .ruma_route(server_server::get_devices_route) - .ruma_route(server_server::get_room_information_route) - .ruma_route(server_server::get_profile_information_route) - .ruma_route(server_server::get_keys_route) - .ruma_route(server_server::claim_keys_route) - .ruma_route(server_server::get_hierarchy_route) - .route("/_conduwuit/server_version", get(client_server::conduwuit_server_version)) - .route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync)) - .route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync)) - .route("/client/server.json", get(client_server::syncv3_client_server_json)) - .route("/.well-known/matrix/client", get(client_server::well_known_client_route)) - .route("/.well-known/matrix/server", get(server_server::well_known_server_route)) - .route("/", get(it_works)) - .fallback(not_found) + Ok(()) } -async fn shutdown_signal(handle: ServerHandle, tx: Sender<()>) -> Result<()> { +async fn shutdown(handle: ServerHandle, tx: Sender<()>) -> Result<()> { let ctrl_c = async { signal::ctrl_c() .await @@ -710,22 +232,18 @@ async fn shutdown_signal(handle: ServerHandle, tx: Sender<()>) -> Result<()> { }; let sig: &str; - #[cfg(unix)] tokio::select! { () = ctrl_c => { sig = "Ctrl+C"; }, () = terminate => { sig = "SIGTERM"; }, } - #[cfg(not(unix))] tokio::select! { _ = ctrl_c => { sig = "Ctrl+C"; }, } warn!("Received {}, shutting down...", sig); - let shutdown_time_elapsed = tokio::time::Instant::now(); handle.graceful_shutdown(Some(Duration::from_secs(180))); - services().globals.shutdown(); #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental @@ -737,316 +255,319 @@ async fn shutdown_signal(handle: ServerHandle, tx: Sender<()>) -> Result<()> { system may not be in an okay/ideal state.)", ); - if shutdown_time_elapsed.elapsed() >= Duration::from_secs(60) && cfg!(feature = "systemd") { - warn!( - "Still shutting down after 60 seconds since receiving shutdown signal, asking systemd for more time (+120 \ - seconds). Remaining connections: {}", - handle.connection_count() - ); - - #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental - #[cfg(feature = "systemd")] - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::ExtendTimeoutUsec(120)]); - } - - warn!("Time took to shutdown: {:?} seconds", shutdown_time_elapsed.elapsed()); - Ok(()) } -async fn not_found(uri: Uri) -> impl IntoResponse { - if uri.path().contains("_matrix/") { - warn!("Not found: {uri}"); - } else { - info!("Not found: {uri}"); - } +async fn stop(_server: &Server) -> io::Result<()> { + info!("Shutdown complete."); - Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request") + Ok(()) } -async fn initial_sync(_uri: Uri) -> impl IntoResponse { - Error::BadRequest(ErrorKind::GuestAccessForbidden, "Guest access not implemented") +/// Async initializations +async fn start(server: &Server) -> Result<(), Error> { + let db_load_time = std::time::Instant::now(); + KeyValueDatabase::load_or_create(server.config.clone()).await?; + info!("Database took {:?} to load", db_load_time.elapsed()); + + Ok(()) } -async fn it_works() -> &'static str { "hewwo from conduwuit woof!" } +async fn build(server: &Server) -> io::Result> { + let base_middlewares = ServiceBuilder::new(); + #[cfg(feature = "sentry_telemetry")] + let base_middlewares = base_middlewares.layer(sentry_tower::NewSentryLayer::>::new_from_top()); -trait RouterExt { - fn ruma_route(self, handler: H) -> Self - where - H: RumaHandler, - T: 'static; -} + let x_forwarded_for = HeaderName::from_static("x-forwarded-for"); + let middlewares = base_middlewares + .sensitive_headers([header::AUTHORIZATION]) + .sensitive_request_headers([x_forwarded_for].into()) + .layer(axum::middleware::from_fn(request_spawn)) + .layer( + TraceLayer::new_for_http() + .make_span_with(tracing_span::<_>) + .on_failure(DefaultOnFailure::new().level(Level::INFO)), + ) + .layer(axum::middleware::from_fn(request_handler)) + .layer(cors_layer(server)) + .layer(DefaultBodyLimit::max( + server + .config + .max_request_size + .try_into() + .expect("failed to convert max request size"), + )); -impl RouterExt for Router { - fn ruma_route(self, handler: H) -> Self - where - H: RumaHandler, - T: 'static, + #[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))] + { + Ok(routes::routes() + .layer(compression_layer(server)) + .layer(middlewares) + .into_make_service()) + } + #[cfg(not(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression")))] { - handler.add_to_router(self) + Ok(routes::routes().layer(middlewares).into_make_service()) } } -pub trait RumaHandler { - // Can't transform to a handler without boxing or relying on the nightly-only - // impl-trait-in-traits feature. Moving a small amount of extra logic into the - // trait allows bypassing both. - fn add_to_router(self, router: Router) -> Router; -} - -macro_rules! impl_ruma_handler { - ( $($ty:ident),* $(,)? ) => { - #[axum::async_trait] - #[allow(non_snake_case)] - impl RumaHandler<($($ty,)* Ruma,)> for F - where - Req: IncomingRequest + Send + 'static, - F: FnOnce($($ty,)* Ruma) -> Fut + Clone + Send + 'static, - Fut: Future> - + Send, - E: IntoResponse, - $( $ty: FromRequestParts<()> + Send + 'static, )* - { - fn add_to_router(self, mut router: Router) -> Router { - let meta = Req::METADATA; - let method_filter = method_to_filter(meta.method); - - for path in meta.history.all_paths() { - let handler = self.clone(); - - router = router.route(path, on(method_filter, |$( $ty: $ty, )* req| async move { - handler($($ty,)* req).await.map(RumaResponse) - })) - } - - router - } - } - }; +async fn request_spawn( + req: http::Request, next: axum::middleware::Next, +) -> std::result::Result { + if services().globals.shutdown.load(atomic::Ordering::Relaxed) { + return Err(StatusCode::SERVICE_UNAVAILABLE); + } + tokio::spawn(next.run(req)) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } -impl_ruma_handler!(); -impl_ruma_handler!(T1); -impl_ruma_handler!(T1, T2); -impl_ruma_handler!(T1, T2, T3); -impl_ruma_handler!(T1, T2, T3, T4); -impl_ruma_handler!(T1, T2, T3, T4, T5); -impl_ruma_handler!(T1, T2, T3, T4, T5, T6); -impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7); -impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7, T8); - -fn method_to_filter(method: Method) -> MethodFilter { - match method { - Method::DELETE => MethodFilter::DELETE, - Method::GET => MethodFilter::GET, - Method::HEAD => MethodFilter::HEAD, - Method::OPTIONS => MethodFilter::OPTIONS, - Method::PATCH => MethodFilter::PATCH, - Method::POST => MethodFilter::POST, - Method::PUT => MethodFilter::PUT, - Method::TRACE => MethodFilter::TRACE, - m => panic!("Unsupported HTTP method: {m:?}"), +async fn request_handler( + req: http::Request, next: axum::middleware::Next, +) -> std::result::Result { + let method = req.method().clone(); + let uri = req.uri().clone(); + let inner = next.run(req).await; + if inner.status() == StatusCode::METHOD_NOT_ALLOWED { + if uri.path().contains("_matrix/") { + warn!("Method not allowed: {method} {uri}"); + } else { + info!("Method not allowed: {method} {uri}"); + } + return Ok(RumaResponse(UiaaResponse::MatrixError(RumaError { + body: ErrorBody::Standard { + kind: ErrorKind::Unrecognized, + message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(), + }, + status_code: StatusCode::METHOD_NOT_ALLOWED, + })) + .into_response()); } -} -#[cfg(unix)] -#[tracing::instrument(err)] -fn maximize_fd_limit() -> Result<(), nix::errno::Errno> { - use nix::sys::resource::{getrlimit, setrlimit, Resource}; + Ok(inner) +} - let res = Resource::RLIMIT_NOFILE; +fn cors_layer(_server: &Server) -> CorsLayer { + let methods = [ + Method::GET, + Method::HEAD, + Method::POST, + Method::PUT, + Method::DELETE, + Method::OPTIONS, + ]; + + let headers = [ + header::ORIGIN, + HeaderName::from_static("x-requested-with"), + header::CONTENT_TYPE, + header::ACCEPT, + header::AUTHORIZATION, + ]; + + CorsLayer::new() + .allow_origin(cors::Any) + .allow_methods(methods) + .allow_headers(headers) + .max_age(Duration::from_secs(86400)) +} - let (soft_limit, hard_limit) = getrlimit(res)?; +#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))] +fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer { + let mut compression_layer = tower_http::compression::CompressionLayer::new(); - debug!("Current nofile soft limit: {soft_limit}"); + #[cfg(feature = "zstd_compression")] + { + if server.config.zstd_compression { + compression_layer = compression_layer.zstd(true); + } else { + compression_layer = compression_layer.no_zstd(); + }; + }; - setrlimit(res, hard_limit, hard_limit)?; + #[cfg(feature = "gzip_compression")] + { + if server.config.gzip_compression { + compression_layer = compression_layer.gzip(true); + } else { + compression_layer = compression_layer.no_gzip(); + }; + }; - debug!("Increased nofile soft limit to {hard_limit}"); + #[cfg(feature = "brotli_compression")] + { + if server.config.brotli_compression { + compression_layer = compression_layer.br(true); + } else { + compression_layer = compression_layer.no_br(); + }; + }; - Ok(()) + compression_layer } -fn check_config(config: &Config) -> Result<()> { - config.warn_deprecated(); - config.warn_unknown_key(); - - if config.unix_socket_path.is_some() && !cfg!(unix) { - return Err(Error::bad_config( - "UNIX socket support is only available on *nix platforms. Please remove \"unix_socket_path\" from your \ - config.", - )); - } - - if config.address.is_loopback() && cfg!(unix) { - debug!( - "Found loopback listening address {}, running checks if we're in a container.", - config.address - ); +fn tracing_span(request: &http::Request) -> tracing::Span { + let path = if let Some(path) = request.extensions().get::() { + path.as_str() + } else { + request.uri().path() + }; - #[cfg(unix)] - if Path::new("/proc/vz").exists() /* Guest */ && !Path::new("/proc/bz").exists() - /* Host */ - { - error!( - "You are detected using OpenVZ with a loopback/localhost listening address of {}. If you are using \ - OpenVZ for containers and you use NAT-based networking to communicate with the host and guest, this \ - will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can ignore.", - config.address - ); - } + tracing::info_span!("handle", %path) +} - #[cfg(unix)] - if Path::new("/.dockerenv").exists() { - error!( - "You are detected using Docker with a loopback/localhost listening address of {}. If you are using a \ - reverse proxy on the host and require communication to conduwuit in the Docker container via \ - NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, \ - you can ignore.", - config.address - ); - } +/// Non-async initializations +fn init(args: clap::Args) -> Result { + let config = Config::new(args.config)?; - #[cfg(unix)] - if Path::new("/run/.containerenv").exists() { - error!( - "You are detected using Podman with a loopback/localhost listening address of {}. If you are using a \ - reverse proxy on the host and require communication to conduwuit in the Podman container via \ - NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, \ - you can ignore.", - config.address - ); - } - } + #[cfg(feature = "sentry_telemetry")] + let sentry_guard = if config.sentry { + Some(init_sentry(&config)) + } else { + None + }; - // rocksdb does not allow max_log_files to be 0 - if config.rocksdb_max_log_files == 0 && cfg!(feature = "rocksdb") { - return Err(Error::bad_config( - "When using RocksDB, rocksdb_max_log_files cannot be 0. Please set a value at least 1.", - )); + if config.allow_jaeger { + #[cfg(feature = "perf_measurements")] + init_tracing_jaeger(&config); + } else if config.tracing_flame { + #[cfg(feature = "perf_measurements")] + init_tracing_flame(&config); + } else { + init_tracing_sub(&config); } - // yeah, unless the user built a debug build hopefully for local testing only - if config.server_name == "your.server.name" && !cfg!(debug_assertions) { - return Err(Error::bad_config( - "You must specify a valid server name for production usage of conduwuit.", - )); - } + info!( + server_name = ?config.server_name, + database_path = ?config.database_path, + log_levels = ?config.log, + "{}", + env!("CARGO_PKG_VERSION"), + ); - if cfg!(debug_assertions) { - info!("Note: conduwuit was built without optimisations (i.e. debug build)"); - } + #[cfg(unix)] + maximize_fd_limit().expect("Unable to increase maximum soft and hard file descriptor limit"); - // check if the user specified a registration token as `""` - if config.registration_token == Some(String::new()) { - return Err(Error::bad_config("Registration token was specified but is empty (\"\")")); - } + Ok(Server { + config, - if config.max_request_size < 16384 { - return Err(Error::bad_config("Max request size is less than 16KB. Please increase it.")); - } + runtime: tokio::runtime::Builder::new_multi_thread() + .enable_io() + .enable_time() + .thread_name("conduwuit:worker") + .worker_threads(num_cpus::get_physical()) + .build() + .unwrap(), - // check if user specified valid IP CIDR ranges on startup - for cidr in &config.ip_range_denylist { - if let Err(e) = ipaddress::IPAddress::parse(cidr) { - error!("Error parsing specified IP CIDR range from string: {e}"); - return Err(Error::bad_config("Error parsing specified IP CIDR ranges from strings")); - } - } + #[cfg(feature = "sentry_telemetry")] + _sentry_guard: sentry_guard, + }) +} - if config.allow_registration - && !config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse - && config.registration_token.is_none() - { - return Err(Error::bad_config( - "!! You have `allow_registration` enabled without a token configured in your config which means you are \ - allowing ANYONE to register on your conduwuit instance without any 2nd-step (e.g. registration token).\n -If this is not the intended behaviour, please set a registration token with the `registration_token` config option.\n -For security and safety reasons, conduwuit will shut down. If you are extra sure this is the desired behaviour you \ - want, please set the following config option to true: -`yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse`", - )); - } +#[cfg(feature = "sentry_telemetry")] +fn init_sentry(config: &Config) -> sentry::ClientInitGuard { + sentry::init(( + "https://fe2eb4536aa04949e28eff3128d64757@o4506996327251968.ingest.us.sentry.io/4506996334657536", + sentry::ClientOptions { + release: sentry::release_name!(), + traces_sample_rate: config.sentry_traces_sample_rate, + server_name: if config.sentry_send_server_name { + Some(config.server_name.to_string().into()) + } else { + None + }, + ..Default::default() + }, + )) +} - if config.allow_registration - && config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse - && config.registration_token.is_none() - { - warn!( - "Open registration is enabled via setting \ - `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` and `allow_registration` to \ - true without a registration token configured. You are expected to be aware of the risks now.\n - If this is not the desired behaviour, please set a registration token." - ); - } +fn init_tracing_sub(config: &Config) { + let registry = tracing_subscriber::Registry::default(); + let fmt_layer = tracing_subscriber::fmt::Layer::new(); + let filter_layer = match EnvFilter::try_new(&config.log) { + Ok(s) => s, + Err(e) => { + eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}"); + EnvFilter::try_new("warn").unwrap() + }, + }; - if config.allow_outgoing_presence && !config.allow_local_presence { - return Err(Error::bad_config( - "Outgoing presence requires allowing local presence. Please enable \"allow_local_presence\".", - )); - } + #[cfg(feature = "sentry_telemetry")] + let sentry_layer = sentry_tracing::layer(); - if config.allow_outgoing_presence { - warn!( - "! Outgoing federated presence is not spec compliant due to relying on PDUs and EDUs combined.\nOutgoing \ - presence will not be very reliable due to this and any issues with federated outgoing presence are very \ - likely attributed to this issue.\nIncoming presence and local presence are unaffected." - ); - } + let subscriber; - if config - .url_preview_domain_contains_allowlist - .contains(&"*".to_owned()) - { - warn!( - "All URLs are allowed for URL previews via setting \"url_preview_domain_contains_allowlist\" to \"*\". \ - This opens up significant attack surface to your server. You are expected to be aware of the risks by \ - doing this." - ); - } - if config - .url_preview_domain_explicit_allowlist - .contains(&"*".to_owned()) + #[allow(clippy::unnecessary_operation)] // error[E0658]: attributes on expressions are experimental + #[cfg(feature = "sentry_telemetry")] { - warn!( - "All URLs are allowed for URL previews via setting \"url_preview_domain_explicit_allowlist\" to \"*\". \ - This opens up significant attack surface to your server. You are expected to be aware of the risks by \ - doing this." - ); - } - if config - .url_preview_url_contains_allowlist - .contains(&"*".to_owned()) + subscriber = registry + .with(filter_layer) + .with(fmt_layer) + .with(sentry_layer); + }; + + #[allow(clippy::unnecessary_operation)] // error[E0658]: attributes on expressions are experimental + #[cfg(not(feature = "sentry_telemetry"))] { - warn!( - "All URLs are allowed for URL previews via setting \"url_preview_url_contains_allowlist\" to \"*\". This \ - opens up significant attack surface to your server. You are expected to be aware of the risks by doing \ - this." - ); - } + subscriber = registry.with(filter_layer).with(fmt_layer); + }; - Ok(()) + tracing::subscriber::set_global_default(subscriber).unwrap(); } -#[cfg(test)] -mod test { - use super::*; +#[cfg(feature = "perf_measurements")] +fn init_tracing_jaeger(config: &Config) { + opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); + let tracer = opentelemetry_jaeger::new_agent_pipeline() + .with_auto_split_batch(true) + .with_service_name("conduwuit") + .install_batch(opentelemetry_sdk::runtime::Tokio) + .unwrap(); + let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); + + let filter_layer = match EnvFilter::try_new(&config.log) { + Ok(s) => s, + Err(e) => { + eprintln!("It looks like your log config is invalid. The following error occurred: {e}"); + EnvFilter::try_new("warn").unwrap() + }, + }; - #[cfg(unix)] - #[test] - /// Tests if `maximize_fd_limit()` actually raised the soft limit to the - /// hard limit - fn maximize_fd_limit_raises_limit() { - use nix::sys::resource::{getrlimit, Resource}; + let subscriber = tracing_subscriber::Registry::default() + .with(filter_layer) + .with(telemetry); + tracing::subscriber::set_global_default(subscriber).unwrap(); +} - let res = Resource::RLIMIT_NOFILE; +#[cfg(feature = "perf_measurements")] +fn init_tracing_flame(_config: &Config) { + let registry = tracing_subscriber::Registry::default(); + let (flame_layer, _guard) = tracing_flame::FlameLayer::with_file("./tracing.folded").unwrap(); + let flame_layer = flame_layer.with_empty_samples(false); - let (_, hard_limit) = getrlimit(res).unwrap(); + let filter_layer = EnvFilter::new("trace,h2=off"); - maximize_fd_limit().unwrap(); + let subscriber = registry.with(filter_layer).with(flame_layer); + tracing::subscriber::set_global_default(subscriber).unwrap(); +} - let (soft_limit, _) = getrlimit(res).unwrap(); +// This is needed for opening lots of file descriptors, which tends to +// happen more often when using RocksDB and making lots of federation +// connections at startup. The soft limit is usually 1024, and the hard +// limit is usually 512000; I've personally seen it hit >2000. +// +// * https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6 +// * https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741 +#[cfg(unix)] +fn maximize_fd_limit() -> Result<(), nix::errno::Errno> { + use nix::sys::resource::{getrlimit, setrlimit, Resource::RLIMIT_NOFILE as NOFILE}; - assert_eq!(soft_limit, hard_limit); + let (soft_limit, hard_limit) = getrlimit(NOFILE)?; + if soft_limit < hard_limit { + setrlimit(NOFILE, hard_limit, hard_limit)?; + assert_eq!((hard_limit, hard_limit), getrlimit(NOFILE)?, "getrlimit != setrlimit"); + debug!(to = hard_limit, from = soft_limit, "Raised RLIMIT_NOFILE",); } + + Ok(()) } diff --git a/src/routes.rs b/src/routes.rs new file mode 100644 index 000000000..7f325c852 --- /dev/null +++ b/src/routes.rs @@ -0,0 +1,312 @@ +use std::future::Future; + +use axum::{ + extract::FromRequestParts, + response::IntoResponse, + routing::{get, on, post, MethodFilter}, + Router, +}; +use conduit::{ + api::{client_server, server_server}, + *, +}; +use http::{Method, Uri}; +use ruma::api::{client::error::ErrorKind, IncomingRequest}; +use tracing::{info, warn}; + +pub fn routes() -> Router { + Router::new() + .ruma_route(client_server::get_supported_versions_route) + .ruma_route(client_server::get_register_available_route) + .ruma_route(client_server::register_route) + .ruma_route(client_server::get_login_types_route) + .ruma_route(client_server::login_route) + .ruma_route(client_server::whoami_route) + .ruma_route(client_server::logout_route) + .ruma_route(client_server::logout_all_route) + .ruma_route(client_server::change_password_route) + .ruma_route(client_server::deactivate_route) + .ruma_route(client_server::third_party_route) + .ruma_route(client_server::request_3pid_management_token_via_email_route) + .ruma_route(client_server::request_3pid_management_token_via_msisdn_route) + .ruma_route(client_server::get_capabilities_route) + .ruma_route(client_server::get_pushrules_all_route) + .ruma_route(client_server::set_pushrule_route) + .ruma_route(client_server::get_pushrule_route) + .ruma_route(client_server::set_pushrule_enabled_route) + .ruma_route(client_server::get_pushrule_enabled_route) + .ruma_route(client_server::get_pushrule_actions_route) + .ruma_route(client_server::set_pushrule_actions_route) + .ruma_route(client_server::delete_pushrule_route) + .ruma_route(client_server::get_room_event_route) + .ruma_route(client_server::get_room_aliases_route) + .ruma_route(client_server::get_filter_route) + .ruma_route(client_server::create_filter_route) + .ruma_route(client_server::set_global_account_data_route) + .ruma_route(client_server::set_room_account_data_route) + .ruma_route(client_server::get_global_account_data_route) + .ruma_route(client_server::get_room_account_data_route) + .ruma_route(client_server::set_displayname_route) + .ruma_route(client_server::get_displayname_route) + .ruma_route(client_server::set_avatar_url_route) + .ruma_route(client_server::get_avatar_url_route) + .ruma_route(client_server::get_profile_route) + .ruma_route(client_server::set_presence_route) + .ruma_route(client_server::get_presence_route) + .ruma_route(client_server::upload_keys_route) + .ruma_route(client_server::get_keys_route) + .ruma_route(client_server::claim_keys_route) + .ruma_route(client_server::create_backup_version_route) + .ruma_route(client_server::update_backup_version_route) + .ruma_route(client_server::delete_backup_version_route) + .ruma_route(client_server::get_latest_backup_info_route) + .ruma_route(client_server::get_backup_info_route) + .ruma_route(client_server::add_backup_keys_route) + .ruma_route(client_server::add_backup_keys_for_room_route) + .ruma_route(client_server::add_backup_keys_for_session_route) + .ruma_route(client_server::delete_backup_keys_for_room_route) + .ruma_route(client_server::delete_backup_keys_for_session_route) + .ruma_route(client_server::delete_backup_keys_route) + .ruma_route(client_server::get_backup_keys_for_room_route) + .ruma_route(client_server::get_backup_keys_for_session_route) + .ruma_route(client_server::get_backup_keys_route) + .ruma_route(client_server::set_read_marker_route) + .ruma_route(client_server::create_receipt_route) + .ruma_route(client_server::create_typing_event_route) + .ruma_route(client_server::create_room_route) + .ruma_route(client_server::redact_event_route) + .ruma_route(client_server::report_event_route) + .ruma_route(client_server::create_alias_route) + .ruma_route(client_server::delete_alias_route) + .ruma_route(client_server::get_alias_route) + .ruma_route(client_server::join_room_by_id_route) + .ruma_route(client_server::join_room_by_id_or_alias_route) + .ruma_route(client_server::joined_members_route) + .ruma_route(client_server::leave_room_route) + .ruma_route(client_server::forget_room_route) + .ruma_route(client_server::joined_rooms_route) + .ruma_route(client_server::kick_user_route) + .ruma_route(client_server::ban_user_route) + .ruma_route(client_server::unban_user_route) + .ruma_route(client_server::invite_user_route) + .ruma_route(client_server::set_room_visibility_route) + .ruma_route(client_server::get_room_visibility_route) + .ruma_route(client_server::get_public_rooms_route) + .ruma_route(client_server::get_public_rooms_filtered_route) + .ruma_route(client_server::search_users_route) + .ruma_route(client_server::get_member_events_route) + .ruma_route(client_server::get_protocols_route) + .ruma_route(client_server::send_message_event_route) + .ruma_route(client_server::send_state_event_for_key_route) + .ruma_route(client_server::get_state_events_route) + .ruma_route(client_server::get_state_events_for_key_route) + // Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes + // share one Ruma request / response type pair with {get,send}_state_event_for_key_route + .route( + "/_matrix/client/r0/rooms/:room_id/state/:event_type", + get(client_server::get_state_events_for_empty_key_route) + .put(client_server::send_state_event_for_empty_key_route), + ) + .route( + "/_matrix/client/v3/rooms/:room_id/state/:event_type", + get(client_server::get_state_events_for_empty_key_route) + .put(client_server::send_state_event_for_empty_key_route), + ) + // These two endpoints allow trailing slashes + .route( + "/_matrix/client/r0/rooms/:room_id/state/:event_type/", + get(client_server::get_state_events_for_empty_key_route) + .put(client_server::send_state_event_for_empty_key_route), + ) + .route( + "/_matrix/client/v3/rooms/:room_id/state/:event_type/", + get(client_server::get_state_events_for_empty_key_route) + .put(client_server::send_state_event_for_empty_key_route), + ) + .ruma_route(client_server::sync_events_route) + .ruma_route(client_server::sync_events_v4_route) + .ruma_route(client_server::get_context_route) + .ruma_route(client_server::get_message_events_route) + .ruma_route(client_server::search_events_route) + .ruma_route(client_server::turn_server_route) + .ruma_route(client_server::send_event_to_device_route) + .ruma_route(client_server::get_media_config_route) + .ruma_route(client_server::get_media_preview_route) + .ruma_route(client_server::create_content_route) + // legacy v1 media routes + .route( + "/_matrix/media/v1/preview_url", + get(client_server::get_media_preview_v1_route) + ) + .route( + "/_matrix/media/v1/config", + get(client_server::get_media_config_v1_route) + ) + .route( + "/_matrix/media/v1/upload", + post(client_server::create_content_v1_route) + ) + .route( + "/_matrix/media/v1/download/:server_name/:media_id", + get(client_server::get_content_v1_route) + ) + .route( + "/_matrix/media/v1/download/:server_name/:media_id/:file_name", + get(client_server::get_content_as_filename_v1_route) + ) + .route( + "/_matrix/media/v1/thumbnail/:server_name/:media_id", + get(client_server::get_content_thumbnail_v1_route) + ) + .ruma_route(client_server::get_content_route) + .ruma_route(client_server::get_content_as_filename_route) + .ruma_route(client_server::get_content_thumbnail_route) + .ruma_route(client_server::get_devices_route) + .ruma_route(client_server::get_device_route) + .ruma_route(client_server::update_device_route) + .ruma_route(client_server::delete_device_route) + .ruma_route(client_server::delete_devices_route) + .ruma_route(client_server::get_tags_route) + .ruma_route(client_server::update_tag_route) + .ruma_route(client_server::delete_tag_route) + .ruma_route(client_server::upload_signing_keys_route) + .ruma_route(client_server::upload_signatures_route) + .ruma_route(client_server::get_key_changes_route) + .ruma_route(client_server::get_pushers_route) + .ruma_route(client_server::set_pushers_route) + // .ruma_route(client_server::third_party_route) + .ruma_route(client_server::upgrade_room_route) + .ruma_route(client_server::get_threads_route) + .ruma_route(client_server::get_relating_events_with_rel_type_and_event_type_route) + .ruma_route(client_server::get_relating_events_with_rel_type_route) + .ruma_route(client_server::get_relating_events_route) + .ruma_route(client_server::get_hierarchy_route) + .ruma_route(server_server::get_server_version_route) + .route("/_matrix/key/v2/server", get(server_server::get_server_keys_route)) + .route( + "/_matrix/key/v2/server/:key_id", + get(server_server::get_server_keys_deprecated_route), + ) + .ruma_route(server_server::get_public_rooms_route) + .ruma_route(server_server::get_public_rooms_filtered_route) + .ruma_route(server_server::send_transaction_message_route) + .ruma_route(server_server::get_event_route) + .ruma_route(server_server::get_backfill_route) + .ruma_route(server_server::get_missing_events_route) + .ruma_route(server_server::get_event_authorization_route) + .ruma_route(server_server::get_room_state_route) + .ruma_route(server_server::get_room_state_ids_route) + .ruma_route(server_server::create_join_event_template_route) + .ruma_route(server_server::create_join_event_v1_route) + .ruma_route(server_server::create_join_event_v2_route) + .ruma_route(server_server::create_invite_route) + .ruma_route(server_server::get_devices_route) + .ruma_route(server_server::get_room_information_route) + .ruma_route(server_server::get_profile_information_route) + .ruma_route(server_server::get_keys_route) + .ruma_route(server_server::claim_keys_route) + .ruma_route(server_server::get_hierarchy_route) + .route("/_conduwuit/server_version", get(client_server::conduwuit_server_version)) + .route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync)) + .route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync)) + .route("/client/server.json", get(client_server::syncv3_client_server_json)) + .route("/.well-known/matrix/client", get(client_server::well_known_client_route)) + .route("/.well-known/matrix/server", get(server_server::well_known_server_route)) + .route("/", get(it_works)) + .fallback(not_found) +} + +async fn not_found(uri: Uri) -> impl IntoResponse { + if uri.path().contains("_matrix/") { + warn!("Not found: {uri}"); + } else { + info!("Not found: {uri}"); + } + + Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request") +} + +async fn initial_sync(_uri: Uri) -> impl IntoResponse { + Error::BadRequest(ErrorKind::GuestAccessForbidden, "Guest access not implemented") +} + +async fn it_works() -> &'static str { "hewwo from conduwuit woof!" } + +trait RouterExt { + fn ruma_route(self, handler: H) -> Self + where + H: RumaHandler, + T: 'static; +} + +impl RouterExt for Router { + fn ruma_route(self, handler: H) -> Self + where + H: RumaHandler, + T: 'static, + { + handler.add_to_router(self) + } +} + +pub trait RumaHandler { + // Can't transform to a handler without boxing or relying on the nightly-only + // impl-trait-in-traits feature. Moving a small amount of extra logic into the + // trait allows bypassing both. + fn add_to_router(self, router: Router) -> Router; +} + +macro_rules! impl_ruma_handler { + ( $($ty:ident),* $(,)? ) => { + #[axum::async_trait] + #[allow(non_snake_case)] + impl RumaHandler<($($ty,)* Ruma,)> for F + where + Req: IncomingRequest + Send + 'static, + F: FnOnce($($ty,)* Ruma) -> Fut + Clone + Send + 'static, + Fut: Future> + + Send, + E: IntoResponse, + $( $ty: FromRequestParts<()> + Send + 'static, )* + { + fn add_to_router(self, mut router: Router) -> Router { + let meta = Req::METADATA; + let method_filter = method_to_filter(meta.method); + + for path in meta.history.all_paths() { + let handler = self.clone(); + + router = router.route(path, on(method_filter, |$( $ty: $ty, )* req| async move { + handler($($ty,)* req).await.map(RumaResponse) + })) + } + + router + } + } + }; +} + +impl_ruma_handler!(); +impl_ruma_handler!(T1); +impl_ruma_handler!(T1, T2); +impl_ruma_handler!(T1, T2, T3); +impl_ruma_handler!(T1, T2, T3, T4); +impl_ruma_handler!(T1, T2, T3, T4, T5); +impl_ruma_handler!(T1, T2, T3, T4, T5, T6); +impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7); +impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7, T8); + +fn method_to_filter(method: Method) -> MethodFilter { + match method { + Method::DELETE => MethodFilter::DELETE, + Method::GET => MethodFilter::GET, + Method::HEAD => MethodFilter::HEAD, + Method::OPTIONS => MethodFilter::OPTIONS, + Method::PATCH => MethodFilter::PATCH, + Method::POST => MethodFilter::POST, + Method::PUT => MethodFilter::PUT, + Method::TRACE => MethodFilter::TRACE, + m => panic!("Unsupported HTTP method: {m:?}"), + } +} diff --git a/src/service/globals/client.rs b/src/service/globals/client.rs index 2d92435df..3335cd16d 100644 --- a/src/service/globals/client.rs +++ b/src/service/globals/client.rs @@ -118,12 +118,12 @@ impl Client { #[cfg(not(feature = "gzip_compression"))] { builder = builder.no_gzip(); - } + }; #[cfg(not(feature = "brotli_compression"))] { builder = builder.no_brotli(); - } + }; if let Some(proxy) = config.proxy.to_proxy()? { Ok(builder.proxy(proxy)) diff --git a/src/utils/error.rs b/src/utils/error.rs index 5820d0856..949bdff40 100644 --- a/src/utils/error.rs +++ b/src/utils/error.rs @@ -19,7 +19,7 @@ use crate::RumaResponse; pub type Result = std::result::Result; -#[derive(Error, Debug)] +#[derive(Error)] pub enum Error { #[cfg(feature = "sqlite")] #[error("There was a problem with the connection to the sqlite database: {source}")] @@ -55,11 +55,11 @@ pub enum Error { #[from] source: std::io::Error, }, + #[error("There was a problem with your configuration file: {0}")] + BadConfig(String), #[error("{0}")] BadServerResponse(&'static str), #[error("{0}")] - BadConfig(&'static str), - #[error("{0}")] /// Don't create this directly. Use Error::bad_database instead. BadDatabase(&'static str), #[error("uiaa")] @@ -78,6 +78,8 @@ pub enum Error { InconsistentRoomState(&'static str, ruma::OwnedRoomId), #[error("{0}")] AdminCommand(&'static str), + #[error("{0}")] + Error(String), } impl Error { @@ -86,9 +88,9 @@ impl Error { Self::BadDatabase(message) } - pub fn bad_config(message: &'static str) -> Self { + pub fn bad_config(message: &str) -> Self { error!("BadConfig: {}", message); - Self::BadConfig(message) + Self::BadConfig(message.to_owned()) } } @@ -197,3 +199,7 @@ impl From for Error { impl axum::response::IntoResponse for Error { fn into_response(self) -> axum::response::Response { self.to_response().into_response() } } + +impl std::fmt::Debug for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self) } +}