Skip to content

Commit

Permalink
collab: Attach GeoIP country code to RPC sessions (#15814)
Browse files Browse the repository at this point in the history
This PR updates collab to attach the user's GeoIP country code to their
RPC session.

We source the country code from the
[`CF-IPCountry`](https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry)
header.

Release Notes:

- N/A
  • Loading branch information
maxdeviant authored Aug 5, 2024
1 parent be0ccf4 commit f11f3f2
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 31 deletions.
38 changes: 36 additions & 2 deletions crates/collab/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,53 @@ use anyhow::anyhow;
use axum::{
body::Body,
extract::{Path, Query},
http::{self, Request, StatusCode},
headers::Header,
http::{self, HeaderName, Request, StatusCode},
middleware::{self, Next},
response::IntoResponse,
routing::{get, post},
Extension, Json, Router,
};
use axum_extra::response::ErasedJson;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::sync::{Arc, OnceLock};
use tower::ServiceBuilder;

pub use extensions::fetch_extensions_from_blob_store_periodically;

pub struct CloudflareIpCountryHeader(String);

impl Header for CloudflareIpCountryHeader {
fn name() -> &'static HeaderName {
static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
}

fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i axum::http::HeaderValue>,
{
let country_code = values
.next()
.ok_or_else(axum::headers::Error::invalid)?
.to_str()
.map_err(|_| axum::headers::Error::invalid())?;

Ok(Self(country_code.to_string()))
}

fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
unimplemented!()
}
}

impl std::fmt::Display for CloudflareIpCountryHeader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Router<(), Body> {
Router::new()
.route("/user", get(get_authenticated_user))
Expand Down
30 changes: 2 additions & 28 deletions crates/collab/src/api/events.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::ips_file::IpsFile;
use crate::api::CloudflareIpCountryHeader;
use crate::{api::slack, AppState, Error, Result};
use anyhow::{anyhow, Context};
use aws_sdk_s3::primitives::ByteStream;
Expand Down Expand Up @@ -59,33 +60,6 @@ impl Header for ZedChecksumHeader {
}
}

pub struct CloudflareIpCountryHeader(String);

impl Header for CloudflareIpCountryHeader {
fn name() -> &'static HeaderName {
static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
}

fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i axum::http::HeaderValue>,
{
let country_code = values
.next()
.ok_or_else(axum::headers::Error::invalid)?
.to_str()
.map_err(|_| axum::headers::Error::invalid())?;

Ok(Self(country_code.to_string()))
}

fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
unimplemented!()
}
}

pub async fn post_crash(
Extension(app): Extension<Arc<AppState>>,
headers: HeaderMap,
Expand Down Expand Up @@ -413,7 +387,7 @@ pub async fn post_events(
let Some(last_event) = request_body.events.last() else {
return Err(Error::Http(StatusCode::BAD_REQUEST, "no events".into()))?;
};
let country_code = country_code_header.map(|h| h.0 .0);
let country_code = country_code_header.map(|h| h.to_string());

let first_event_at = chrono::Utc::now()
- chrono::Duration::milliseconds(last_event.milliseconds_since_first_event);
Expand Down
13 changes: 12 additions & 1 deletion crates/collab/src/rpc.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod connection_pool;

use crate::api::CloudflareIpCountryHeader;
use crate::{
auth,
db::{
Expand Down Expand Up @@ -152,6 +153,9 @@ struct Session {
supermaven_client: Option<Arc<SupermavenAdminApi>>,
http_client: Arc<IsahcHttpClient>,
rate_limiter: Arc<RateLimiter>,
/// The GeoIP country code for the user.
#[allow(unused)]
geoip_country_code: Option<String>,
_executor: Executor,
}

Expand Down Expand Up @@ -984,6 +988,7 @@ impl Server {
address: String,
principal: Principal,
zed_version: ZedVersion,
geoip_country_code: Option<String>,
send_connection_id: Option<oneshot::Sender<ConnectionId>>,
executor: Executor,
) -> impl Future<Output = ()> {
Expand All @@ -1009,7 +1014,10 @@ impl Server {
let executor = executor.clone();
move |duration| executor.sleep(duration)
});
tracing::Span::current().record("connection_id", format!("{}", connection_id));
tracing::Span::current()
.record("connection_id", format!("{}", connection_id))
.record("geoip_country_code", geoip_country_code.clone());

tracing::info!("connection opened");

let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
Expand Down Expand Up @@ -1039,6 +1047,7 @@ impl Server {
live_kit_client: this.app_state.live_kit_client.clone(),
http_client,
rate_limiter: this.app_state.rate_limiter.clone(),
geoip_country_code,
_executor: executor.clone(),
supermaven_client,
};
Expand Down Expand Up @@ -1395,6 +1404,7 @@ pub async fn handle_websocket_request(
ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
Extension(server): Extension<Arc<Server>>,
Extension(principal): Extension<Principal>,
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
ws: WebSocketUpgrade,
) -> axum::response::Response {
if protocol_version != rpc::PROTOCOL_VERSION {
Expand Down Expand Up @@ -1435,6 +1445,7 @@ pub async fn handle_websocket_request(
socket_address,
principal,
version,
country_code_header.map(|header| header.to_string()),
None,
Executor::Production,
)
Expand Down
2 changes: 2 additions & 0 deletions crates/collab/src/tests/test_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ impl TestServer {
client_name,
Principal::User(user),
ZedVersion(SemanticVersion::new(1, 0, 0)),
None,
Some(connection_id_tx),
Executor::Deterministic(cx.background_executor().clone()),
))
Expand Down Expand Up @@ -377,6 +378,7 @@ impl TestServer {
"dev-server".to_string(),
Principal::DevServer(dev_server),
ZedVersion(SemanticVersion::new(1, 0, 0)),
None,
Some(connection_id_tx),
Executor::Deterministic(cx.background_executor().clone()),
))
Expand Down

0 comments on commit f11f3f2

Please sign in to comment.