Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add openai_embed function #2538

Merged
merged 18 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion crates/datafusion_ext/src/planner/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ impl<'a, S: AsyncContextProvider> SqlQueryPlanner<'a, S> {
.await?;

// user-defined function (UDF) should have precedence in case it has the same name as a scalar built-in function
if let Some(expr) = self.context_provider.get_function_meta(&name, &args).await {
if let Some(expr) = self
.context_provider
.get_function_meta(&name, &args)
.await?
{
return Ok(expr);
}

Expand Down
2 changes: 1 addition & 1 deletion crates/datafusion_ext/src/planner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ pub trait AsyncContextProvider: Send + Sync {
/// NOTE: This is a modified version of `get_function_meta` that takes
/// arguments and reutrns an `Expr` instead of `ScalarUDF`. This is so that
/// we can return any kind of Expr from scalar UDFs.
async fn get_function_meta(&mut self, name: &str, args: &[Expr]) -> Option<Expr>;
async fn get_function_meta(&mut self, name: &str, args: &[Expr]) -> Result<Option<Expr>>;
/// Getter for a UDAF description
async fn get_aggregate_meta(&mut self, name: &str) -> Option<Arc<AggregateUDF>>;
/// Getter for a UDWF
Expand Down
55 changes: 55 additions & 0 deletions crates/datafusion_ext/src/vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::fmt::Display;
use std::str::FromStr;
use std::sync::Arc;

use catalog::session_catalog::SessionCatalog;
use constants::IMPLICIT_SCHEMAS;
use datafusion::arrow::array::{ListBuilder, StringBuilder};
use datafusion::arrow::datatypes::{DataType, Field};
Expand All @@ -20,6 +21,7 @@ pub use inner::{Dialect, SessionVarsInner};
use once_cell::sync::Lazy;
use parking_lot::{RwLock, RwLockReadGuard};
use pgrepr::notice::NoticeSeverity;
use protogen::metastore::types::options::{CredentialsOptions, CredentialsOptionsOpenAI};
use utils::split_comma_delimited;
use uuid::Uuid;

Expand Down Expand Up @@ -288,3 +290,56 @@ impl VarProvider for SessionVars {
}
}
}

#[derive(Debug, Clone)]
pub struct CredentialsVarProvider<'a> {
pub catalog: &'a SessionCatalog,
}

impl<'a> CredentialsVarProvider<'a> {
const CREDS_PREFIX: &'static str = "@creds";
const CREDS_OPENAI_PREFIX: &'static str = "openai";

pub fn new(catalog: &'a SessionCatalog) -> Self {
Self { catalog }
}
}

// Currently only supports OpenAI credentials
// We can add more providers in the future if needed
impl VarProvider for CredentialsVarProvider<'_> {
fn get_value(&self, var_names: Vec<String>) -> datafusion::error::Result<ScalarValue> {
let var_names: Vec<&str> = var_names.iter().map(|s| s.as_str()).collect();
match var_names.as_slice() {
[Self::CREDS_PREFIX, Self::CREDS_OPENAI_PREFIX, value] => {
let openai_cred = self.catalog.resolve_credentials(value).ok_or_else(|| {
datafusion::error::DataFusionError::Internal(
"No openai credentials found".to_string(),
)
})?;
if let CredentialsOptions::OpenAI(opts) = openai_cred.options.clone() {
Ok(opts.into())
} else {
Err(datafusion::error::DataFusionError::Internal(
"Something went wrong. Expected openai credential, found other".to_string(),
))
}
}
_ => Err(datafusion::error::DataFusionError::Internal(
"unsupported variable".to_string(),
)),
}
}

fn get_type(&self, var_names: &[String]) -> Option<DataType> {
let first = var_names.first().map(|s| s.as_str());
let second = var_names.get(1).map(|s| s.as_str());

match (first, second) {
(Some(Self::CREDS_PREFIX), Some(Self::CREDS_OPENAI_PREFIX)) => {
Some(CredentialsOptionsOpenAI::data_type())
}
_ => None,
}
}
}
7 changes: 7 additions & 0 deletions crates/protogen/proto/metastore/options.proto
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ message CredentialsOptions {
CredentialsOptionsGcp gcp = 2;
CredentialsOptionsAws aws = 3;
CredentialsOptionsAzure azure = 4;
CredentialsOptionsOpenAI openai = 5;
}
}

Expand All @@ -295,3 +296,9 @@ message CredentialsOptionsAzure {
// TODO: We may want to allow the user to give us just the "connection string"
// which contains the account and access key.
}

message CredentialsOptionsOpenAI {
string api_key = 1;
optional string api_base = 2;
optional string org_id = 3;
}
Loading
Loading