diff --git a/Cargo.lock b/Cargo.lock index d8c93a9c0..ff5840129 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -566,6 +566,8 @@ dependencies = [ "async-nats", "async-trait", "aws-config", + "aws-msk-iam-sasl-signer", + "aws-sdk-kafka", "aws-sdk-kinesis", "axum", "base64 0.13.1", @@ -1536,6 +1538,24 @@ dependencies = [ "zeroize", ] +[[package]] +name = "aws-msk-iam-sasl-signer" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7036b8409ffe698dfdc5ae08722999d960092aeb738026ea99c3071c94831668" +dependencies = [ + "aws-config", + "aws-credential-types", + "aws-sdk-sts", + "aws-sigv4", + "aws-types", + "base64 0.22.1", + "chrono", + "futures", + "thiserror", + "url", +] + [[package]] name = "aws-runtime" version = "1.4.3" @@ -1584,6 +1604,28 @@ dependencies = [ "tracing", ] +[[package]] +name = "aws-sdk-kafka" +version = "1.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e034aa9ec5d7a865e554e8a7744b8d4655db954aa265049e8e6361ba3f5c0d2" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", +] + [[package]] name = "aws-sdk-kinesis" version = "1.48.0" diff --git a/crates/arroyo-connectors/Cargo.toml b/crates/arroyo-connectors/Cargo.toml index 5ad4e606d..f1c6f70f9 100644 --- a/crates/arroyo-connectors/Cargo.toml +++ b/crates/arroyo-connectors/Cargo.toml @@ -50,8 +50,10 @@ regex = "1" ########################## # Kafka +aws-sdk-kafka = { version = "1.44" } +aws-msk-iam-sasl-signer = "1.0.0" rdkafka = { version = "0.36", features = ["cmake-build", "tracing", "sasl", "ssl-vendored"] } -rdkafka-sys = "4.5.0" +rdkafka-sys = "4.7.0" sasl2-sys = { version = "0.1.6", features = ["vendored"] } # SSE diff --git a/crates/arroyo-connectors/src/kafka/mod.rs b/crates/arroyo-connectors/src/kafka/mod.rs index f03c97c10..4db46c48a 100644 --- a/crates/arroyo-connectors/src/kafka/mod.rs +++ b/crates/arroyo-connectors/src/kafka/mod.rs @@ -11,10 +11,14 @@ use arroyo_rpc::schema_resolver::{ }; use arroyo_rpc::{schema_resolver, var_str::VarStr, OperatorConfig}; use arroyo_types::string_to_map; +use aws_config::Region; +use aws_msk_iam_sasl_signer::generate_auth_token; use futures::TryFutureExt; use rdkafka::{ - consumer::{BaseConsumer, Consumer}, - ClientConfig, Message, Offset, TopicPartitionList, + client::OAuthToken, + consumer::{Consumer, ConsumerContext}, + producer::ProducerContext, + ClientConfig, ClientContext, Message, Offset, TopicPartitionList, }; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -22,10 +26,13 @@ use std::borrow::Cow; use std::collections::HashMap; use std::num::NonZeroU32; use std::sync::Arc; +use std::thread; use std::time::{Duration, Instant, SystemTime}; +use tokio::runtime::Handle; use tokio::sync::mpsc::Sender; use tokio::sync::oneshot; use tokio::sync::oneshot::Receiver; +use tokio::time::timeout; use tonic::Status; use tracing::{error, info, warn}; use typify::import_types; @@ -77,6 +84,9 @@ impl KafkaConnector { username: VarStr::new(pull_opt("auth.username", options)?), password: VarStr::new(pull_opt("auth.password", options)?), }, + Some("aws_msk_iam") => KafkaConfigAuthentication::AwsMskIam { + region: pull_opt("auth.region", options)?, + }, Some(other) => bail!("unknown auth type '{}'", other), }; @@ -362,7 +372,7 @@ impl Connector for KafkaConnector { read_mode, group_id_prefix, } => { - let mut client_configs = client_configs(&profile, &table); + let mut client_configs = client_configs(&profile, Some(table.clone()))?; if let Some(ReadMode::ReadCommitted) = read_mode { client_configs .insert("isolation.level".to_string(), "read_committed".to_string()); @@ -399,6 +409,7 @@ impl Connector for KafkaConnector { schema_resolver, bad_data: config.bad_data, client_configs, + context: Context::new(Some(profile.clone())), messages_per_second: NonZeroU32::new( config .rate_limit @@ -422,7 +433,8 @@ impl Connector for KafkaConnector { key_field: key_field.clone(), key_col: None, write_futures: vec![], - client_config: client_configs(&profile, &table), + client_config: client_configs(&profile, Some(table.clone()))?, + context: Context::new(Some(profile.clone())), topic: table.topic, serializer: ArrowSerializer::new( config.format.expect("Format must be defined for KafkaSink"), @@ -467,38 +479,24 @@ impl KafkaTester { }, ); - // TODO: merge this with client_configs() - match &self.connection.authentication { - KafkaConfigAuthentication::None {} => {} - KafkaConfigAuthentication::Sasl { - mechanism, - password, - protocol, - username, - } => { - client_config.set("sasl.mechanism", mechanism); - client_config.set("security.protocol", protocol); - client_config.set( - "sasl.username", - username.sub_env_vars().map_err(|e| e.to_string())?, - ); - client_config.set( - "sasl.password", - password.sub_env_vars().map_err(|e| e.to_string())?, - ); - } - }; - - if let Some(table) = table { - for (k, v) in table.client_configs { - client_config.set(k, v); - } + for (k, v) in client_configs(&self.connection, table) + .map_err(|e| e.to_string())? + .into_iter() + { + client_config.set(k, v); } + let context = Context::new(Some(self.connection.clone())); let client: BaseConsumer = client_config - .create() + .create_with_context(context) .map_err(|e| format!("invalid kafka config: {:?}", e))?; + // NOTE: this is required to trigger an oauth token refresh (when using + // OAUTHBEARER auth). + if client.poll(Duration::from_secs(0)).is_some() { + return Err("unexpected poll event from new consumer".to_string()); + } + tokio::task::spawn_blocking(move || { client .fetch_metadata(None, Duration::from_secs(10)) @@ -903,7 +901,10 @@ impl SourceOffset { } } -pub fn client_configs(connection: &KafkaConfig, table: &KafkaTable) -> HashMap { +pub fn client_configs( + connection: &KafkaConfig, + table: Option, +) -> anyhow::Result> { let mut client_configs: HashMap = HashMap::new(); match &connection.authentication { @@ -916,27 +917,83 @@ pub fn client_configs(connection: &KafkaConfig, table: &KafkaTable) -> HashMap { client_configs.insert("sasl.mechanism".to_string(), mechanism.to_string()); client_configs.insert("security.protocol".to_string(), protocol.to_string()); - client_configs.insert( - "sasl.username".to_string(), - username - .sub_env_vars() - .expect("Missing env-vars for Kafka username"), - ); - client_configs.insert( - "sasl.password".to_string(), - password - .sub_env_vars() - .expect("Missing env-vars for Kafka password"), - ); + client_configs.insert("sasl.username".to_string(), username.sub_env_vars()?); + client_configs.insert("sasl.password".to_string(), password.sub_env_vars()?); + } + KafkaConfigAuthentication::AwsMskIam { region: _ } => { + client_configs.insert("sasl.mechanism".to_string(), "OAUTHBEARER".to_string()); + client_configs.insert("security.protocol".to_string(), "SASL_SSL".to_string()); } }; - client_configs.extend( - table - .client_configs - .iter() - .map(|(k, v)| (k.to_string(), v.to_string())), - ); + if let Some(table) = table { + client_configs.extend( + table + .client_configs + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())), + ); + } + + Ok(client_configs) +} + +type BaseConsumer = rdkafka::consumer::BaseConsumer; +type FutureProducer = rdkafka::producer::FutureProducer; +type StreamConsumer = rdkafka::consumer::StreamConsumer; + +#[derive(Clone)] +pub struct Context { + config: Option, +} + +impl Context { + pub fn new(config: Option) -> Self { + Self { config } + } +} + +impl ConsumerContext for Context {} + +impl ProducerContext for Context { + type DeliveryOpaque = (); + fn delivery( + &self, + _delivery_result: &rdkafka::message::DeliveryResult<'_>, + _delivery_opaque: Self::DeliveryOpaque, + ) { + } +} + +impl ClientContext for Context { + const ENABLE_REFRESH_OAUTH_TOKEN: bool = true; - client_configs + fn generate_oauth_token( + &self, + _oauthbearer_config: Option<&str>, + ) -> Result> { + if let Some(KafkaConfigAuthentication::AwsMskIam { region }) = + self.config.as_ref().map(|c| &c.authentication) + { + let region = Region::new(region.clone()); + let rt = Handle::current(); + + let (token, expiration_time_ms) = { + let handle = thread::spawn(move || { + rt.block_on(async { + timeout(Duration::from_secs(10), generate_auth_token(region.clone())).await + }) + }); + handle.join().unwrap()?? + }; + + Ok(OAuthToken { + token, + principal_name: "".to_string(), + lifetime_ms: expiration_time_ms, + }) + } else { + Err(anyhow!("only AWS_MSK_IAM is supported for sasl oauth").into()) + } + } } diff --git a/crates/arroyo-connectors/src/kafka/profile.json b/crates/arroyo-connectors/src/kafka/profile.json index be7292aba..6fe762dee 100644 --- a/crates/arroyo-connectors/src/kafka/profile.json +++ b/crates/arroyo-connectors/src/kafka/profile.json @@ -6,9 +6,7 @@ "type": "string", "title": "Bootstrap Servers", "description": "Comma-separated list of Kafka servers to connect to", - "examples": [ - "broker-1:9092,broker-2:9092" - ], + "examples": ["broker-1:9092,broker-2:9092"], "pattern": "^(([\\w\\.\\-]+:\\d+),)*([\\w\\.\\-]+:\\d+)$" }, "authentication": { @@ -17,8 +15,7 @@ { "type": "object", "title": "None", - "properties": { - }, + "properties": {}, "additionalProperties": false }, { @@ -30,9 +27,7 @@ "username", "password" ], - "sensitive": [ - "password" - ], + "sensitive": ["password"], "properties": { "protocol": { "type": "string", @@ -54,6 +49,18 @@ } }, "additionalProperties": false + }, + { + "type": "object", + "title": "AWS_MSK_IAM", + "required": ["region"], + "properties": { + "region": { + "type": "string", + "description": "The AWS region to connect to" + } + }, + "additionalProperties": false } ] }, @@ -64,8 +71,7 @@ { "type": "object", "title": "None", - "properties": { - }, + "properties": {}, "additionalProperties": false }, { @@ -76,9 +82,7 @@ "title": "Endpoint", "type": "string", "description": "The endpoint for your Confluent Schema Registry", - "examples": [ - "http://localhost:8081" - ], + "examples": ["http://localhost:8081"], "format": "uri" }, "apiKey": { @@ -86,9 +90,7 @@ "type": "string", "description": "The API key for your Confluent Schema Registry", "format": "var-str", - "examples": [ - "ABCDEFGHIJK01234" - ] + "examples": ["ABCDEFGHIJK01234"] }, "apiSecret": { "title": "API Secret", @@ -100,19 +102,11 @@ "format": "var-str" } }, - "required": [ - "endpoint" - ], - "sensitive": [ - "apiSecret" - ] + "required": ["endpoint"], + "sensitive": ["apiSecret"] } ] } }, - "required": [ - "bootstrapServers", - "authentication" - ] + "required": ["bootstrapServers", "authentication"] } - diff --git a/crates/arroyo-connectors/src/kafka/sink/mod.rs b/crates/arroyo-connectors/src/kafka/sink/mod.rs index 5d756d3c2..4b90d4080 100644 --- a/crates/arroyo-connectors/src/kafka/sink/mod.rs +++ b/crates/arroyo-connectors/src/kafka/sink/mod.rs @@ -8,12 +8,11 @@ use std::collections::HashMap; use std::fmt::{Display, Formatter}; use tracing::{error, warn}; -use rdkafka::producer::{DeliveryFuture, FutureProducer, FutureRecord, Producer}; +use rdkafka::producer::{DeliveryFuture, FutureRecord, Producer}; use rdkafka::util::Timeout; use rdkafka::ClientConfig; -use super::SinkCommitMode; use arrow::array::{Array, AsArray, RecordBatch}; use arrow::datatypes::{DataType, TimeUnit}; use arroyo_formats::ser::ArrowSerializer; @@ -26,6 +25,8 @@ use prost::Message; use rdkafka::error::{KafkaError, RDKafkaErrorCode}; use std::time::{Duration, SystemTime}; +use super::{Context, FutureProducer, SinkCommitMode}; + #[cfg(test)] mod test; @@ -40,6 +41,7 @@ pub struct KafkaSinkFunc { pub producer: Option, pub write_futures: Vec, pub client_config: HashMap, + pub context: Context, pub serializer: ArrowSerializer, } @@ -134,7 +136,7 @@ impl KafkaSinkFunc { match &mut self.consistency_mode { ConsistencyMode::AtLeastOnce => { - self.producer = Some(client_config.create()?); + self.producer = Some(client_config.create_with_context(self.context.clone())?); } ConsistencyMode::ExactlyOnce { next_transaction_index, @@ -150,7 +152,8 @@ impl KafkaSinkFunc { next_transaction_index ); client_config.set("transactional.id", transactional_id); - let producer: FutureProducer = client_config.create()?; + let producer: FutureProducer = + client_config.create_with_context(self.context.clone())?; producer.init_transactions(Timeout::After(Duration::from_secs(30)))?; producer.begin_transaction()?; *next_transaction_index += 1; diff --git a/crates/arroyo-connectors/src/kafka/sink/test.rs b/crates/arroyo-connectors/src/kafka/sink/test.rs index 6b7b1d5f4..6885adeba 100644 --- a/crates/arroyo-connectors/src/kafka/sink/test.rs +++ b/crates/arroyo-connectors/src/kafka/sink/test.rs @@ -23,6 +23,7 @@ use serde::Deserialize; use tokio::sync::mpsc::channel; use super::{ConsistencyMode, KafkaSinkFunc}; +use crate::kafka::Context; pub struct KafkaTopicTester { topic: String, @@ -80,6 +81,7 @@ impl KafkaTopicTester { key_field: None, write_futures: vec![], client_config: HashMap::new(), + context: Context::new(None), serializer: ArrowSerializer::new(Format::Json(JsonFormat::default())), key_col: None, }; diff --git a/crates/arroyo-connectors/src/kafka/source/mod.rs b/crates/arroyo-connectors/src/kafka/source/mod.rs index 89fa297d0..299bf40b6 100644 --- a/crates/arroyo-connectors/src/kafka/source/mod.rs +++ b/crates/arroyo-connectors/src/kafka/source/mod.rs @@ -1,17 +1,9 @@ -use arroyo_formats::de::FieldValueType; -use arroyo_rpc::formats::{BadData, Format, Framing}; -use arroyo_rpc::grpc::rpc::TableConfig; -use arroyo_rpc::schema_resolver::SchemaResolver; -use arroyo_rpc::{grpc::rpc::StopMode, ControlMessage, ControlResp, MetadataField}; - -use arroyo_operator::context::ArrowContext; -use arroyo_operator::operator::SourceOperator; -use arroyo_operator::SourceFinishType; -use arroyo_types::*; +use anyhow::bail; use async_trait::async_trait; use bincode::{Decode, Encode}; +use futures::FutureExt; use governor::{Quota, RateLimiter as GovernorRateLimiter}; -use rdkafka::consumer::{CommitMode, Consumer, StreamConsumer}; +use rdkafka::consumer::{CommitMode, Consumer}; use rdkafka::{ClientConfig, Message as KMessage, Offset, TopicPartitionList}; use std::collections::HashMap; use std::num::NonZeroU32; @@ -21,6 +13,18 @@ use tokio::select; use tokio::time::MissedTickBehavior; use tracing::{debug, error, info, warn}; +use arroyo_formats::de::FieldValueType; +use arroyo_operator::context::ArrowContext; +use arroyo_operator::operator::SourceOperator; +use arroyo_operator::SourceFinishType; +use arroyo_rpc::formats::{BadData, Format, Framing}; +use arroyo_rpc::grpc::rpc::TableConfig; +use arroyo_rpc::schema_resolver::SchemaResolver; +use arroyo_rpc::{grpc::rpc::StopMode, ControlMessage, ControlResp, MetadataField}; +use arroyo_types::*; + +use super::{Context, SourceOffset, StreamConsumer}; + #[cfg(test)] mod test; @@ -29,12 +33,13 @@ pub struct KafkaSourceFunc { pub bootstrap_servers: String, pub group_id: Option, pub group_id_prefix: Option, - pub offset_mode: super::SourceOffset, + pub offset_mode: SourceOffset, pub format: Format, pub framing: Option, pub bad_data: Option, pub schema_resolver: Option>, pub client_configs: HashMap, + pub context: Context, pub messages_per_second: NonZeroU32, pub metadata_fields: Vec, } @@ -79,7 +84,13 @@ impl KafkaSourceFunc { .set("enable.partition.eof", "false") .set("enable.auto.commit", "false") .set("group.id", group_id) - .create()?; + .create_with_context(self.context.clone())?; + + // NOTE: this is required to trigger an oauth token refresh (when using + // OAUTHBEARER auth). + if consumer.recv().now_or_never().is_some() { + bail!("unexpected recv before assignments"); + } let state: Vec<_> = ctx .table_manager diff --git a/crates/arroyo-connectors/src/kafka/source/test.rs b/crates/arroyo-connectors/src/kafka/source/test.rs index a6c5f5b51..1898f1115 100644 --- a/crates/arroyo-connectors/src/kafka/source/test.rs +++ b/crates/arroyo-connectors/src/kafka/source/test.rs @@ -29,6 +29,7 @@ use serde::{Deserialize, Serialize}; use tokio::sync::mpsc::{channel, Receiver, Sender}; use super::KafkaSourceFunc; +use crate::kafka::Context; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] struct TestData { @@ -86,6 +87,7 @@ impl KafkaTopicTester { bad_data: None, schema_resolver: None, client_configs: HashMap::new(), + context: Context::new(None), messages_per_second: NonZeroU32::new(100).unwrap(), metadata_fields: vec![], }); @@ -375,6 +377,7 @@ async fn test_kafka_with_metadata_fields() { bad_data: None, schema_resolver: None, client_configs: HashMap::new(), + context: Context::new(None), messages_per_second: NonZeroU32::new(100).unwrap(), metadata_fields, };