Skip to content

Commit

Permalink
feat: add openai_embed function (#2538)
Browse files Browse the repository at this point in the history
todo: 

- [ ] integration tests & slts
  • Loading branch information
universalmind303 authored Jan 31, 2024
1 parent f059fd9 commit 56f8d38
Show file tree
Hide file tree
Showing 16 changed files with 707 additions and 59 deletions.
152 changes: 152 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion crates/datafusion_ext/src/planner/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ impl<'a, S: AsyncContextProvider> SqlQueryPlanner<'a, S> {
.await?;

// user-defined function (UDF) should have precedence in case it has the same name as a scalar built-in function
if let Some(expr) = self.context_provider.get_function_meta(&name, &args).await {
if let Some(expr) = self
.context_provider
.get_function_meta(&name, &args)
.await?
{
return Ok(expr);
}

Expand Down
2 changes: 1 addition & 1 deletion crates/datafusion_ext/src/planner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ pub trait AsyncContextProvider: Send + Sync {
/// NOTE: This is a modified version of `get_function_meta` that takes
/// arguments and reutrns an `Expr` instead of `ScalarUDF`. This is so that
/// we can return any kind of Expr from scalar UDFs.
async fn get_function_meta(&mut self, name: &str, args: &[Expr]) -> Option<Expr>;
async fn get_function_meta(&mut self, name: &str, args: &[Expr]) -> Result<Option<Expr>>;
/// Getter for a UDAF description
async fn get_aggregate_meta(&mut self, name: &str) -> Option<Arc<AggregateUDF>>;
/// Getter for a UDWF
Expand Down
55 changes: 55 additions & 0 deletions crates/datafusion_ext/src/vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::fmt::Display;
use std::str::FromStr;
use std::sync::Arc;

use catalog::session_catalog::SessionCatalog;
use constants::IMPLICIT_SCHEMAS;
use datafusion::arrow::array::{ListBuilder, StringBuilder};
use datafusion::arrow::datatypes::{DataType, Field};
Expand All @@ -20,6 +21,7 @@ pub use inner::{Dialect, SessionVarsInner};
use once_cell::sync::Lazy;
use parking_lot::{RwLock, RwLockReadGuard};
use pgrepr::notice::NoticeSeverity;
use protogen::metastore::types::options::{CredentialsOptions, CredentialsOptionsOpenAI};
use utils::split_comma_delimited;
use uuid::Uuid;

Expand Down Expand Up @@ -288,3 +290,56 @@ impl VarProvider for SessionVars {
}
}
}

#[derive(Debug, Clone)]
pub struct CredentialsVarProvider<'a> {
pub catalog: &'a SessionCatalog,
}

impl<'a> CredentialsVarProvider<'a> {
const CREDS_PREFIX: &'static str = "@creds";
const CREDS_OPENAI_PREFIX: &'static str = "openai";

pub fn new(catalog: &'a SessionCatalog) -> Self {
Self { catalog }
}
}

// Currently only supports OpenAI credentials
// We can add more providers in the future if needed
impl VarProvider for CredentialsVarProvider<'_> {
fn get_value(&self, var_names: Vec<String>) -> datafusion::error::Result<ScalarValue> {
let var_names: Vec<&str> = var_names.iter().map(|s| s.as_str()).collect();
match var_names.as_slice() {
[Self::CREDS_PREFIX, Self::CREDS_OPENAI_PREFIX, value] => {
let openai_cred = self.catalog.resolve_credentials(value).ok_or_else(|| {
datafusion::error::DataFusionError::Internal(
"No openai credentials found".to_string(),
)
})?;
if let CredentialsOptions::OpenAI(opts) = openai_cred.options.clone() {
Ok(opts.into())
} else {
Err(datafusion::error::DataFusionError::Internal(
"Something went wrong. Expected openai credential, found other".to_string(),
))
}
}
_ => Err(datafusion::error::DataFusionError::Internal(
"unsupported variable".to_string(),
)),
}
}

fn get_type(&self, var_names: &[String]) -> Option<DataType> {
let first = var_names.first().map(|s| s.as_str());
let second = var_names.get(1).map(|s| s.as_str());

match (first, second) {
(Some(Self::CREDS_PREFIX), Some(Self::CREDS_OPENAI_PREFIX)) => {
Some(CredentialsOptionsOpenAI::data_type())
}
_ => None,
}
}
}
7 changes: 7 additions & 0 deletions crates/protogen/proto/metastore/options.proto
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ message CredentialsOptions {
CredentialsOptionsGcp gcp = 2;
CredentialsOptionsAws aws = 3;
CredentialsOptionsAzure azure = 4;
CredentialsOptionsOpenAI openai = 5;
}
}

Expand All @@ -295,3 +296,9 @@ message CredentialsOptionsAzure {
// TODO: We may want to allow the user to give us just the "connection string"
// which contains the account and access key.
}

message CredentialsOptionsOpenAI {
string api_key = 1;
optional string api_base = 2;
optional string org_id = 3;
}
Loading

0 comments on commit 56f8d38

Please sign in to comment.