diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 35db23d58de2a5..fd2973c7705aa7 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -14,7 +14,8 @@ use anyhow::anyhow; use axum::{ body::Body, extract::{Path, Query}, - http::{self, Request, StatusCode}, + headers::Header, + http::{self, HeaderName, Request, StatusCode}, middleware::{self, Next}, response::IntoResponse, routing::{get, post}, @@ -22,11 +23,44 @@ use axum::{ }; use axum_extra::response::ErasedJson; use serde::{Deserialize, Serialize}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use tower::ServiceBuilder; pub use extensions::fetch_extensions_from_blob_store_periodically; +pub struct CloudflareIpCountryHeader(String); + +impl Header for CloudflareIpCountryHeader { + fn name() -> &'static HeaderName { + static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock = OnceLock::new(); + CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry")) + } + + fn decode<'i, I>(values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + let country_code = values + .next() + .ok_or_else(axum::headers::Error::invalid)? + .to_str() + .map_err(|_| axum::headers::Error::invalid())?; + + Ok(Self(country_code.to_string())) + } + + fn encode>(&self, _values: &mut E) { + unimplemented!() + } +} + +impl std::fmt::Display for CloudflareIpCountryHeader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + pub fn routes(rpc_server: Option>, state: Arc) -> Router<(), Body> { Router::new() .route("/user", get(get_authenticated_user)) diff --git a/crates/collab/src/api/events.rs b/crates/collab/src/api/events.rs index fa569b4ba148ba..e0cf79bb887538 100644 --- a/crates/collab/src/api/events.rs +++ b/crates/collab/src/api/events.rs @@ -1,4 +1,5 @@ use super::ips_file::IpsFile; +use crate::api::CloudflareIpCountryHeader; use crate::{api::slack, AppState, Error, Result}; use anyhow::{anyhow, Context}; use aws_sdk_s3::primitives::ByteStream; @@ -59,33 +60,6 @@ impl Header for ZedChecksumHeader { } } -pub struct CloudflareIpCountryHeader(String); - -impl Header for CloudflareIpCountryHeader { - fn name() -> &'static HeaderName { - static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock = OnceLock::new(); - CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry")) - } - - fn decode<'i, I>(values: &mut I) -> Result - where - Self: Sized, - I: Iterator, - { - let country_code = values - .next() - .ok_or_else(axum::headers::Error::invalid)? - .to_str() - .map_err(|_| axum::headers::Error::invalid())?; - - Ok(Self(country_code.to_string())) - } - - fn encode>(&self, _values: &mut E) { - unimplemented!() - } -} - pub async fn post_crash( Extension(app): Extension>, headers: HeaderMap, @@ -413,7 +387,7 @@ pub async fn post_events( let Some(last_event) = request_body.events.last() else { return Err(Error::Http(StatusCode::BAD_REQUEST, "no events".into()))?; }; - let country_code = country_code_header.map(|h| h.0 .0); + let country_code = country_code_header.map(|h| h.to_string()); let first_event_at = chrono::Utc::now() - chrono::Duration::milliseconds(last_event.milliseconds_since_first_event); diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 939ab551102759..ed836571de978a 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,5 +1,6 @@ mod connection_pool; +use crate::api::CloudflareIpCountryHeader; use crate::{ auth, db::{ @@ -152,6 +153,9 @@ struct Session { supermaven_client: Option>, http_client: Arc, rate_limiter: Arc, + /// The GeoIP country code for the user. + #[allow(unused)] + geoip_country_code: Option, _executor: Executor, } @@ -984,6 +988,7 @@ impl Server { address: String, principal: Principal, zed_version: ZedVersion, + geoip_country_code: Option, send_connection_id: Option>, executor: Executor, ) -> impl Future { @@ -1009,7 +1014,10 @@ impl Server { let executor = executor.clone(); move |duration| executor.sleep(duration) }); - tracing::Span::current().record("connection_id", format!("{}", connection_id)); + tracing::Span::current() + .record("connection_id", format!("{}", connection_id)) + .record("geoip_country_code", geoip_country_code.clone()); + tracing::info!("connection opened"); let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION")); @@ -1039,6 +1047,7 @@ impl Server { live_kit_client: this.app_state.live_kit_client.clone(), http_client, rate_limiter: this.app_state.rate_limiter.clone(), + geoip_country_code, _executor: executor.clone(), supermaven_client, }; @@ -1395,6 +1404,7 @@ pub async fn handle_websocket_request( ConnectInfo(socket_address): ConnectInfo, Extension(server): Extension>, Extension(principal): Extension, + country_code_header: Option>, ws: WebSocketUpgrade, ) -> axum::response::Response { if protocol_version != rpc::PROTOCOL_VERSION { @@ -1435,6 +1445,7 @@ pub async fn handle_websocket_request( socket_address, principal, version, + country_code_header.map(|header| header.to_string()), None, Executor::Production, ) diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index d1aa42f28b25cd..fde3082102e479 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -244,6 +244,7 @@ impl TestServer { client_name, Principal::User(user), ZedVersion(SemanticVersion::new(1, 0, 0)), + None, Some(connection_id_tx), Executor::Deterministic(cx.background_executor().clone()), )) @@ -377,6 +378,7 @@ impl TestServer { "dev-server".to_string(), Principal::DevServer(dev_server), ZedVersion(SemanticVersion::new(1, 0, 0)), + None, Some(connection_id_tx), Executor::Deterministic(cx.background_executor().clone()), ))