From 56f8d38827cf0e58590779d1e03d093b1fc6059f Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Wed, 31 Jan 2024 17:23:07 -0600 Subject: [PATCH] feat: add openai_embed function (#2538) todo: - [ ] integration tests & slts --- Cargo.lock | 152 ++++++++++ .../src/planner/expr/function.rs | 6 +- crates/datafusion_ext/src/planner/mod.rs | 2 +- crates/datafusion_ext/src/vars.rs | 55 ++++ crates/protogen/proto/metastore/options.proto | 7 + .../protogen/src/metastore/types/options.rs | 63 +++- crates/sqlbuiltins/Cargo.toml | 4 + crates/sqlbuiltins/src/functions/mod.rs | 11 +- .../src/functions/scalars/hashing.rs | 24 +- .../sqlbuiltins/src/functions/scalars/kdl.rs | 17 +- .../sqlbuiltins/src/functions/scalars/mod.rs | 17 +- .../src/functions/scalars/openai.rs | 287 ++++++++++++++++++ .../src/functions/scalars/postgres.rs | 87 ++++-- crates/sqlexec/src/planner/context_builder.rs | 17 +- crates/sqlexec/src/planner/session_planner.rs | 15 +- justfile | 2 +- 16 files changed, 707 insertions(+), 59 deletions(-) create mode 100644 crates/sqlbuiltins/src/functions/scalars/openai.rs diff --git a/Cargo.lock b/Cargo.lock index d330c4660..40c738631 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -521,6 +521,15 @@ dependencies = [ "zstd-safe 7.0.0", ] +[[package]] +name = "async-convert" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d416feee97712e43152cd42874de162b8f9b77295b1c85e5d92725cc8310bae" +dependencies = [ + "async-trait", +] + [[package]] name = "async-lock" version = "2.8.0" @@ -530,6 +539,31 @@ dependencies = [ "event-listener 2.5.3", ] +[[package]] +name = "async-openai" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b34bfaa81ed6e96ae205528fa7dfb59f5c25f55b6a3e5cdbcb223df168774572" +dependencies = [ + "async-convert", + "backoff", + "base64 0.21.7", + "bytes", + "derive_builder", + "futures", + "rand", + "reqwest", + "reqwest-eventsource", + "secrecy", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", +] + [[package]] name = "async-recursion" version = "1.0.5" @@ -2033,6 +2067,16 @@ dependencies = [ "darling_macro 0.13.4", ] +[[package]] +name = "darling" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" +dependencies = [ + "darling_core 0.14.4", + "darling_macro 0.14.4", +] + [[package]] name = "darling" version = "0.20.3" @@ -2057,6 +2101,20 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "darling_core" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 1.0.109", +] + [[package]] name = "darling_core" version = "0.20.3" @@ -2082,6 +2140,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "darling_macro" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" +dependencies = [ + "darling_core 0.14.4", + "quote", + "syn 1.0.109", +] + [[package]] name = "darling_macro" version = "0.20.3" @@ -2611,6 +2680,37 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_builder" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" +dependencies = [ + "darling 0.14.4", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_builder_macro" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" +dependencies = [ + "derive_builder_core", + "syn 1.0.109", +] + [[package]] name = "derive_more" version = "0.99.17" @@ -2990,6 +3090,17 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite", +] + [[package]] name = "fallible-iterator" version = "0.2.0" @@ -4594,6 +4705,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -6334,6 +6455,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "once_cell", "percent-encoding", "pin-project-lite", @@ -6357,6 +6479,22 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest-eventsource" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f03f570355882dd8d15acc3a313841e6e90eddbc76a93c748fd82cc13ba9f51" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom", + "pin-project-lite", + "reqwest", + "thiserror", +] + [[package]] name = "resolv-conf" version = "0.7.0" @@ -6977,6 +7115,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "secrecy" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bd1c54ea06cfd2f6b63219704de0b9b4f72dcc2b8fdef820be6cd799780e91e" +dependencies = [ + "serde", + "zeroize", +] + [[package]] name = "security-framework" version = "2.9.2" @@ -7463,7 +7611,9 @@ dependencies = [ name = "sqlbuiltins" version = "0.8.3" dependencies = [ + "async-openai", "async-trait", + "catalog", "datafusion", "datafusion_ext", "datasources", @@ -7479,10 +7629,12 @@ dependencies = [ "once_cell", "pgrepr", "protogen", + "reqwest", "siphasher 1.0.0", "strum 0.25.0", "telemetry", "thiserror", + "tokio", "tracing", "uuid", ] diff --git a/crates/datafusion_ext/src/planner/expr/function.rs b/crates/datafusion_ext/src/planner/expr/function.rs index 8b731c63a..5ca07d7d9 100644 --- a/crates/datafusion_ext/src/planner/expr/function.rs +++ b/crates/datafusion_ext/src/planner/expr/function.rs @@ -86,7 +86,11 @@ impl<'a, S: AsyncContextProvider> SqlQueryPlanner<'a, S> { .await?; // user-defined function (UDF) should have precedence in case it has the same name as a scalar built-in function - if let Some(expr) = self.context_provider.get_function_meta(&name, &args).await { + if let Some(expr) = self + .context_provider + .get_function_meta(&name, &args) + .await? + { return Ok(expr); } diff --git a/crates/datafusion_ext/src/planner/mod.rs b/crates/datafusion_ext/src/planner/mod.rs index e50005e59..a90ce596c 100644 --- a/crates/datafusion_ext/src/planner/mod.rs +++ b/crates/datafusion_ext/src/planner/mod.rs @@ -83,7 +83,7 @@ pub trait AsyncContextProvider: Send + Sync { /// NOTE: This is a modified version of `get_function_meta` that takes /// arguments and reutrns an `Expr` instead of `ScalarUDF`. This is so that /// we can return any kind of Expr from scalar UDFs. - async fn get_function_meta(&mut self, name: &str, args: &[Expr]) -> Option; + async fn get_function_meta(&mut self, name: &str, args: &[Expr]) -> Result>; /// Getter for a UDAF description async fn get_aggregate_meta(&mut self, name: &str) -> Option>; /// Getter for a UDWF diff --git a/crates/datafusion_ext/src/vars.rs b/crates/datafusion_ext/src/vars.rs index 8297bb55f..ea4fe3700 100644 --- a/crates/datafusion_ext/src/vars.rs +++ b/crates/datafusion_ext/src/vars.rs @@ -9,6 +9,7 @@ use std::fmt::Display; use std::str::FromStr; use std::sync::Arc; +use catalog::session_catalog::SessionCatalog; use constants::IMPLICIT_SCHEMAS; use datafusion::arrow::array::{ListBuilder, StringBuilder}; use datafusion::arrow::datatypes::{DataType, Field}; @@ -20,6 +21,7 @@ pub use inner::{Dialect, SessionVarsInner}; use once_cell::sync::Lazy; use parking_lot::{RwLock, RwLockReadGuard}; use pgrepr::notice::NoticeSeverity; +use protogen::metastore::types::options::{CredentialsOptions, CredentialsOptionsOpenAI}; use utils::split_comma_delimited; use uuid::Uuid; @@ -288,3 +290,56 @@ impl VarProvider for SessionVars { } } } + +#[derive(Debug, Clone)] +pub struct CredentialsVarProvider<'a> { + pub catalog: &'a SessionCatalog, +} + +impl<'a> CredentialsVarProvider<'a> { + const CREDS_PREFIX: &'static str = "@creds"; + const CREDS_OPENAI_PREFIX: &'static str = "openai"; + + pub fn new(catalog: &'a SessionCatalog) -> Self { + Self { catalog } + } +} + +// Currently only supports OpenAI credentials +// We can add more providers in the future if needed +impl VarProvider for CredentialsVarProvider<'_> { + fn get_value(&self, var_names: Vec) -> datafusion::error::Result { + let var_names: Vec<&str> = var_names.iter().map(|s| s.as_str()).collect(); + match var_names.as_slice() { + [Self::CREDS_PREFIX, Self::CREDS_OPENAI_PREFIX, value] => { + let openai_cred = self.catalog.resolve_credentials(value).ok_or_else(|| { + datafusion::error::DataFusionError::Internal( + "No openai credentials found".to_string(), + ) + })?; + if let CredentialsOptions::OpenAI(opts) = openai_cred.options.clone() { + Ok(opts.into()) + } else { + Err(datafusion::error::DataFusionError::Internal( + "Something went wrong. Expected openai credential, found other".to_string(), + )) + } + } + _ => Err(datafusion::error::DataFusionError::Internal( + "unsupported variable".to_string(), + )), + } + } + + fn get_type(&self, var_names: &[String]) -> Option { + let first = var_names.first().map(|s| s.as_str()); + let second = var_names.get(1).map(|s| s.as_str()); + + match (first, second) { + (Some(Self::CREDS_PREFIX), Some(Self::CREDS_OPENAI_PREFIX)) => { + Some(CredentialsOptionsOpenAI::data_type()) + } + _ => None, + } + } +} diff --git a/crates/protogen/proto/metastore/options.proto b/crates/protogen/proto/metastore/options.proto index b2222842b..e69d6cb17 100644 --- a/crates/protogen/proto/metastore/options.proto +++ b/crates/protogen/proto/metastore/options.proto @@ -273,6 +273,7 @@ message CredentialsOptions { CredentialsOptionsGcp gcp = 2; CredentialsOptionsAws aws = 3; CredentialsOptionsAzure azure = 4; + CredentialsOptionsOpenAI openai = 5; } } @@ -295,3 +296,9 @@ message CredentialsOptionsAzure { // TODO: We may want to allow the user to give us just the "connection string" // which contains the account and access key. } + +message CredentialsOptionsOpenAI { + string api_key = 1; + optional string api_base = 2; + optional string org_id = 3; +} \ No newline at end of file diff --git a/crates/protogen/src/metastore/types/options.rs b/crates/protogen/src/metastore/types/options.rs index 47503426f..0a7bc46ed 100644 --- a/crates/protogen/src/metastore/types/options.rs +++ b/crates/protogen/src/metastore/types/options.rs @@ -1,7 +1,7 @@ use std::collections::BTreeMap; use std::fmt; -use datafusion::arrow::datatypes::{DataType, Field, SchemaRef}; +use datafusion::arrow::datatypes::{DataType, Field, Fields, SchemaRef}; use datafusion::common::DFSchemaRef; use proptest_derive::Arbitrary; @@ -1331,6 +1331,7 @@ pub enum CredentialsOptions { Gcp(CredentialsOptionsGcp), Aws(CredentialsOptionsAws), Azure(CredentialsOptionsAzure), + OpenAI(CredentialsOptionsOpenAI), } impl CredentialsOptions { @@ -1338,6 +1339,7 @@ impl CredentialsOptions { pub const GCP: &'static str = "gcp"; pub const AWS: &'static str = "aws"; pub const AZURE: &'static str = "azure"; + pub const OPENAI: &'static str = "openai"; pub fn as_str(&self) -> &'static str { match self { @@ -1345,6 +1347,7 @@ impl CredentialsOptions { Self::Gcp(_) => Self::GCP, Self::Aws(_) => Self::AWS, Self::Azure(_) => Self::AZURE, + Self::OpenAI(_) => Self::OPENAI, } } } @@ -1363,6 +1366,7 @@ impl TryFrom for CredentialsOptions { options::credentials_options::Options::Gcp(v) => Self::Gcp(v.try_into()?), options::credentials_options::Options::Aws(v) => Self::Aws(v.try_into()?), options::credentials_options::Options::Azure(v) => Self::Azure(v.try_into()?), + options::credentials_options::Options::Openai(v) => Self::OpenAI(v.try_into()?), }) } } @@ -1381,6 +1385,9 @@ impl From for options::credentials_options::Options { CredentialsOptions::Gcp(v) => options::credentials_options::Options::Gcp(v.into()), CredentialsOptions::Aws(v) => options::credentials_options::Options::Aws(v.into()), CredentialsOptions::Azure(v) => options::credentials_options::Options::Azure(v.into()), + CredentialsOptions::OpenAI(v) => { + options::credentials_options::Options::Openai(v.into()) + } } } } @@ -1487,6 +1494,60 @@ impl From for options::CredentialsOptionsAzure { } } +#[derive(Debug, Clone, Arbitrary, PartialEq, Eq, Hash)] +pub struct CredentialsOptionsOpenAI { + pub api_key: String, + pub api_base: Option, + pub org_id: Option, +} + +impl CredentialsOptionsOpenAI { + pub fn fields() -> Fields { + vec![ + Field::new("api_key", DataType::Utf8, false), + Field::new("api_base", DataType::Utf8, true), + Field::new("org_id", DataType::Utf8, true), + ] + .into() + } + pub fn data_type() -> DataType { + DataType::Struct(Self::fields()) + } +} +impl From for datafusion::scalar::ScalarValue { + fn from(value: CredentialsOptionsOpenAI) -> Self { + datafusion::scalar::ScalarValue::Struct( + Some(vec![ + datafusion::scalar::ScalarValue::Utf8(Some(value.api_key)), + datafusion::scalar::ScalarValue::Utf8(value.api_base), + datafusion::scalar::ScalarValue::Utf8(value.org_id), + ]), + CredentialsOptionsOpenAI::fields(), + ) + } +} + +impl TryFrom for CredentialsOptionsOpenAI { + type Error = ProtoConvError; + fn try_from(value: options::CredentialsOptionsOpenAi) -> Result { + Ok(CredentialsOptionsOpenAI { + api_key: value.api_key, + api_base: value.api_base, + org_id: value.org_id, + }) + } +} + +impl From for options::CredentialsOptionsOpenAi { + fn from(value: CredentialsOptionsOpenAI) -> Self { + options::CredentialsOptionsOpenAi { + api_key: value.api_key, + api_base: value.api_base, + org_id: value.org_id, + } + } +} + #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub enum CopyToDestinationOptions { Local(CopyToDestinationOptionsLocal), diff --git a/crates/sqlbuiltins/Cargo.toml b/crates/sqlbuiltins/Cargo.toml index 9f77a4c18..17fe5dde7 100644 --- a/crates/sqlbuiltins/Cargo.toml +++ b/crates/sqlbuiltins/Cargo.toml @@ -12,6 +12,7 @@ logutil = { path = "../logutil" } pgrepr = { path = "../pgrepr" } protogen = { path = "../protogen" } datafusion_ext = { path = "../datafusion_ext" } +catalog = { path = "../catalog" } telemetry = { path = "../telemetry" } datasources = { path = "../datasources" } decimal = { path = "../decimal" } @@ -29,3 +30,6 @@ kdl = "5.0.0-alpha.1" siphasher = "1.0.0" fnv = "1.0.7" memoize = { version = "0.4.2", features = ["full"] } +async-openai = "0.18.2" +tokio.workspace = true +reqwest.workspace = true diff --git a/crates/sqlbuiltins/src/functions/mod.rs b/crates/sqlbuiltins/src/functions/mod.rs index 432f3ffbe..8a5f0076a 100644 --- a/crates/sqlbuiltins/src/functions/mod.rs +++ b/crates/sqlbuiltins/src/functions/mod.rs @@ -33,6 +33,7 @@ use scalars::{ConnectionId, Version}; use table::{BuiltinTableFuncs, TableFunc}; use self::alias_map::AliasMap; +use crate::functions::scalars::openai::OpenAIEmbed; /// All builtin functions available for all sessions. pub static FUNCTION_REGISTRY: Lazy = Lazy::new(FunctionRegistry::new); @@ -115,7 +116,13 @@ pub enum FunctionNamespace { /// likely be used instead. pub trait BuiltinScalarUDF: BuiltinFunction { /// Builds an expression for the function using the provided arguments. - fn as_expr(&self, args: Vec) -> Expr; + /// Some functions may require additional information from the catalog to build the expression. + /// Examples of such functions are ones that require credentials to access external services such as `openai_embed`. + fn try_as_expr( + &self, + catalog: &catalog::session_catalog::SessionCatalog, + args: Vec, + ) -> datafusion::error::Result; /// The namespace of the function. /// Defaults to global (None) @@ -208,6 +215,8 @@ impl FunctionRegistry { Arc::new(SipHash), Arc::new(FnvHash), Arc::new(PartitionResults), + // OpenAI + Arc::new(OpenAIEmbed), ]; let udfs = udfs .into_iter() diff --git a/crates/sqlbuiltins/src/functions/scalars/hashing.rs b/crates/sqlbuiltins/src/functions/scalars/hashing.rs index 04fd2475c..64721045f 100644 --- a/crates/sqlbuiltins/src/functions/scalars/hashing.rs +++ b/crates/sqlbuiltins/src/functions/scalars/hashing.rs @@ -1,8 +1,9 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; +use catalog::session_catalog::SessionCatalog; use datafusion::arrow::datatypes::DataType; -use datafusion::error::DataFusionError; +use datafusion::error::{DataFusionError, Result as DataFusionResult}; use datafusion::logical_expr::expr::ScalarFunction; use datafusion::logical_expr::{ ReturnTypeFunction, @@ -40,7 +41,7 @@ impl ConstBuiltinFunction for SipHash { } } impl BuiltinScalarUDF for SipHash { - fn as_expr(&self, args: Vec) -> Expr { + fn try_as_expr(&self, _: &SessionCatalog, args: Vec) -> DataFusionResult { let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::UInt64))); let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |input| { Ok(get_nth_scalar_value(input, 0, &|value| -> Result< @@ -58,7 +59,10 @@ impl BuiltinScalarUDF for SipHash { &return_type_fn, &scalar_fn_impl, ); - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), args)) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf), + args, + ))) } } @@ -81,7 +85,7 @@ impl ConstBuiltinFunction for FnvHash { } impl BuiltinScalarUDF for FnvHash { - fn as_expr(&self, args: Vec) -> Expr { + fn try_as_expr(&self, _: &SessionCatalog, args: Vec) -> DataFusionResult { let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::UInt64))); let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |input| { Ok(get_nth_scalar_value(input, 0, &|value| -> Result< @@ -99,7 +103,10 @@ impl BuiltinScalarUDF for FnvHash { &return_type_fn, &scalar_fn_impl, ); - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), args)) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf), + args, + ))) } } @@ -122,7 +129,7 @@ impl ConstBuiltinFunction for PartitionResults { } impl BuiltinScalarUDF for PartitionResults { - fn as_expr(&self, args: Vec) -> Expr { + fn try_as_expr(&self, _: &SessionCatalog, args: Vec) -> DataFusionResult { let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Boolean))); let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |input| { if input.len() != 3 { @@ -164,6 +171,9 @@ impl BuiltinScalarUDF for PartitionResults { &return_type_fn, &scalar_fn_impl, ); - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), args)) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf), + args, + ))) } } diff --git a/crates/sqlbuiltins/src/functions/scalars/kdl.rs b/crates/sqlbuiltins/src/functions/scalars/kdl.rs index 85295c83e..046fed5d1 100644 --- a/crates/sqlbuiltins/src/functions/scalars/kdl.rs +++ b/crates/sqlbuiltins/src/functions/scalars/kdl.rs @@ -1,8 +1,9 @@ use std::sync::Arc; use ::kdl::{KdlNode, KdlQuery}; +use catalog::session_catalog::SessionCatalog; use datafusion::arrow::datatypes::DataType; -use datafusion::error::DataFusionError; +use datafusion::error::{DataFusionError, Result as DataFusionResult}; use datafusion::logical_expr::expr::ScalarFunction; use datafusion::logical_expr::{ ReturnTypeFunction, @@ -44,7 +45,7 @@ impl ConstBuiltinFunction for KDLSelect { } impl BuiltinScalarUDF for KDLSelect { - fn as_expr(&self, args: Vec) -> Expr { + fn try_as_expr(&self, _: &SessionCatalog, args: Vec) -> DataFusionResult { let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Utf8))); let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |input| { let filter = get_nth_string_fn_arg(input, 1)?; @@ -82,7 +83,10 @@ impl BuiltinScalarUDF for KDLSelect { &return_type_fn, &scalar_fn_impl, ); - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), args)) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf), + args, + ))) } } @@ -110,7 +114,7 @@ impl ConstBuiltinFunction for KDLMatches { } impl BuiltinScalarUDF for KDLMatches { - fn as_expr(&self, args: Vec) -> Expr { + fn try_as_expr(&self, _: &SessionCatalog, args: Vec) -> DataFusionResult { let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Boolean))); let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |input| { let filter = get_nth_string_fn_arg(input, 1)?; @@ -136,7 +140,10 @@ impl BuiltinScalarUDF for KDLMatches { &return_type_fn, &scalar_fn_impl, ); - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), args)) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf), + args, + ))) } } diff --git a/crates/sqlbuiltins/src/functions/scalars/mod.rs b/crates/sqlbuiltins/src/functions/scalars/mod.rs index 6f2cddddb..bb4023ecc 100644 --- a/crates/sqlbuiltins/src/functions/scalars/mod.rs +++ b/crates/sqlbuiltins/src/functions/scalars/mod.rs @@ -1,6 +1,7 @@ pub mod df_scalars; pub mod hashing; pub mod kdl; +pub mod openai; pub mod postgres; use std::sync::Arc; @@ -30,8 +31,12 @@ impl ConstBuiltinFunction for ConnectionId { } impl BuiltinScalarUDF for ConnectionId { - fn as_expr(&self, _: Vec) -> Expr { - session_var("connection_id") + fn try_as_expr( + &self, + _: &catalog::session_catalog::SessionCatalog, + _: Vec, + ) -> datafusion::error::Result { + Ok(session_var("connection_id")) } } @@ -48,8 +53,12 @@ impl ConstBuiltinFunction for Version { } impl BuiltinScalarUDF for Version { - fn as_expr(&self, _: Vec) -> Expr { - session_var("version") + fn try_as_expr( + &self, + _: &catalog::session_catalog::SessionCatalog, + _: Vec, + ) -> datafusion::error::Result { + Ok(session_var("version")) } } diff --git a/crates/sqlbuiltins/src/functions/scalars/openai.rs b/crates/sqlbuiltins/src/functions/scalars/openai.rs new file mode 100644 index 000000000..3ead517a8 --- /dev/null +++ b/crates/sqlbuiltins/src/functions/scalars/openai.rs @@ -0,0 +1,287 @@ +use std::str::FromStr; +use std::sync::Arc; + +use async_openai::config::OpenAIConfig; +use async_openai::types::{CreateEmbeddingRequest, Embedding, EmbeddingInput, EncodingFormat}; +use async_openai::Client; +use datafusion::arrow::array::{ + ArrayRef, + AsArray, + FixedSizeListArray, + FixedSizeListBuilder, + Float32Builder, +}; +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::expr::ScalarFunction; +use datafusion::logical_expr::{ + Expr, + ReturnTypeFunction, + ScalarFunctionImplementation, + ScalarUDF, + Signature, + TypeSignature, + Volatility, +}; +use datafusion::physical_plan::ColumnarValue; +use datafusion::scalar::ScalarValue; +use datafusion::variable::VarProvider; +use datafusion_ext::vars::CredentialsVarProvider; +use once_cell::sync::Lazy; +use protogen::metastore::types::catalog::FunctionType; +use tokio::runtime::Handle; +use tokio::task; + +use crate::functions::{BuiltinScalarUDF, ConstBuiltinFunction}; +static DEFAULT_CREDENTIAL_LOCATION: Lazy<&[&str]> = + Lazy::new(|| &["@creds", "openai", "openai_default_credential", "api_key"]); +/// This is a placeholder for empty values in the input array +/// The openai API does not accept empty strings, so we use this to represent NULL/"" values +const EMPTY_PLACEHOLDER: &str = "NULL"; + +pub struct OpenAIEmbed; + +impl ConstBuiltinFunction for OpenAIEmbed { + const NAME: &'static str = "openai_embed"; + const DESCRIPTION: &'static str = "Embeds text using OpenAI's API. + WARNING: This function makes an external API call and may be slow. It is recommended to use it with small datasets. + Available models: 'text-embedding-3-small', 'text-embedding-ada-002', 'text-embedding-3-large' default: 'text-embedding-3-small' + Note: This function requires an API key. You can pass it as the first argument or set it as a stored credential. + If no API key is provided, the function will attempt to use the stored credential 'openai_default_credential'"; + const EXAMPLE: &'static str = + "openai_embed(@creds.openai.my_openai, 'text-embedding-3-small', 'Hello, world!')"; + const FUNCTION_TYPE: FunctionType = FunctionType::Scalar; + fn signature(&self) -> Option { + Some(Signature::new( + // args: , , + TypeSignature::OneOf(vec![ + // openai_embed('') -- uses default model and credential + TypeSignature::Exact(vec![DataType::Utf8]), + // openai_embed('', '') -- uses default credential + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), + // openai_embed('api_key', '', '') --uses provided api key and model + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]), + ]), + // This is a volatile function because it makes an external API call. + // Additionally, the openai API key may change or expire at any time. + Volatility::Volatile, + )) + } +} +#[derive(Debug)] +enum EmbeddingModel { + TextEmbedding3Small, + TextEmbeddingAda002, + TextEmbedding3Large, +} +impl EmbeddingModel { + fn len(&self) -> i32 { + match self { + EmbeddingModel::TextEmbedding3Small => 1536, + EmbeddingModel::TextEmbeddingAda002 => 1536, + EmbeddingModel::TextEmbedding3Large => 3072, + } + } +} +impl FromStr for EmbeddingModel { + type Err = datafusion::error::DataFusionError; + + fn from_str(s: &str) -> Result { + match s { + "text-embedding-3-small" => Ok(EmbeddingModel::TextEmbedding3Small), + "text-embedding-ada-002" => Ok(EmbeddingModel::TextEmbeddingAda002), + "text-embedding-3-large" => Ok(EmbeddingModel::TextEmbedding3Large), + _ => Err(DataFusionError::Plan("Invalid argument. Available models are: 'text-embedding-3-small', 'text-embedding-ada-002', 'text-embedding-3-large' ".to_string())), + } + } +} + +impl ToString for EmbeddingModel { + fn to_string(&self) -> String { + match self { + EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small".to_string(), + EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002".to_string(), + EmbeddingModel::TextEmbedding3Large => "text-embedding-3-large".to_string(), + } + } +} + +fn model_from_arg(arg: &Expr) -> datafusion::error::Result { + match arg { + Expr::Literal(ScalarValue::Utf8(Some(v))) => v.parse(), + other => Err(DataFusionError::Plan(format!( + "Invalid argument, expected a string, instead received: '{other}'", + ))), + } +} + +impl BuiltinScalarUDF for OpenAIEmbed { + fn namespace(&self) -> crate::functions::FunctionNamespace { + crate::functions::FunctionNamespace::Optional("openai") + } + fn try_as_expr( + &self, + catalog: &catalog::session_catalog::SessionCatalog, + mut args: Vec, + ) -> datafusion::error::Result { + let creds_from_arg = |values: Vec| -> Option { + let prov = CredentialsVarProvider::new(catalog); + let scalar = prov.get_value(values); + + match scalar.ok()? { + ScalarValue::Utf8(v) => Some(OpenAIConfig::new().with_api_key(v.unwrap())), + ScalarValue::Struct(Some(values), _) => { + let api_key = values.first()?; + let api_base = values.get(1); + let org_id = values.get(2); + let api_key = match api_key { + ScalarValue::Utf8(v) => v.clone().unwrap(), + _ => return None, + }; + let mut config = OpenAIConfig::new().with_api_key(api_key); + config = match api_base { + Some(ScalarValue::Utf8(Some(v))) => config.with_api_base(v.clone()), + _ => config, + }; + config = match org_id { + Some(ScalarValue::Utf8(Some(v))) => config.with_org_id(v.clone()), + _ => config, + }; + Some(config) + } + _ => None, + } + }; + + let (creds, model, idx) = match args.len() { + // openai_embed() + 1 => { + let model = EmbeddingModel::TextEmbedding3Small; + let scalar = creds_from_arg( + DEFAULT_CREDENTIAL_LOCATION + .iter() + .map(|s| s.to_string()) + .collect(), + ); + (scalar, model, 0) + } + // openai_embed(, ) + 2 => { + let model = model_from_arg(&args[0])?; + let scalar = creds_from_arg( + DEFAULT_CREDENTIAL_LOCATION + .iter() + .map(|s| s.to_string()) + .collect(), + ); + (scalar, model, 1) + } + // openai_embed('api_key', '', '') + 3 => { + let creds = match args.first() { + Some(Expr::Literal(ScalarValue::Utf8(v))) => { + Some(OpenAIConfig::new().with_api_key(v.clone().unwrap())) + } + Some(Expr::ScalarVariable(_, values)) => creds_from_arg(values.clone()), + _ => return Err(DataFusionError::Plan("Invalid argument".to_string())), + }; + + let model = model_from_arg(&args[1])?; + (creds, model, 2) + } + _ => return Err(DataFusionError::Plan("Invalid argument count".to_string())), + }; + let Some(creds) = creds else { + return Err(DataFusionError::Plan( + "No API key or credential provided".to_string(), + )); + }; + + let model_len = model.len(); + let return_type_fn: ReturnTypeFunction = Arc::new(move |_| { + let f = Field::new("item", DataType::Float32, true); + let dtype = DataType::FixedSizeList(Arc::new(f), model_len); + Ok(Arc::new(dtype)) + }); + + let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |args| { + let input_chunks = match &args[0] { + ColumnarValue::Array(arr) => match arr.data_type() { + DataType::Utf8 => arr + .as_string::() + .into_iter() + .map(|s| s.unwrap_or(EMPTY_PLACEHOLDER).to_string()) + .collect::>() + .chunks(2000) + .map(|chunk| EmbeddingInput::StringArray(chunk.to_vec())) + .collect::>(), + DataType::LargeUtf8 => arr + .as_string::() + .into_iter() + .map(|s| s.unwrap_or(EMPTY_PLACEHOLDER).to_string()) + .collect::>() + .chunks(2000) + .map(|chunk| EmbeddingInput::StringArray(chunk.to_vec())) + .collect::>(), + _ => return Err(DataFusionError::Plan("Invalid argument".to_string())), + }, + ColumnarValue::Scalar(ScalarValue::Utf8(v) | ScalarValue::LargeUtf8(v)) => { + vec![EmbeddingInput::StringArray(vec![v + .clone() + .unwrap_or_default()])] + } + _ => return Err(DataFusionError::Plan("Invalid argument".to_string())), + }; + + let reqwest_client = reqwest::ClientBuilder::new() + // Set a hard timeout of 10 seconds + .timeout(std::time::Duration::from_secs(10)) + .build() + .unwrap(); + + let client = Client::with_config(creds.clone()).with_http_client(reqwest_client); + let embed = client.embeddings(); + // We chunk the input into 2000 items per request to avoid hitting token limits + let reqs = input_chunks + .into_iter() + .map(|input| CreateEmbeddingRequest { + model: model.to_string(), + input, + encoding_format: Some(EncodingFormat::Float), + user: None, + dimensions: None, + }); + + // no way around blocking here. Expressions are not async + let res: datafusion::error::Result = + task::block_in_place(move || { + Handle::current().block_on(async move { + let values_builder = Float32Builder::new(); + let mut builder = FixedSizeListBuilder::new(values_builder, model_len); + + // We need to do them sequentially to maintain the order + // This should also help us abort early if there's an error + for req in reqs { + let res = embed.create(req).await; + let res = res.map_err(|e| DataFusionError::Execution(e.to_string()))?; + for Embedding { embedding, .. } in res.data.iter() { + builder.values().append_slice(embedding); + builder.append(true); + } + } + Ok(builder.finish()) + }) + }); + + let a: ArrayRef = Arc::new(res?); + Ok(ColumnarValue::Array(a)) + }); + let signature = Signature::exact(vec![DataType::Utf8], Volatility::Volatile); + let udf = ScalarUDF::new(Self::NAME, &signature, &return_type_fn, &scalar_fn_impl); + + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf), + vec![args.remove(idx)], + ))) + } +} diff --git a/crates/sqlbuiltins/src/functions/scalars/postgres.rs b/crates/sqlbuiltins/src/functions/scalars/postgres.rs index 2a3845df0..692ecc4ba 100644 --- a/crates/sqlbuiltins/src/functions/scalars/postgres.rs +++ b/crates/sqlbuiltins/src/functions/scalars/postgres.rs @@ -1,6 +1,8 @@ use std::sync::Arc; +use catalog::session_catalog::SessionCatalog; use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::error::Result as DataFusionResult; use datafusion::logical_expr::expr::ScalarFunction; use datafusion::logical_expr::{ BuiltinScalarFunction, @@ -41,7 +43,7 @@ impl ConstBuiltinFunction for PgGetUserById { } impl BuiltinScalarUDF for PgGetUserById { - fn as_expr(&self, args: Vec) -> Expr { + fn try_as_expr(&self, _: &SessionCatalog, args: Vec) -> DataFusionResult { let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Utf8))); let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |_| { Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( @@ -54,7 +56,11 @@ impl BuiltinScalarUDF for PgGetUserById { &return_type_fn, &scalar_fn_impl, ); - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), args)) + + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf), + args, + ))) } fn namespace(&self) -> FunctionNamespace { @@ -79,7 +85,7 @@ impl ConstBuiltinFunction for PgTableIsVisible { } impl BuiltinScalarUDF for PgTableIsVisible { - fn as_expr(&self, args: Vec) -> Expr { + fn try_as_expr(&self, _: &SessionCatalog, args: Vec) -> DataFusionResult { let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Boolean))); let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |input| { Ok(get_nth_scalar_value(input, 0, &|value| -> Result< @@ -100,7 +106,10 @@ impl BuiltinScalarUDF for PgTableIsVisible { &scalar_fn_impl, ); - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), args)) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf), + args, + ))) } fn namespace(&self) -> FunctionNamespace { @@ -125,7 +134,7 @@ impl ConstBuiltinFunction for PgEncodingToChar { } impl BuiltinScalarUDF for PgEncodingToChar { - fn as_expr(&self, args: Vec) -> Expr { + fn try_as_expr(&self, _: &SessionCatalog, args: Vec) -> DataFusionResult { let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Utf8))); let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |input| { Ok(get_nth_scalar_value(input, 0, &|value| -> Result< @@ -145,7 +154,10 @@ impl BuiltinScalarUDF for PgEncodingToChar { &return_type_fn, &scalar_fn_impl, ); - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), args)) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf), + args, + ))) } fn namespace(&self) -> FunctionNamespace { @@ -173,7 +185,7 @@ impl ConstBuiltinFunction for HasSchemaPrivilege { } impl BuiltinScalarUDF for HasSchemaPrivilege { - fn as_expr(&self, args: Vec) -> Expr { + fn try_as_expr(&self, _: &SessionCatalog, args: Vec) -> DataFusionResult { let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Boolean))); let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |_input| Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))))); @@ -183,7 +195,10 @@ impl BuiltinScalarUDF for HasSchemaPrivilege { &return_type_fn, &scalar_fn_impl, ); - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), args)) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf), + args, + ))) } fn namespace(&self) -> FunctionNamespace { @@ -211,7 +226,7 @@ impl ConstBuiltinFunction for HasDatabasePrivilege { } impl BuiltinScalarUDF for HasDatabasePrivilege { - fn as_expr(&self, args: Vec) -> Expr { + fn try_as_expr(&self, _: &SessionCatalog, args: Vec) -> DataFusionResult { let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Boolean))); let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |_input| Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))))); @@ -221,7 +236,10 @@ impl BuiltinScalarUDF for HasDatabasePrivilege { &return_type_fn, &scalar_fn_impl, ); - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), args)) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf), + args, + ))) } fn namespace(&self) -> FunctionNamespace { @@ -248,7 +266,7 @@ impl ConstBuiltinFunction for HasTablePrivilege { } impl BuiltinScalarUDF for HasTablePrivilege { - fn as_expr(&self, args: Vec) -> Expr { + fn try_as_expr(&self, _: &SessionCatalog, args: Vec) -> DataFusionResult { let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Boolean))); let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))))); @@ -258,7 +276,10 @@ impl BuiltinScalarUDF for HasTablePrivilege { &return_type_fn, &scalar_fn_impl, ); - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), args)) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf), + args, + ))) } fn namespace(&self) -> FunctionNamespace { @@ -286,7 +307,7 @@ impl ConstBuiltinFunction for CurrentSchemas { } impl BuiltinScalarUDF for CurrentSchemas { - fn as_expr(&self, args: Vec) -> Expr { + fn try_as_expr(&self, _: &SessionCatalog, args: Vec) -> DataFusionResult { // There's no good way to handle the `include_implicit` argument, // but since its a binary value (true/false), // we can just assign it to a different variable @@ -296,11 +317,11 @@ impl BuiltinScalarUDF for CurrentSchemas { "current_schemas".to_string() }; - Expr::ScalarVariable( + Ok(Expr::ScalarVariable( DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), vec![var_name], ) - .alias("current_schemas") + .alias("current_schemas")) } fn namespace(&self) -> FunctionNamespace { @@ -324,8 +345,8 @@ impl ConstBuiltinFunction for CurrentUser { } } impl BuiltinScalarUDF for CurrentUser { - fn as_expr(&self, _: Vec) -> Expr { - session_var("current_user") + fn try_as_expr(&self, _: &SessionCatalog, _: Vec) -> DataFusionResult { + Ok(session_var("current_user")) } } @@ -346,8 +367,8 @@ impl ConstBuiltinFunction for CurrentRole { } impl BuiltinScalarUDF for CurrentRole { - fn as_expr(&self, _: Vec) -> Expr { - session_var("current_role") + fn try_as_expr(&self, _: &SessionCatalog, _: Vec) -> DataFusionResult { + Ok(session_var("current_role")) } fn namespace(&self) -> FunctionNamespace { @@ -372,8 +393,8 @@ impl ConstBuiltinFunction for CurrentSchema { } impl BuiltinScalarUDF for CurrentSchema { - fn as_expr(&self, _: Vec) -> Expr { - session_var("current_schema") + fn try_as_expr(&self, _: &SessionCatalog, _: Vec) -> DataFusionResult { + Ok(session_var("current_schema")) } fn namespace(&self) -> FunctionNamespace { @@ -398,8 +419,8 @@ impl ConstBuiltinFunction for CurrentDatabase { } impl BuiltinScalarUDF for CurrentDatabase { - fn as_expr(&self, _: Vec) -> Expr { - session_var("current_database") + fn try_as_expr(&self, _: &SessionCatalog, _: Vec) -> DataFusionResult { + Ok(session_var("current_database")) } } @@ -420,8 +441,8 @@ impl ConstBuiltinFunction for CurrentCatalog { } impl BuiltinScalarUDF for CurrentCatalog { - fn as_expr(&self, _: Vec) -> Expr { - session_var("current_catalog") + fn try_as_expr(&self, _: &SessionCatalog, _: Vec) -> DataFusionResult { + Ok(session_var("current_catalog")) } fn namespace(&self) -> FunctionNamespace { @@ -446,8 +467,8 @@ impl ConstBuiltinFunction for User { } impl BuiltinScalarUDF for User { - fn as_expr(&self, args: Vec) -> Expr { - CurrentUser.as_expr(args).alias("user") + fn try_as_expr(&self, ctx: &SessionCatalog, args: Vec) -> DataFusionResult { + Ok(CurrentUser.try_as_expr(ctx, args)?.alias("user")) } fn namespace(&self) -> FunctionNamespace { @@ -472,11 +493,11 @@ impl ConstBuiltinFunction for PgArrayToString { } impl BuiltinScalarUDF for PgArrayToString { - fn as_expr(&self, args: Vec) -> Expr { - Expr::ScalarFunction(ScalarFunction::new( + fn try_as_expr(&self, _: &SessionCatalog, args: Vec) -> DataFusionResult { + Ok(Expr::ScalarFunction(ScalarFunction::new( BuiltinScalarFunction::ArrayToString, args, - )) + ))) } fn namespace(&self) -> FunctionNamespace { @@ -506,10 +527,10 @@ impl ConstBuiltinFunction for PgVersion { } impl BuiltinScalarUDF for PgVersion { - fn as_expr(&self, _: Vec) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some( + fn try_as_expr(&self, _: &SessionCatalog, _: Vec) -> DataFusionResult { + Ok(Expr::Literal(ScalarValue::Utf8(Some( server_version_with_build_info().to_string(), - ))) + )))) } fn namespace(&self) -> FunctionNamespace { diff --git a/crates/sqlexec/src/planner/context_builder.rs b/crates/sqlexec/src/planner/context_builder.rs index 05021f14f..f7e66f6d4 100644 --- a/crates/sqlexec/src/planner/context_builder.rs +++ b/crates/sqlexec/src/planner/context_builder.rs @@ -11,9 +11,11 @@ use datafusion::execution::context::SessionState; use datafusion::logical_expr::{AggregateUDF, TableSource, WindowUDF}; use datafusion::prelude::Expr; use datafusion::sql::TableReference; +use datafusion::variable::VarProvider; use datafusion_ext::functions::FuncParamValue; use datafusion_ext::planner::AsyncContextProvider; use datafusion_ext::runtime::table_provider::RuntimeAwareTableProvider; +use datafusion_ext::vars::CredentialsVarProvider; use protogen::metastore::types::catalog::{CatalogEntry, RuntimePreference}; use protogen::metastore::types::options::TableOptions; use protogen::rpcsrv::types::service::ResolvedTableReference; @@ -292,14 +294,21 @@ impl<'a> AsyncContextProvider for PartialContextProvider<'a> { .map_err(|e| DataFusionError::External(Box::new(e))) } - async fn get_function_meta(&mut self, name: &str, args: &[Expr]) -> Option { + async fn get_function_meta( + &mut self, + name: &str, + args: &[Expr], + ) -> DataFusionResult> { FUNCTION_REGISTRY .get_scalar_udf(name) - .map(|f| f.as_expr(args.to_vec())) + .map(|f| f.try_as_expr(self.ctx.get_session_catalog(), args.to_vec())) + .transpose() } - async fn get_variable_type(&mut self, _variable_names: &[String]) -> Option { - None + async fn get_variable_type(&mut self, var_names: &[String]) -> Option { + let catalog = self.ctx.get_session_catalog(); + let cred_var_provider = CredentialsVarProvider::new(catalog); + cred_var_provider.get_type(var_names) } async fn get_aggregate_meta(&mut self, _name: &str) -> Option> { diff --git a/crates/sqlexec/src/planner/session_planner.rs b/crates/sqlexec/src/planner/session_planner.rs index 79e01d954..33857e048 100644 --- a/crates/sqlexec/src/planner/session_planner.rs +++ b/crates/sqlexec/src/planner/session_planner.rs @@ -64,6 +64,7 @@ use protogen::metastore::types::options::{ CredentialsOptionsAzure, CredentialsOptionsDebug, CredentialsOptionsGcp, + CredentialsOptionsOpenAI, DatabaseOptions, DatabaseOptionsBigQuery, DatabaseOptionsCassandra, @@ -950,6 +951,17 @@ impl<'a> SessionPlanner<'a> { access_key, }) } + CredentialsOptions::OPENAI => { + let api_key = m.remove_required("api_key")?; + let api_base = m.remove_optional("api_base")?; + let org_id = m.remove_optional("org_id")?; + + CredentialsOptions::OpenAI(CredentialsOptionsOpenAI { + api_key, + api_base, + org_id, + }) + } other => return Err(internal!("unsupported credentials provider: {other}")), }; @@ -2269,7 +2281,8 @@ fn storage_options_with_credentials( creds: CredentialsOptions, ) { match creds { - CredentialsOptions::Debug(_) => {} // Nothing to do here + CredentialsOptions::Debug(_) => {} // Nothing to do here + CredentialsOptions::OpenAI(_) => {} // Nothing to do here. OpenAI is not a storage backend CredentialsOptions::Gcp(creds) => { storage_options.inner.insert( GoogleConfigKey::ServiceAccountKey.as_ref().to_string(), diff --git a/justfile b/justfile index 9203d453b..d1ffac5fa 100644 --- a/justfile +++ b/justfile @@ -101,7 +101,7 @@ lint: clippy fmt-check fix: protoc cargo clippy --fix --all --all-features --allow-staged --allow-dirty cargo fix --all --allow-staged --allow-dirty - cargo fmt --all + just fmt --all # Displays help message. help: