Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tls and auth check #49

Merged
merged 9 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 45 additions & 24 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 2 additions & 10 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ use jito_geyser_protos::solana::geyser::{
};
use prost_types::Timestamp;
use solana_sdk::pubkey::Pubkey;
use tonic::{
transport::{ClientTlsConfig, Endpoint},
Streaming,
};
use tonic::{transport::channel::Endpoint, Streaming};
use uuid::Uuid;

#[derive(Parser, Debug)]
Expand Down Expand Up @@ -72,12 +69,7 @@ async fn main() {
let args: Args = Args::parse();
println!("args: {args:?}");

let mut endpoint = Endpoint::from_str(&args.url).unwrap();
if args.url.starts_with("https://") {
endpoint = endpoint
.tls_config(ClientTlsConfig::new())
.expect("create tls config");
}
let endpoint = Endpoint::from_str(&args.url).unwrap();

let channel = endpoint.connect().await.expect("connects");

Expand Down
71 changes: 56 additions & 15 deletions server/src/geyser_grpc_plugin.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
//! Implements the geyser plugin interface.

use std::{
fs,
fs::File,
io::Read,
sync::{
atomic::{AtomicU64, Ordering},
atomic::{AtomicBool, AtomicU64, Ordering},
Arc,
},
time::SystemTime,
};
use std::sync::atomic::AtomicBool;

use bs58;
use crossbeam_channel::{bounded, Sender, TrySendError};
Expand All @@ -29,7 +29,11 @@ use solana_geyser_plugin_interface::geyser_plugin_interface::{
ReplicaTransactionInfoVersions, Result as PluginResult, SlotStatus,
};
use tokio::{runtime::Runtime, sync::oneshot};
use tonic::transport::Server;
use tonic::{
service::{interceptor::InterceptedService, Interceptor},
transport::{Identity, Server, ServerTlsConfig},
Request, Status,
};

use crate::server::{GeyserService, GeyserServiceConfig};

Expand All @@ -47,7 +51,7 @@ pub struct PluginData {
highest_write_slot: Arc<AtomicU64>,

is_startup_completed: AtomicBool,
ignore_startup_updates: bool
ignore_startup_updates: bool,
}

#[derive(Default)]
Expand All @@ -70,7 +74,7 @@ pub struct PluginConfig {
pub slot_update_buffer_size: usize,
pub block_update_buffer_size: usize,
pub transaction_update_buffer_size: usize,
pub skip_startup_stream: Option<bool>
pub skip_startup_stream: Option<bool>,
}

impl GeyserPlugin for GeyserGrpcPlugin {
Expand Down Expand Up @@ -114,7 +118,7 @@ impl GeyserPlugin for GeyserGrpcPlugin {
bounded(config.transaction_update_buffer_size);

let svc = GeyserService::new(
config.geyser_service_config,
config.geyser_service_config.clone(),
account_update_rx,
slot_update_rx,
block_update_receiver,
Expand All @@ -125,13 +129,26 @@ impl GeyserPlugin for GeyserGrpcPlugin {

let runtime = Runtime::new().unwrap();
let (server_exit_tx, server_exit_rx) = oneshot::channel();
runtime.spawn(
Server::builder()
.add_service(svc)
.serve_with_shutdown(addr, async move {
let _ = server_exit_rx.await;
}),
);
let mut server_builder = Server::builder();
let tls_config = config.geyser_service_config.tls_config.clone();
let access_token = config.geyser_service_config.access_token.clone();
if let Some(tls_config) = tls_config {
let cert = fs::read(&tls_config.cert_path)?;
let key = fs::read(&tls_config.key_path)?;
server_builder = server_builder
.tls_config(ServerTlsConfig::new().identity(Identity::from_pem(cert, key)))
.map_err(|e| GeyserPluginError::Custom(e.into()))?;
}
let s;
if let Some(access_token) = access_token {
let svc = InterceptedService::new(svc, AccessTokenChecker::new(access_token));
s = server_builder.add_service(svc);
} else {
s = server_builder.add_service(svc);
}
runtime.spawn(s.serve_with_shutdown(addr, async move {
let _ = server_exit_rx.await;
}));

self.data = Some(PluginData {
runtime,
Expand All @@ -143,7 +160,7 @@ impl GeyserPlugin for GeyserGrpcPlugin {
highest_write_slot,
is_startup_completed: AtomicBool::new(false),
// don't skip startup to keep backwards compatability
ignore_startup_updates: config.skip_startup_stream.unwrap_or(false)
ignore_startup_updates: config.skip_startup_stream.unwrap_or(false),
});
info!("plugin data initialized");

Expand All @@ -161,7 +178,11 @@ impl GeyserPlugin for GeyserGrpcPlugin {
}

fn notify_end_of_startup(&self) -> PluginResult<()> {
self.data.as_ref().unwrap().is_startup_completed.store(true, Ordering::Relaxed);
self.data
.as_ref()
.unwrap()
.is_startup_completed
.store(true, Ordering::Relaxed);
Ok(())
}

Expand Down Expand Up @@ -459,3 +480,23 @@ pub unsafe extern "C" fn _create_plugin() -> *mut dyn GeyserPlugin {
let plugin: Box<dyn GeyserPlugin> = Box::new(plugin);
Box::into_raw(plugin)
}

#[derive(Clone)]
struct AccessTokenChecker {
access_token: String,
}

impl AccessTokenChecker {
fn new(access_token: String) -> Self {
Self { access_token }
}
}

impl Interceptor for AccessTokenChecker {
fn call(&mut self, req: Request<()>) -> Result<Request<()>, Status> {
match req.metadata().get("access-token") {
Some(t) if &self.access_token == t => Ok(req),
_ => Err(Status::unauthenticated("Access token is incorrect")),
}
}
}
9 changes: 9 additions & 0 deletions server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,22 @@ pub enum GeyserServiceError {

type GeyserServiceResult<T> = Result<T, GeyserServiceError>;

#[derive(Debug, Clone, Deserialize)]
pub struct ServerTlsConfig {
pub cert_path: String,
pub key_path: String,
}

#[derive(Clone, Debug, Deserialize)]
pub struct GeyserServiceConfig {
/// Cadence of heartbeats.
heartbeat_interval_ms: u64,

/// Individual subscriber buffer size.
subscriber_buffer_size: usize,

pub tls_config: Option<ServerTlsConfig>,
pub access_token: Option<String>,
}

pub struct GeyserService {
Expand Down