Skip to content

Commit

Permalink
Feature/padw 106 add transformer related gucs (#24)
Browse files Browse the repository at this point in the history
* feat (PADW-92): Add new GUC to GUC.rs

* feat (PADW-92): Add support for GUC transformer_server_max_retries

* feat (PADW-92): Add support for GUC transformer_server_wait_duration
  • Loading branch information
analyzer1 authored Dec 6, 2024
1 parent 0a3df11 commit 957b6f5
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 18 deletions.
37 changes: 28 additions & 9 deletions extension/src/controller/bgw_transformer_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ use crate::utility::transformer_client;
use crate::utility::guc;
use regex::Regex;

const MAX_TRANSFORMER_RETRIES: u64 = 3; // TODO: Set in GUC

#[pg_guard]
#[no_mangle]
pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {

let max_transformer_retries = guc::get_guc(guc::PgAutoDWGuc::TransformerServerMaxRetries).unwrap();
let max_transformer_retries: u64 = max_transformer_retries.parse().expect("TransformerServerMaxRetries Not Valid u64");

let database_name_string = guc::get_guc(guc::PgAutoDWGuc::DatabaseName);
let database_name_o: Option<&str> = database_name_string.as_deref();

Expand Down Expand Up @@ -77,7 +78,7 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {
let mut retries = 0;
let mut hints = String::new();

while retries < MAX_TRANSFORMER_RETRIES {
while retries < max_transformer_retries {
runtime.block_on(async {
generation_json_business_key_component_identification =
match transformer_client::send_request(
Expand All @@ -102,7 +103,13 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {

if generation_json_business_key_component_identification.is_none() {
retries += 1;
continue; // Skip to the next iteration

if retries >= max_transformer_retries {
panic!("Max Transformer Retries Reached - restart backgrounder.")
}

// Skip to the next iteration
continue;
}

match serde_json::from_value::<BusinessKeyComponentIdentification>(generation_json_business_key_component_identification.clone().unwrap()) {
Expand All @@ -128,7 +135,7 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {
Some(bkci) => {
if bkci.business_key_component_identification.is_business_key_component {
// Identify BK Name
while retries < MAX_TRANSFORMER_RETRIES {
while retries < max_transformer_retries {
runtime.block_on(async {
generation_json_business_key_name =
match transformer_client::send_request(table_details_json_str.as_str(), prompt_template::PromptTemplate::BKName, &column, &hints).await {
Expand All @@ -149,7 +156,13 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {

if generation_json_business_key_name.is_none() {
retries += 1;
continue; // Skip to the next iteration

if retries >= max_transformer_retries {
panic!("Max Transformer Retries Reached - restart backgrounder.")
}

// Skip to the next iteration
continue;
}

match serde_json::from_value::<BusinessKeyName>(generation_json_business_key_name.clone().unwrap()) {
Expand Down Expand Up @@ -177,9 +190,9 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {
let mut descriptors_sensitive: HashMap<&u32, DescriptorSensitive> = HashMap::new();
let mut generation_json_descriptor_sensitive: Option<serde_json::Value> = None;
for column in &columns {
let mut retries: u64 = 0;
let mut retries = 0;
let mut hints = String::new();
while retries < MAX_TRANSFORMER_RETRIES {
while retries < max_transformer_retries {
// Run the async block
runtime.block_on(async {
// Get Generation
Expand Down Expand Up @@ -207,7 +220,13 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {

if generation_json_descriptor_sensitive.is_none() {
retries += 1;
continue; // Skip to the next iteration

if retries >= max_transformer_retries {
panic!("Max Transformer Retries Reached - restart backgrounder.")
}

// Skip to the next iteration
continue;
}

match serde_json::from_value::<DescriptorSensitive>(generation_json_descriptor_sensitive.clone().unwrap()) {
Expand Down
39 changes: 37 additions & 2 deletions extension/src/utility/guc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ pub static PG_AUTO_DW_MODEL: GucSetting<Option<&CStr>> = GucSetting::<Option<&CS
// The accepted transformer's, self-described, confidence level - default 0.8.
pub static PG_AUTO_DW_ACCEPTED_TRANSFORMER_CONFIDENCE_LEVEL: GucSetting<f64> = GucSetting::<f64>::new(0.8);

// Number of times the transformer can be given the same request if a failure is recognized - default 3.
pub static PG_AUTO_DW_TRANSFORMER_SERVER_MAX_RETRIES: GucSetting<i32> = GucSetting::<i32>::new(3);

// Number of seconds to wait for the transformers response - default 60 sec.
pub static PG_AUTO_DW_TRANSFORMER_SERVER_WAIT_DURATION: GucSetting<i32> = GucSetting::<i32>::new(60);

pub fn init_guc() {
// Register the GUCs
GucRegistry::define_string_guc(
Expand Down Expand Up @@ -88,7 +94,7 @@ pub fn init_guc() {

GucRegistry::define_float_guc(
"pg_auto_dw.accepted_transformer_confidence_level",
"Transformers generated confidence level for the pg_auto_dw extension.",
"Transformer generated confidence level for the pg_auto_dw extension.",
"Specifies the confidence level threshold generated by the transformer model for the operations performed by the pg_auto_dw extension.",
&PG_AUTO_DW_ACCEPTED_TRANSFORMER_CONFIDENCE_LEVEL,
0.0, // min value
Expand All @@ -97,6 +103,28 @@ pub fn init_guc() {
GucFlags::default(),
);

GucRegistry::define_int_guc(
"pg_auto_dw.transformer_server_max_retries",
"Maximum Transformer Retries",
"Specifies the number of retry attempts the pg_auto_dw extension can make for a transformer request in case of failure.",
&PG_AUTO_DW_TRANSFORMER_SERVER_MAX_RETRIES,
1, // min value
10, // max value
GucContext::Suset,
GucFlags::default(),
);

GucRegistry::define_int_guc(
"pg_auto_dw.transformer_server_wait_duration",
"Maximum Transformer Server Wait Time",
"Specifies the maximum number of seconds the pg_auto_dw extension will wait for a response from the transformer server.",
&PG_AUTO_DW_TRANSFORMER_SERVER_WAIT_DURATION,
1, // min value
360, // max value
GucContext::Suset,
GucFlags::default(),
);

}

// For handling of GUCs that can be error prone
Expand All @@ -107,11 +135,12 @@ pub enum PgAutoDWGuc {
TransformerServerType,
TransformerServerUrl,
TransformerServerToken,
TransformerServerWaitDuration,
TransformerServerMaxRetries,
Model,
AcceptedTransformerConfidenceLevel,
}


// A convenience function to get this project's GUCs
pub fn get_guc(guc: PgAutoDWGuc) -> Option<String> {
match guc {
Expand All @@ -120,6 +149,8 @@ pub fn get_guc(guc: PgAutoDWGuc) -> Option<String> {
PgAutoDWGuc::TransformerServerType => cstr_option_to_string(PG_AUTO_DW_TRANSFORMER_SERVER_TYPE.get()),
PgAutoDWGuc::TransformerServerUrl => cstr_option_to_string(PG_AUTO_DW_TRANSFORMER_SERVER_URL.get()),
PgAutoDWGuc::TransformerServerToken => cstr_option_to_string(PG_AUTO_DW_TRANSFORMER_SERVER_TOKEN.get()),
PgAutoDWGuc::TransformerServerWaitDuration => cstr_from_int(PG_AUTO_DW_TRANSFORMER_SERVER_WAIT_DURATION.get()),
PgAutoDWGuc::TransformerServerMaxRetries => cstr_from_int(PG_AUTO_DW_TRANSFORMER_SERVER_MAX_RETRIES.get()),
PgAutoDWGuc::Model => cstr_option_to_string(PG_AUTO_DW_MODEL.get()),
PgAutoDWGuc::AcceptedTransformerConfidenceLevel => cstr_from_float(PG_AUTO_DW_ACCEPTED_TRANSFORMER_CONFIDENCE_LEVEL.get()),
}
Expand All @@ -134,3 +165,7 @@ fn cstr_from_float(val: f64) -> Option<String> {
Some(val.to_string())
}

fn cstr_from_int(val: i32) -> Option<String> {
Some(val.to_string())
}

4 changes: 2 additions & 2 deletions extension/src/utility/ollama_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ pub struct GenerateResponse {
pub done: bool,
}

pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u32, hints: &str) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u32, hints: &str, timout_in_sec: u64) -> Result<serde_json::Value, Box<dyn std::error::Error>> {

let client = ClientBuilder::new().timeout(Duration::from_secs(180)).build()?; // 30 sec Default to short for some LLMS.
let client = ClientBuilder::new().timeout(Duration::from_secs(timout_in_sec)).build()?; // 30 sec Default to short for some LLMS.

let prompt_template = template_type.template();

Expand Down
7 changes: 4 additions & 3 deletions extension/src/utility/openai_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ pub struct Usage {
pub total_tokens: u32, // Total number of tokens used
}

pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u32, hints: &str) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u32, hints: &str, timout_in_sec: u64) -> Result<serde_json::Value, Box<dyn std::error::Error>> {

let client = ClientBuilder::new().timeout(Duration::from_secs(60)).build()?; // 30 sec Default to short for some LLMS.
let client = ClientBuilder::new().timeout(Duration::from_secs(timout_in_sec)).build()?; // 30 sec Default to short for some LLMS.

let prompt_template = template_type.template();
// let prompt_template = PromptTemplate::Test.template();
Expand Down Expand Up @@ -93,6 +93,8 @@ pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u
response_format,
};

log!("Request: {:#?}", request.messages[0]);

let raw_response = client
.post(&transformer_server_url)
.header("Authorization", format!("Bearer {}", transformer_server_token))
Expand Down Expand Up @@ -122,7 +124,6 @@ pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u
.message
.content;

log!("Request: {:#?}", request.messages[0]);
log!("Response: {:#?}", content_str);

// Parse the content string into serde_json::Value
Expand Down
7 changes: 5 additions & 2 deletions extension/src/utility/transformer_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ impl FromStr for TransformerServerType {
pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u32, hints: &str) -> Result<serde_json::Value, Box<dyn std::error::Error>> {

let transformer_server_type_str = guc::get_guc(guc::PgAutoDWGuc::TransformerServerType).ok_or("GUC: Transformer Server Type is not set.")?;

let transformer_server_wait_duration = guc::get_guc(guc::PgAutoDWGuc::TransformerServerWaitDuration).ok_or("GUC: Transformer Server Wait Duration is not set.")?;
let timout_in_sec: u64 = transformer_server_wait_duration.parse().expect("TransformerServerMaxRetries Not Valid u64");

let transformer_server_type = transformer_server_type_str.parse::<TransformerServerType>()
.map_err(|e| format!("Error parsing Transformer Server Type: {}", e))?;

match transformer_server_type {
OpenAI => openai_client::send_request(new_json, template_type, col, hints).await,
Ollama => ollama_client::send_request(new_json, template_type, col, hints).await,
OpenAI => openai_client::send_request(new_json, template_type, col, hints, timout_in_sec).await,
Ollama => ollama_client::send_request(new_json, template_type, col, hints, timout_in_sec).await,
}
}

0 comments on commit 957b6f5

Please sign in to comment.