diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/request.rs b/engine/baml-runtime/src/internal/llm_client/primitive/request.rs index c288da7d4..3d480552b 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/request.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/request.rs @@ -34,7 +34,7 @@ pub async fn make_request( stream: bool, ) -> Result<(Response, web_time::SystemTime, web_time::Instant), LLMResponse> { let (system_now, instant_now) = (web_time::SystemTime::now(), web_time::Instant::now()); - log::info!("Making request using client {}", client.context().name); + log::debug!("Making request using client {}", client.context().name); let req = match client.build_request(prompt, stream).build() { Ok(req) => req, diff --git a/engine/baml-runtime/src/lib.rs b/engine/baml-runtime/src/lib.rs index 7548de77c..9445d637a 100644 --- a/engine/baml-runtime/src/lib.rs +++ b/engine/baml-runtime/src/lib.rs @@ -31,6 +31,7 @@ use baml_types::BamlMap; use baml_types::BamlValue; use indexmap::IndexMap; use internal_baml_core::configuration::GeneratorOutputType; +use on_log_event::LogEventCallbackSync; use runtime::InternalBamlRuntime; #[cfg(not(target_arch = "wasm32"))] @@ -38,6 +39,7 @@ pub use cli::CallerType; use runtime_interface::ExperimentalTracingInterface; use runtime_interface::RuntimeConstructor; use runtime_interface::RuntimeInterface; +use tracing::api_wrapper::core_types::LogSchema; use tracing::{BamlTracer, TracingSpan}; use type_builder::TypeBuilder; pub use types::*; @@ -329,4 +331,10 @@ impl ExperimentalTracingInterface for BamlRuntime { fn flush(&self) -> Result<()> { self.tracer.flush() } + + #[cfg(not(target_arch = "wasm32"))] + fn set_log_event_callback(&self, log_event_callback: LogEventCallbackSync) -> Result<()> { + self.tracer.set_log_event_callback(log_event_callback); + Ok(()) + } } diff --git a/engine/baml-runtime/src/runtime_interface.rs b/engine/baml-runtime/src/runtime_interface.rs index d74b664f1..bb6f87f5c 100644 --- a/engine/baml-runtime/src/runtime_interface.rs +++ b/engine/baml-runtime/src/runtime_interface.rs @@ -8,8 +8,10 @@ use std::{collections::HashMap, sync::Arc}; use crate::internal::llm_client::llm_provider::LLMProvider; use crate::internal::llm_client::orchestrator::{OrchestrationScope, OrchestratorNode}; +use crate::tracing::api_wrapper::core_types::LogSchema; use crate::tracing::{BamlTracer, TracingSpan}; use crate::type_builder::TypeBuilder; +use crate::types::on_log_event::LogEventCallbackSync; use crate::RuntimeContextManager; use crate::{ internal::{ir_features::IrFeatures, llm_client::retry_policy::CallablePolicy}, @@ -94,6 +96,9 @@ pub trait ExperimentalTracingInterface { ) -> Result>; fn flush(&self) -> Result<()>; + + #[cfg(not(target_arch = "wasm32"))] + fn set_log_event_callback(&self, callback: LogEventCallbackSync) -> Result<()>; } pub trait InternalClientLookup<'a> { diff --git a/engine/baml-runtime/src/tracing/api_wrapper/core_types.rs b/engine/baml-runtime/src/tracing/api_wrapper/core_types.rs index 4dd89d9da..94c0e5fe8 100644 --- a/engine/baml-runtime/src/tracing/api_wrapper/core_types.rs +++ b/engine/baml-runtime/src/tracing/api_wrapper/core_types.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; #[derive(Serialize, Debug)] -pub(crate) struct UpdateTestCase { +pub struct UpdateTestCase { pub project_id: Option, pub test_cycle_id: String, pub test_dataset_name: String, @@ -15,8 +15,8 @@ pub(crate) struct UpdateTestCase { pub error_data: Option, // Rust doesn't have a direct equivalent of Python's Any type, so we use serde_json::Value } -#[derive(Serialize, Debug)] -pub(crate) struct LogSchema { +#[derive(Serialize, Debug, Clone)] +pub struct LogSchema { pub project_id: Option, pub event_type: EventType, pub root_event_id: String, @@ -29,26 +29,26 @@ pub(crate) struct LogSchema { } #[derive(Serialize, Debug, Clone)] -pub(crate) struct IO { +pub struct IO { pub(crate) input: Option, pub(crate) output: Option, } #[derive(Serialize, Debug, Clone)] -pub(crate) struct IOValue { +pub struct IOValue { pub(crate) value: ValueType, pub(crate) r#override: Option>, pub(crate) r#type: TypeSchema, } #[derive(Serialize, Debug, Clone)] -pub(crate) struct TypeSchema { +pub struct TypeSchema { pub(crate) name: TypeSchemaName, pub(crate) fields: IndexMap, } #[derive(Serialize, Debug, Clone)] -pub(crate) enum TypeSchemaName { +pub enum TypeSchemaName { #[serde(rename = "single")] Single, #[serde(rename = "multi")] @@ -57,7 +57,7 @@ pub(crate) enum TypeSchemaName { #[derive(Serialize, Debug, Clone)] #[serde(untagged)] -pub(crate) enum ValueType { +pub enum ValueType { String(String), // For mutli-args, we use a list of strings List(Vec), @@ -98,7 +98,7 @@ pub enum EventType { } #[derive(Serialize, Debug, Clone)] -pub(crate) struct LogSchemaContext { +pub struct LogSchemaContext { pub hostname: String, pub process_id: String, pub stage: Option, @@ -109,7 +109,7 @@ pub(crate) struct LogSchemaContext { } #[derive(Serialize, Debug, Clone)] -pub(crate) struct EventChain { +pub struct EventChain { pub function_name: String, pub variant_name: Option, } @@ -122,8 +122,8 @@ pub(crate) struct Error { pub r#override: Option>, } -#[derive(Serialize, Debug, Deserialize, Default)] -pub(crate) struct LLMOutputModelMetadata { +#[derive(Serialize, Debug, Deserialize, Default, Clone)] +pub struct LLMOutputModelMetadata { pub logprobs: Option, pub prompt_tokens: Option, pub output_tokens: Option, @@ -131,21 +131,21 @@ pub(crate) struct LLMOutputModelMetadata { pub finish_reason: Option, } -#[derive(Serialize, Debug)] -pub(crate) struct LLMOutputModel { +#[derive(Serialize, Debug, Clone)] +pub struct LLMOutputModel { pub raw_text: String, pub metadata: LLMOutputModelMetadata, pub r#override: Option>, } -#[derive(Serialize, Debug)] +#[derive(Serialize, Debug, Clone)] pub(crate) struct LLMChat { pub role: Role, pub content: Vec, } -#[derive(Serialize, Debug)] -pub(crate) enum ContentPart { +#[derive(Serialize, Debug, Clone)] +pub enum ContentPart { #[serde(rename = "text")] Text(String), #[serde(rename = "url_image")] @@ -158,9 +158,9 @@ pub(crate) enum ContentPart { B64Audio(String), } -#[derive(Serialize, Debug, Deserialize)] +#[derive(Serialize, Debug, Deserialize, Clone)] #[serde(untagged)] -pub(crate) enum Role { +pub enum Role { #[serde(rename = "assistant")] Assistant, #[serde(rename = "user")] @@ -170,14 +170,14 @@ pub(crate) enum Role { Other(String), } -#[derive(Serialize, Debug)] +#[derive(Serialize, Debug, Clone)] pub(crate) struct LLMEventInput { pub prompt: LLMEventInputPrompt, pub invocation_params: HashMap, } -#[derive(Serialize, Debug)] -pub(crate) struct LLMEventSchema { +#[derive(Serialize, Debug, Clone)] +pub struct LLMEventSchema { pub model_name: String, pub provider: String, pub input: LLMEventInput, @@ -185,20 +185,20 @@ pub(crate) struct LLMEventSchema { pub error: Option, } -#[derive(Serialize, Debug)] +#[derive(Serialize, Debug, Clone)] #[serde(untagged)] -pub(crate) enum MetadataType { +pub enum MetadataType { Single(LLMEventSchema), Multi(Vec), } -#[derive(Serialize, Debug)] -pub(crate) struct LLMEventInputPrompt { +#[derive(Serialize, Debug, Clone)] +pub struct LLMEventInputPrompt { pub template: Template, pub template_args: HashMap, pub r#override: Option>, } -#[derive(Serialize, Debug)] +#[derive(Serialize, Debug, Clone)] #[serde(untagged)] #[allow(dead_code)] pub enum Template { diff --git a/engine/baml-runtime/src/tracing/api_wrapper/env_setup.rs b/engine/baml-runtime/src/tracing/api_wrapper/env_setup.rs index 0388003af..fbaf88582 100644 --- a/engine/baml-runtime/src/tracing/api_wrapper/env_setup.rs +++ b/engine/baml-runtime/src/tracing/api_wrapper/env_setup.rs @@ -13,12 +13,20 @@ pub struct Config { pub stage: String, #[serde(default = "default_host_name")] pub host_name: String, + #[serde(default)] // default is false + pub log_redaction_enabled: bool, + #[serde(default = "default_redaction_placeholder")] + pub log_redaction_placeholder: String, } fn default_base_url() -> String { "https://app.boundaryml.com/api".to_string() } +fn default_redaction_placeholder() -> String { + "".to_string() +} + fn default_sessions_id() -> String { uuid::Uuid::new_v4().to_string() } diff --git a/engine/baml-runtime/src/tracing/api_wrapper/mod.rs b/engine/baml-runtime/src/tracing/api_wrapper/mod.rs index 7dd70b6e7..92d890c68 100644 --- a/engine/baml-runtime/src/tracing/api_wrapper/mod.rs +++ b/engine/baml-runtime/src/tracing/api_wrapper/mod.rs @@ -2,7 +2,7 @@ mod env_setup; use anyhow::Result; pub(super) mod api_interface; -pub(super) mod core_types; +pub(crate) mod core_types; use instant::Duration; use serde::{de::DeserializeOwned, Deserialize}; use serde_json::{json, Value}; @@ -14,11 +14,11 @@ use self::core_types::{TestCaseStatus, UpdateTestCase}; #[derive(Debug, Clone)] pub struct APIWrapper { - config: APIConfig, + pub(super) config: APIConfig, } #[derive(Debug, Clone)] -enum APIConfig { +pub(super) enum APIConfig { LocalOnly(PartialAPIConfig), Web(CompleteAPIConfig), } @@ -59,6 +59,20 @@ impl APIConfig { } } + pub fn log_redaction_enabled(&self) -> bool { + match self { + Self::LocalOnly(config) => config.log_redaction_enabled, + Self::Web(config) => config.log_redaction_enabled, + } + } + + pub fn log_redaction_placeholder(&self) -> &str { + match self { + Self::LocalOnly(config) => &config.log_redaction_placeholder, + Self::Web(config) => &config.log_redaction_placeholder, + } + } + #[allow(dead_code)] #[allow(clippy::too_many_arguments)] pub(crate) fn copy_from( @@ -69,6 +83,8 @@ impl APIConfig { sessions_id: Option<&str>, stage: Option<&str>, host_name: Option<&str>, + log_redaction_enabled: Option, + log_redaction_placeholder: Option, _debug_level: Option, ) -> Self { let base_url = base_url.unwrap_or(match self { @@ -95,6 +111,14 @@ impl APIConfig { Self::LocalOnly(config) => &config.host_name, Self::Web(config) => &config.host_name, }); + let log_redaction_enabled = log_redaction_enabled.unwrap_or_else(|| match self { + Self::LocalOnly(config) => config.log_redaction_enabled, + Self::Web(config) => config.log_redaction_enabled, + }); + let log_redaction_placeholder = log_redaction_placeholder.unwrap_or_else(|| match self { + Self::LocalOnly(config) => config.log_redaction_placeholder.clone(), + Self::Web(config) => config.log_redaction_placeholder.clone(), + }); match (api_key, project_id) { (Some(api_key), Some(project_id)) => Self::Web(CompleteAPIConfig { @@ -105,6 +129,8 @@ impl APIConfig { sessions_id: sessions_id.to_string(), host_name: host_name.to_string(), client: create_client().unwrap(), + log_redaction_enabled, + log_redaction_placeholder, }), _ => Self::LocalOnly(PartialAPIConfig { base_url: base_url.to_string(), @@ -113,6 +139,8 @@ impl APIConfig { stage: stage.to_string(), sessions_id: sessions_id.to_string(), host_name: host_name.to_string(), + log_redaction_enabled, + log_redaction_placeholder, }), } } @@ -126,6 +154,8 @@ pub(super) struct CompleteAPIConfig { pub stage: String, pub sessions_id: String, pub host_name: String, + pub log_redaction_enabled: bool, + pub log_redaction_placeholder: String, client: reqwest::Client, } @@ -140,6 +170,8 @@ pub(super) struct PartialAPIConfig { stage: String, sessions_id: String, host_name: String, + log_redaction_enabled: bool, + log_redaction_placeholder: String, } impl CompleteAPIConfig { @@ -318,6 +350,9 @@ impl BoundaryTestAPI for APIWrapper { impl APIWrapper { pub fn from_env_vars>(value: impl Iterator) -> Self { let config = env_setup::Config::from_env_vars(value).unwrap(); + if config.log_redaction_enabled { + log::info!("Redaction enabled: {}", config.log_redaction_enabled); + } match (&config.secret, &config.project_id) { (Some(api_key), Some(project_id)) => Self { config: APIConfig::Web(CompleteAPIConfig { @@ -328,6 +363,8 @@ impl APIWrapper { sessions_id: config.sessions_id, host_name: config.host_name, client: create_client().unwrap(), + log_redaction_enabled: config.log_redaction_enabled, + log_redaction_placeholder: config.log_redaction_placeholder, }), }, _ => Self { @@ -338,6 +375,8 @@ impl APIWrapper { stage: config.stage, sessions_id: config.sessions_id, host_name: config.host_name, + log_redaction_enabled: config.log_redaction_enabled, + log_redaction_placeholder: config.log_redaction_placeholder, }), }, } diff --git a/engine/baml-runtime/src/tracing/mod.rs b/engine/baml-runtime/src/tracing/mod.rs index 09f874704..19220101e 100644 --- a/engine/baml-runtime/src/tracing/mod.rs +++ b/engine/baml-runtime/src/tracing/mod.rs @@ -1,9 +1,10 @@ -mod api_wrapper; +pub mod api_wrapper; #[cfg(not(target_arch = "wasm32"))] mod threaded_tracer; #[cfg(target_arch = "wasm32")] mod wasm_tracer; +use crate::on_log_event::LogEventCallbackSync; use anyhow::Result; use baml_types::{BamlMap, BamlMedia, BamlMediaType, BamlValue}; use colored::Colorize; @@ -72,6 +73,13 @@ impl BamlTracer { tracer } + #[cfg(not(target_arch = "wasm32"))] + pub(crate) fn set_log_event_callback(&self, log_event_callback: LogEventCallbackSync) { + if let Some(tracer) = &self.tracer { + tracer.set_log_event_callback(log_event_callback); + } + } + pub(crate) fn flush(&self) -> Result<()> { if let Some(tracer) = &self.tracer { tracer.flush() diff --git a/engine/baml-runtime/src/tracing/threaded_tracer.rs b/engine/baml-runtime/src/tracing/threaded_tracer.rs index 7871cd677..e16082bb0 100644 --- a/engine/baml-runtime/src/tracing/threaded_tracer.rs +++ b/engine/baml-runtime/src/tracing/threaded_tracer.rs @@ -1,9 +1,18 @@ -use std::sync::mpsc::{Receiver, Sender, TryRecvError}; +use std::{ + cell::RefCell, + sync::mpsc::{Receiver, Sender, TryRecvError}, +}; +// use crate::log_callback_event::LogEvent use anyhow::Result; use web_time::Duration; -use super::api_wrapper::{core_types::LogSchema, APIWrapper, BoundaryAPI}; +use crate::{ + on_log_event::{LogEvent, LogEventCallbackSync, LogEventMetadata}, + tracing::api_wrapper::core_types::{ContentPart, MetadataType, Template, ValueType}, +}; + +use super::api_wrapper::{core_types::LogSchema, APIConfig, APIWrapper, BoundaryAPI}; enum TxEventSignal { Stop, @@ -79,6 +88,7 @@ fn batch_processor( Ok(_) => {} Err(e) => { println!("Error sending flush signal: {:?}", e); + log::error!("Error sending flush signal: {:?}", e); } } } @@ -89,10 +99,12 @@ fn batch_processor( } pub(super) struct ThreadedTracer { + api_config: APIWrapper, tx: std::sync::Arc>>, rx: std::sync::Arc>>, #[allow(dead_code)] join_handle: std::thread::JoinHandle<()>, + log_event_callback: std::sync::Arc>>, } impl ThreadedTracer { @@ -115,9 +127,11 @@ impl ThreadedTracer { pub fn new(api_config: &APIWrapper, max_batch_size: usize) -> Self { let (tx, rx, join_handle) = Self::start_worker(api_config.clone(), max_batch_size); Self { + api_config: api_config.clone(), tx: std::sync::Arc::new(std::sync::Mutex::new(tx)), rx: std::sync::Arc::new(std::sync::Mutex::new(rx)), join_handle, + log_event_callback: std::sync::Arc::new(std::sync::Mutex::new(None)), } } @@ -143,8 +157,89 @@ impl ThreadedTracer { } } - pub fn submit(&self, event: LogSchema) -> Result<()> { - log::info!("Submitting work {}", event.event_id); + pub fn set_log_event_callback(&self, log_event_callback: LogEventCallbackSync) { + // Get a mutable lock on the log_event_callback + let mut callback_lock = self.log_event_callback.lock().unwrap(); + + *callback_lock = Some(log_event_callback); + } + + pub fn submit(&self, mut event: LogSchema) -> Result<()> { + log::debug!("Submitting work {:#?}", event.event_id); + + let callback = self.log_event_callback.lock().unwrap(); + if let Some(ref callback) = *callback { + let event = event.clone(); + let llm_output_model = event.metadata.as_ref().and_then(|m| match m { + MetadataType::Single(llm_event) => Some(llm_event), + // take the last element in the vector + MetadataType::Multi(llm_events) => llm_events.last().clone(), + }); + + let log_event_result = callback(LogEvent { + metadata: LogEventMetadata { + event_id: event.event_id.clone(), + parent_id: event.parent_event_id.clone(), + root_event_id: event.root_event_id.clone(), + }, + prompt: llm_output_model.and_then(|llm_event| { + match llm_event.clone().input.prompt.template { + Template::Single(text) => Some(text), + Template::Multiple(chat_prompt) => { + serde_json::to_string_pretty(&chat_prompt).ok().or_else(|| { + log::info!( + "Failed to serialize chat prompt for event {}", + event.event_id + ); + None + }) + } + } + }), + raw_output: llm_output_model.and_then(|llm_event| { + llm_event + .clone() + .output + .and_then(|output| Some(output.raw_text)) + }), + parsed_output: event.io.output.and_then(|output| match output.value { + // so the string value looks something like: + // '"[\"d\", \"e\", \"f\"]"' + // so we need to unescape it once and turn it into a normal json + // and then stringify it to get: + // '["d", "e", "f"]' + ValueType::String(value) => serde_json::from_str::(&value) + .ok() + .and_then(|json_value| json_value.as_str().map(|s| s.to_string())) + .or_else(|| Some(value)), + _ => serde_json::to_string_pretty(&output.value) + .ok() + .or_else(|| { + log::info!( + "Failed to serialize output value for event {}", + event.event_id + ); + None + }), + }), + start_time: event.context.start_time, + }); + + if log_event_result.is_err() { + log::error!( + "Error calling log_event_callback for event id: {}", + event.event_id + ); + } + + log_event_result?; + } + + // TODO: do the redaction + + // Redact the event + event = redact_event(event, &self.api_config.config); + let tx = self .tx .lock() @@ -153,3 +248,81 @@ impl ThreadedTracer { Ok(()) } } + +fn redact_event(mut event: LogSchema, api_config: &APIConfig) -> LogSchema { + let redaction_enabled = api_config.log_redaction_enabled(); + let placeholder = api_config.log_redaction_placeholder(); + + if !redaction_enabled { + return event; + } + + let placeholder = placeholder + .replace("{root_event.id}", &event.root_event_id) + .replace("{event.id}", &event.event_id); + + // Redact LLMOutputModel raw_text + if let Some(metadata) = &mut event.metadata { + match metadata { + MetadataType::Single(llm_event) => { + if let Some(output) = &mut llm_event.output { + output.raw_text = placeholder.clone(); + } + } + MetadataType::Multi(llm_events) => { + for llm_event in llm_events { + if let Some(output) = &mut llm_event.output { + output.raw_text = placeholder.clone(); + } + } + } + } + } + + // Redact input IO + if let Some(input) = &mut event.io.input { + match &mut input.value { + ValueType::String(s) => *s = placeholder.clone(), + ValueType::List(v) => v.iter_mut().for_each(|s| *s = placeholder.clone()), + } + } + + // Redact output IO + if let Some(output) = &mut event.io.output { + match &mut output.value { + ValueType::String(s) => *s = placeholder.clone(), + ValueType::List(v) => v.iter_mut().for_each(|s| *s = placeholder.clone()), + } + } + + // Redact LLMEventInput Template + if let Some(metadata) = &mut event.metadata { + match metadata { + MetadataType::Single(llm_event) => { + redact_template(&mut llm_event.input.prompt.template, &placeholder); + } + MetadataType::Multi(llm_events) => { + for llm_event in llm_events { + redact_template(&mut llm_event.input.prompt.template, &placeholder); + } + } + } + } + + event +} + +fn redact_template(template: &mut Template, placeholder: &str) { + match template { + Template::Single(s) => *s = placeholder.to_string(), + Template::Multiple(chats) => { + for chat in chats { + for part in &mut chat.content { + if let ContentPart::Text(s) = part { + *s = placeholder.to_string(); + } + } + } + } + } +} diff --git a/engine/baml-runtime/src/tracing/wasm_tracer.rs b/engine/baml-runtime/src/tracing/wasm_tracer.rs index 3f7c25904..8046d5d88 100644 --- a/engine/baml-runtime/src/tracing/wasm_tracer.rs +++ b/engine/baml-runtime/src/tracing/wasm_tracer.rs @@ -4,10 +4,15 @@ use super::api_wrapper::{core_types::LogSchema, APIWrapper, BoundaryAPI}; pub(super) struct NonThreadedTracer { options: APIWrapper, + log_event_callback: Option Result<()> + Send>>, } impl NonThreadedTracer { - pub fn new(api_config: &APIWrapper, _max_batch_size: usize) -> Self { + pub fn new( + api_config: &APIWrapper, + _max_batch_size: usize, + log_event_callback: Option Result<()> + Send>>, + ) -> Self { Self { options: api_config.clone(), } diff --git a/engine/baml-runtime/src/types/mod.rs b/engine/baml-runtime/src/types/mod.rs index b904ddfdd..c7eefb381 100644 --- a/engine/baml-runtime/src/types/mod.rs +++ b/engine/baml-runtime/src/types/mod.rs @@ -1,5 +1,6 @@ mod context_manager; mod expression_helper; +pub mod on_log_event; mod response; pub(crate) mod runtime_context; mod stream; diff --git a/engine/baml-runtime/src/types/on_log_event.rs b/engine/baml-runtime/src/types/on_log_event.rs new file mode 100644 index 000000000..2d2bd1b91 --- /dev/null +++ b/engine/baml-runtime/src/types/on_log_event.rs @@ -0,0 +1,22 @@ +use anyhow::Error; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct LogEvent { + pub metadata: LogEventMetadata, + pub prompt: Option, + pub raw_output: Option, + // json structure or a string + pub parsed_output: Option, + pub start_time: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] + +pub struct LogEventMetadata { + pub event_id: String, + pub parent_id: Option, + pub root_event_id: String, +} + +pub type LogEventCallbackSync = Box Result<(), Error> + Send + Sync>; diff --git a/engine/language-client-codegen/src/python/templates/tracing.py.j2 b/engine/language-client-codegen/src/python/templates/tracing.py.j2 index 9dba027cf..07a25cf84 100644 --- a/engine/language-client-codegen/src/python/templates/tracing.py.j2 +++ b/engine/language-client-codegen/src/python/templates/tracing.py.j2 @@ -3,6 +3,7 @@ from .globals import DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX trace = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.trace_fn set_tags = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.upsert_tags flush = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.flush +on_log_event = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.on_log_event -__all__ = ['trace', 'set_tags', "flush"] +__all__ = ['trace', 'set_tags', "flush", "on_log_event"] diff --git a/engine/language-client-codegen/src/typescript/templates/tracing.ts.j2 b/engine/language-client-codegen/src/typescript/templates/tracing.ts.j2 index c60bc422b..6fe8887c5 100644 --- a/engine/language-client-codegen/src/typescript/templates/tracing.ts.j2 +++ b/engine/language-client-codegen/src/typescript/templates/tracing.ts.j2 @@ -4,5 +4,7 @@ const traceAsync = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.trac const traceSync = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.traceFnSync.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX) const setTags = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.upsertTags.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX) const flush = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.flush.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX) +const onLogEvent = (...args: Parameters) => + DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.onLogEvent(...args); -export { traceAsync, traceSync, setTags, flush } \ No newline at end of file +export { traceAsync, traceSync, setTags, flush, onLogEvent } \ No newline at end of file diff --git a/engine/language_client_python/python_src/baml_py/baml_py.pyi b/engine/language_client_python/python_src/baml_py/baml_py.pyi index cd14454d3..9c2ff24a4 100644 --- a/engine/language_client_python/python_src/baml_py/baml_py.pyi +++ b/engine/language_client_python/python_src/baml_py/baml_py.pyi @@ -79,6 +79,27 @@ class BamlRuntime: ) -> FunctionResultStream: ... def create_context_manager(self) -> RuntimeContextManager: ... def flush(self) -> None: ... + def set_log_event_callback(self, handler: Callable[[BamlLogEvent], None]) -> None: ... + + +class LogEventMetadata: + event_id: str + parent_id: Optional[str] + root_event_id: str + + def __init__(self, event_id: str, parent_id: Optional[str], root_event_id: str) -> None: + ... + +class BamlLogEvent: + metadata: LogEventMetadata + prompt: Optional[str] + raw_output: Optional[str] + parsed_output: Optional[str] + start_time: str + + def __init__(self, metadata: LogEventMetadata, prompt: Optional[str], raw_output: Optional[str], parsed_output: Optional[str], start_time: str) -> None: + ... + class BamlSpan: @staticmethod diff --git a/engine/language_client_python/python_src/baml_py/ctx_manager.py b/engine/language_client_python/python_src/baml_py/ctx_manager.py index 2230c23e9..01d1fd461 100644 --- a/engine/language_client_python/python_src/baml_py/ctx_manager.py +++ b/engine/language_client_python/python_src/baml_py/ctx_manager.py @@ -65,6 +65,9 @@ def end_trace(self, span: BamlSpan, response: typing.Any) -> None: def flush(self) -> None: self.rt.flush() + def on_log_event(self, handler: typing.Callable[[str], None]) -> None: + self.rt.set_log_event_callback(handler) + def trace_fn(self, func: F) -> F: func_name = func.__name__ signature = inspect.signature(func).parameters diff --git a/engine/language_client_python/src/types/runtime.rs b/engine/language_client_python/src/types/runtime.rs index c01761fcb..25140f042 100644 --- a/engine/language_client_python/src/types/runtime.rs +++ b/engine/language_client_python/src/types/runtime.rs @@ -8,12 +8,68 @@ use super::type_builder::TypeBuilder; use baml_runtime::runtime_interface::ExperimentalTracingInterface; use baml_runtime::BamlRuntime as CoreBamlRuntime; use pyo3::prelude::{pymethods, PyResult}; -use pyo3::{PyObject, Python, ToPyObject}; +use pyo3::types::{PyDict, PyTuple}; +use pyo3::{pyclass, PyObject, Python, ToPyObject}; use std::collections::HashMap; use std::path::PathBuf; crate::lang_wrapper!(BamlRuntime, CoreBamlRuntime, clone_safe); +#[derive(Debug, Clone)] +#[pyclass] +pub struct BamlLogEvent { + pub metadata: LogEventMetadata, + pub prompt: Option, + pub raw_output: Option, + // json structure or a string + pub parsed_output: Option, + pub start_time: String, +} + +#[derive(Debug, Clone)] +#[pyclass] +pub struct LogEventMetadata { + pub event_id: String, + pub parent_id: Option, + pub root_event_id: String, +} + +#[pymethods] +impl BamlLogEvent { + fn __repr__(&self) -> String { + format!( + "BamlLogEvent {{\n metadata: {:?},\n prompt: {:?},\n raw_output: {:?},\n parsed_output: {:?},\n start_time: {:?}\n}}", + self.metadata, self.prompt, self.raw_output, self.parsed_output, self.start_time + ) + } + + fn __str__(&self) -> String { + let prompt = self + .prompt + .as_ref() + .map_or("None".to_string(), |p| format!("\"{p}\"")); + let raw_output = self + .raw_output + .as_ref() + .map_or("None".to_string(), |r| format!("\"{r}\"")); + let parsed_output = self + .parsed_output + .as_ref() + .map_or("None".to_string(), |p| format!("\"{p}\"")); + + format!( + "BamlLogEvent {{\n metadata: {{\n event_id: \"{}\",\n parent_id: {},\n root_event_id: \"{}\"\n }},\n prompt: {},\n raw_output: {},\n parsed_output: {},\n start_time: \"{}\"\n}}", + self.metadata.event_id, + self.metadata.parent_id.as_ref().map_or("None".to_string(), |id| format!("\"{}\"", id)), + self.metadata.root_event_id, + prompt, + raw_output, + parsed_output, + self.start_time + ) + } +} + #[pymethods] impl BamlRuntime { #[staticmethod] @@ -122,4 +178,39 @@ impl BamlRuntime { fn flush(&self) -> PyResult<()> { self.inner.flush().map_err(BamlError::from_anyhow) } + + #[pyo3()] + fn set_log_event_callback(&self, callback: PyObject) -> PyResult<()> { + let callback = callback.clone(); + let baml_runtime = self.inner.clone(); + + let res = baml_runtime + .as_ref() + .set_log_event_callback(Box::new(move |log_event| { + Python::with_gil(|py| { + match callback.call1( + py, + (BamlLogEvent { + metadata: LogEventMetadata { + event_id: log_event.metadata.event_id.clone(), + parent_id: log_event.metadata.parent_id.clone(), + root_event_id: log_event.metadata.root_event_id.clone(), + }, + prompt: log_event.prompt.clone(), + raw_output: log_event.raw_output.clone(), + parsed_output: log_event.parsed_output.clone(), + start_time: log_event.start_time.clone(), + },), + ) { + Ok(_) => Ok(()), + Err(e) => { + log::error!("Error calling log_event_callback: {:?}", e); + Err(anyhow::Error::new(e).into()) // Proper error handling + } + } + }) + })); + + Ok(()) + } } diff --git a/engine/language_client_typescript/async_context_vars.d.ts b/engine/language_client_typescript/async_context_vars.d.ts index 36f3f1ebb..377f6ca7f 100644 --- a/engine/language_client_typescript/async_context_vars.d.ts +++ b/engine/language_client_typescript/async_context_vars.d.ts @@ -1,3 +1,4 @@ +import { BamlLogEvent } from '../native'; import { BamlSpan, RuntimeContextManager, BamlRuntime } from './native'; export declare class CtxManager { private rt; @@ -9,6 +10,7 @@ export declare class CtxManager { startTraceAsync(name: string, args: Record): BamlSpan; endTrace(span: BamlSpan, response: any): void; flush(): void; + onLogEvent(callback: (error: any, event: BamlLogEvent) => void): void; traceFnSync ReturnType>(name: string, func: F): F; traceFnAync Promise>(name: string, func: F): F; } diff --git a/engine/language_client_typescript/async_context_vars.d.ts.map b/engine/language_client_typescript/async_context_vars.d.ts.map index 7642518da..007d895f7 100644 --- a/engine/language_client_typescript/async_context_vars.d.ts.map +++ b/engine/language_client_typescript/async_context_vars.d.ts.map @@ -1 +1 @@ -{"version":3,"file":"async_context_vars.d.ts","sourceRoot":"","sources":["typescript_src/async_context_vars.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,QAAQ,EAAE,qBAAqB,EAAE,WAAW,EAAE,MAAM,UAAU,CAAA;AAGvE,qBAAa,UAAU;IACrB,OAAO,CAAC,EAAE,CAAa;IACvB,OAAO,CAAC,GAAG,CAA0C;gBAEzC,EAAE,EAAE,WAAW;IAS3B,UAAU,CAAC,IAAI,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,GAAG,IAAI;IAK9C,GAAG,IAAI,qBAAqB;IAS5B,cAAc,CAAC,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,CAAC,MAAM,EAAE,GAAG,CAAC,GAAG,QAAQ;IAOjE,eAAe,CAAC,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,CAAC,MAAM,EAAE,GAAG,CAAC,GAAG,QAAQ;IAOlE,QAAQ,CAAC,IAAI,EAAE,QAAQ,EAAE,QAAQ,EAAE,GAAG,GAAG,IAAI;IAS7C,KAAK,IAAI,IAAI;IAIb,WAAW,CAAC,UAAU,EAAE,CAAC,SAAS,CAAC,GAAG,IAAI,EAAE,GAAG,EAAE,KAAK,UAAU,EAAE,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,GAAG,CAAC;IAsB3F,WAAW,CAAC,UAAU,EAAE,CAAC,SAAS,CAAC,GAAG,IAAI,EAAE,GAAG,EAAE,KAAK,OAAO,CAAC,UAAU,CAAC,EAAE,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,GAAG,CAAC;CAqBrG"} \ No newline at end of file +{"version":3,"file":"async_context_vars.d.ts","sourceRoot":"","sources":["typescript_src/async_context_vars.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,YAAY,EAAE,MAAM,WAAW,CAAA;AACxC,OAAO,EAAE,QAAQ,EAAE,qBAAqB,EAAE,WAAW,EAAE,MAAM,UAAU,CAAA;AAGvE,qBAAa,UAAU;IACrB,OAAO,CAAC,EAAE,CAAa;IACvB,OAAO,CAAC,GAAG,CAA0C;gBAEzC,EAAE,EAAE,WAAW;IAS3B,UAAU,CAAC,IAAI,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,GAAG,IAAI;IAK9C,GAAG,IAAI,qBAAqB;IAS5B,cAAc,CAAC,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,CAAC,MAAM,EAAE,GAAG,CAAC,GAAG,QAAQ;IAOjE,eAAe,CAAC,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,CAAC,MAAM,EAAE,GAAG,CAAC,GAAG,QAAQ;IAOlE,QAAQ,CAAC,IAAI,EAAE,QAAQ,EAAE,QAAQ,EAAE,GAAG,GAAG,IAAI;IAS7C,KAAK,IAAI,IAAI;IAIb,UAAU,CAAC,QAAQ,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,YAAY,KAAK,IAAI,GAAG,IAAI;IAIrE,WAAW,CAAC,UAAU,EAAE,CAAC,SAAS,CAAC,GAAG,IAAI,EAAE,GAAG,EAAE,KAAK,UAAU,EAAE,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,GAAG,CAAC;IAsB3F,WAAW,CAAC,UAAU,EAAE,CAAC,SAAS,CAAC,GAAG,IAAI,EAAE,GAAG,EAAE,KAAK,OAAO,CAAC,UAAU,CAAC,EAAE,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,GAAG,CAAC;CAqBrG"} \ No newline at end of file diff --git a/engine/language_client_typescript/async_context_vars.js b/engine/language_client_typescript/async_context_vars.js index 62cc8a2f1..709146da8 100644 --- a/engine/language_client_typescript/async_context_vars.js +++ b/engine/language_client_typescript/async_context_vars.js @@ -49,6 +49,9 @@ class CtxManager { flush() { this.rt.flush(); } + onLogEvent(callback) { + this.rt.setLogEventCallback(callback); + } traceFnSync(name, func) { return ((...args) => { const params = args.reduce((acc, arg, i) => ({ diff --git a/engine/language_client_typescript/native.d.ts b/engine/language_client_typescript/native.d.ts index 3bc766bfd..494513884 100644 --- a/engine/language_client_typescript/native.d.ts +++ b/engine/language_client_typescript/native.d.ts @@ -24,6 +24,7 @@ export class BamlRuntime { createContextManager(): RuntimeContextManager callFunction(functionName: string, args: { [string]: any }, ctx: RuntimeContextManager, tb?: TypeBuilder | undefined | null): Promise streamFunction(functionName: string, args: { [string]: any }, cb: (err: any, param: FunctionResult) => void, ctx: RuntimeContextManager, tb?: TypeBuilder | undefined | null): FunctionResultStream + setLogEventCallback(func: (err: any, param: BamlLogEvent) => void): void flush(): void } @@ -87,5 +88,19 @@ export class TypeBuilder { null(): FieldType } +export interface BamlLogEvent { + metadata: LogEventMetadata + prompt?: string + rawOutput?: string + parsedOutput?: string + startTime: string +} + export function invoke_runtime_cli(params: Array): void +export interface LogEventMetadata { + eventId: string + parentId?: string + rootEventId: string +} + diff --git a/engine/language_client_typescript/src/types/lang_wrappers.rs b/engine/language_client_typescript/src/types/lang_wrappers.rs index 31988fad5..bb99ee41c 100644 --- a/engine/language_client_typescript/src/types/lang_wrappers.rs +++ b/engine/language_client_typescript/src/types/lang_wrappers.rs @@ -17,6 +17,23 @@ macro_rules! lang_wrapper { } }; + ($name:ident, $type:ty, clone_safe, custom_finalize $(, $attr_name:ident : $attr_type:ty = $default:expr)*) => { + #[napi_derive::napi(custom_finalize)] + pub struct $name { + pub(crate) inner: std::sync::Arc<$type>, + $($attr_name: $attr_type),* + } + + impl From<$type> for $name { + fn from(inner: $type) -> Self { + Self { + inner: std::sync::Arc::new(inner), + $($attr_name: $default),* + } + } + } + }; + ($name:ident, $type:ty, sync_thread_safe $(, $attr_name:ident : $attr_type:ty)*) => { #[napi_derive::napi] pub struct $name { diff --git a/engine/language_client_typescript/src/types/runtime.rs b/engine/language_client_typescript/src/types/runtime.rs index 0bc331170..590830a2b 100644 --- a/engine/language_client_typescript/src/types/runtime.rs +++ b/engine/language_client_typescript/src/types/runtime.rs @@ -3,17 +3,45 @@ use super::runtime_ctx_manager::RuntimeContextManager; use super::type_builder::TypeBuilder; use crate::parse_ts_types; use crate::types::function_results::FunctionResult; +use baml_runtime::on_log_event::{LogEvent, LogEventCallbackSync}; use baml_runtime::runtime_interface::ExperimentalTracingInterface; use baml_runtime::BamlRuntime as CoreRuntime; use baml_types::BamlValue; -use napi::Env; +use napi::bindgen_prelude::ObjectFinalize; +use napi::threadsafe_function::{ThreadSafeCallContext, ThreadsafeFunctionCallMode}; use napi::JsFunction; use napi::JsObject; +use napi::{Env, JsUndefined}; use napi_derive::napi; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::path::PathBuf; -crate::lang_wrapper!(BamlRuntime, CoreRuntime, clone_safe); +crate::lang_wrapper!(BamlRuntime, + CoreRuntime, + clone_safe, + custom_finalize, + callback: Option> = None +); + +#[napi(object)] +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct LogEventMetadata { + pub event_id: String, + pub parent_id: Option, + pub root_event_id: String, +} + +#[napi(object)] +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct BamlLogEvent { + pub metadata: LogEventMetadata, + pub prompt: Option, + pub raw_output: Option, + // json structure or a string + pub parsed_output: Option, + pub start_time: String, +} #[napi] impl BamlRuntime { @@ -123,9 +151,97 @@ impl BamlRuntime { } #[napi] - pub fn flush(&self) -> napi::Result<()> { - self.inner + pub fn set_log_event_callback( + &mut self, + env: Env, + #[napi(ts_arg_type = "(err: any, param: BamlLogEvent) => void")] func: JsFunction, + ) -> napi::Result { + let cb = env.create_reference(func)?; + // let prev = self.callback.take(); + // if let Some(mut old_cb) = prev { + // old_cb.unref(env)?; + // } + self.callback = Some(cb); + + let res = match &self.callback { + Some(callback_ref) => { + let cb = env.get_reference_value::(callback_ref)?; + let mut tsfn = env.create_threadsafe_function( + &cb, + 0, + |ctx: ThreadSafeCallContext| { + Ok(vec![BamlLogEvent::from(ctx.value)]) + }, + )?; + let tsfn_clone = tsfn.clone(); + + let res = self + .inner + .set_log_event_callback(Box::new(move |event: LogEvent| { + // let env = callback.env; + let event = BamlLogEvent { + metadata: LogEventMetadata { + event_id: event.metadata.event_id, + parent_id: event.metadata.parent_id, + root_event_id: event.metadata.root_event_id, + }, + prompt: event.prompt, + raw_output: event.raw_output, + parsed_output: event.parsed_output, + start_time: event.start_time, + }; + + let res = tsfn_clone.call(Ok(event), ThreadsafeFunctionCallMode::Blocking); + if res != napi::Status::Ok { + log::error!("Error calling on_log_event callback: {:?}", res); + } + + Ok(()) + })) + .map_err(|e| napi::Error::new(napi::Status::GenericFailure, e.to_string())); + let _ = tsfn.unref(&env); + + match res { + Ok(_) => Ok(()), + Err(e) => { + log::error!("Error setting log_event_callback: {:?}", e); + Err(e) + } + } + } + None => Ok(()), + }; + + let _ = match res { + Ok(_) => Ok(env.get_undefined()?), + Err(e) => { + log::error!("Error setting log_event_callback: {:?}", e); + Err(e) + } + }; + + env.get_undefined() + } + + #[napi] + pub fn flush(&mut self, env: Env) -> napi::Result<()> { + let res = self + .inner .flush() - .map_err(|e| napi::Error::new(napi::Status::GenericFailure, e.to_string())) + .map_err(|e| napi::Error::new(napi::Status::GenericFailure, e.to_string())); + + res + } +} + +impl ObjectFinalize for BamlRuntime { + fn finalize(mut self, env: Env) -> napi::Result<()> { + if let Some(mut cb) = self.callback.take() { + match cb.unref(env) { + Ok(_) => (), + Err(e) => log::error!("Error unrefing callback: {:?}", e), + } + } + Ok(()) } } diff --git a/engine/language_client_typescript/typescript_src/async_context_vars.ts b/engine/language_client_typescript/typescript_src/async_context_vars.ts index c2a31d129..076e5c3c8 100644 --- a/engine/language_client_typescript/typescript_src/async_context_vars.ts +++ b/engine/language_client_typescript/typescript_src/async_context_vars.ts @@ -1,3 +1,4 @@ +import { BamlLogEvent } from '../native' import { BamlSpan, RuntimeContextManager, BamlRuntime } from './native' import { AsyncLocalStorage } from 'async_hooks' @@ -55,6 +56,10 @@ export class CtxManager { this.rt.flush() } + onLogEvent(callback: (error: any, event: BamlLogEvent) => void): void { + this.rt.setLogEventCallback(callback) + } + traceFnSync ReturnType>(name: string, func: F): F { return ((...args: any[]) => { const params = args.reduce( diff --git a/integ-tests/python/baml_client/tracing.py b/integ-tests/python/baml_client/tracing.py index b536ee8eb..3ac276994 100644 --- a/integ-tests/python/baml_client/tracing.py +++ b/integ-tests/python/baml_client/tracing.py @@ -18,6 +18,7 @@ trace = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.trace_fn set_tags = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.upsert_tags flush = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.flush +on_log_event = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.on_log_event -__all__ = ['trace', 'set_tags', "flush"] \ No newline at end of file +__all__ = ['trace', 'set_tags', "flush", "on_log_event"] \ No newline at end of file diff --git a/integ-tests/python/test_functions.py b/integ-tests/python/test_functions.py index 8372daa9d..38fe15078 100644 --- a/integ-tests/python/test_functions.py +++ b/integ-tests/python/test_functions.py @@ -7,7 +7,7 @@ import baml_py from baml_client import b from baml_client.types import NamedArgsSingleEnumList, NamedArgsSingleClass -from baml_client.tracing import trace, set_tags, flush +from baml_client.tracing import trace, set_tags, flush, on_log_event from baml_client.type_builder import TypeBuilder import datetime @@ -15,7 +15,7 @@ @pytest.mark.asyncio async def test_should_work_for_all_inputs(): res = await b.TestFnNamedArgsSingleBool(True) - assert res == "true" + assert res res = await b.TestFnNamedArgsSingleStringList(["a", "b", "c"]) assert "a" in res and "b" in res and "c" in res @@ -169,12 +169,11 @@ async def test_streaming(): assert len(final) > 0, "Expected non-empty final but got empty." assert len(msgs) > 0, "Expected at least one streamed response but got none." for prev_msg, msg in zip(msgs, msgs[1:]): - assert msg.startswith(prev_msg), ( - "Expected messages to be continuous, but prev was %r and next was %r" - % ( - prev_msg, - msg, - ) + assert msg.startswith( + prev_msg + ), "Expected messages to be continuous, but prev was %r and next was %r" % ( + prev_msg, + msg, ) assert msgs[-1] == final, "Expected last stream message to match final response." @@ -198,12 +197,11 @@ async def test_streaming_claude(): assert len(final) > 0, "Expected non-empty final but got empty." assert len(msgs) > 0, "Expected at least one streamed response but got none." for prev_msg, msg in zip(msgs, msgs[1:]): - assert msg.startswith(prev_msg), ( - "Expected messages to be continuous, but prev was %r and next was %r" - % ( - prev_msg, - msg, - ) + assert msg.startswith( + prev_msg + ), "Expected messages to be continuous, but prev was %r and next was %r" % ( + prev_msg, + msg, ) print("msgs:") print(msgs[-1]) @@ -223,12 +221,11 @@ async def test_streaming_gemini(): assert len(final) > 0, "Expected non-empty final but got empty." assert len(msgs) > 0, "Expected at least one streamed response but got none." for prev_msg, msg in zip(msgs, msgs[1:]): - assert msg.startswith(prev_msg), ( - "Expected messages to be continuous, but prev was %r and next was %r" - % ( - prev_msg, - msg, - ) + assert msg.startswith( + prev_msg + ), "Expected messages to be continuous, but prev was %r and next was %r" % ( + prev_msg, + msg, ) print("msgs:") print(msgs[-1]) @@ -483,3 +480,14 @@ async def test_nested_class_streaming(): assert len(msgs) > 0, "Expected at least one streamed response but got none." print("final ", final.model_dump(mode="json")) + + +@pytest.mark.asyncio +async def test_event_log_hook(): + def event_log_hook(event): + print("Event log hook1: ") + print("Event log event ", event) + + on_log_event(event_log_hook) + res = await b.TestFnNamedArgsSingleStringList(["a", "b", "c"]) + assert res diff --git a/integ-tests/typescript/baml_client/tracing.ts b/integ-tests/typescript/baml_client/tracing.ts index 6b6b708bd..7facf3689 100644 --- a/integ-tests/typescript/baml_client/tracing.ts +++ b/integ-tests/typescript/baml_client/tracing.ts @@ -21,5 +21,7 @@ const traceAsync = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.trac const traceSync = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.traceFnSync.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX) const setTags = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.upsertTags.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX) const flush = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.flush.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX) +const onLogEvent = (...args: Parameters) => + DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.onLogEvent(...args); -export { traceAsync, traceSync, setTags, flush } \ No newline at end of file +export { traceAsync, traceSync, setTags, flush, onLogEvent } \ No newline at end of file diff --git a/integ-tests/typescript/package.json b/integ-tests/typescript/package.json index 7a9c721d9..7762bb3fa 100644 --- a/integ-tests/typescript/package.json +++ b/integ-tests/typescript/package.json @@ -5,9 +5,10 @@ "main": "index.js", "scripts": { "test": "jest", - "build": "cd ../../clients/ts && npm run build && cd - && pnpm i", - "integ-tests": "BAML_LOG=baml_events infisical run --env=test -- pnpm test -- --silent false --testTimeout 30000", - "integ-tests:dotenv": "BAML_LOG=baml_events dotenv -e ../.env -- pnpm test -- --silent false --testTimeout 30000", + "build:debug": "cd ../../engine/language_client_typescript && pnpm run build:debug && cd - && pnpm i", + "build": "cd ../../engine/language_client_typescript && npm run build && cd - && pnpm i", + "integ-tests": "BAML_LOG=info infisical run --env=dev -- pnpm test -- --silent false --testTimeout 30000", + "integ-tests:dotenv": "BAML_LOG=info dotenv -e ../.env -- pnpm test -- --silent false --testTimeout 30000", "generate": "baml-cli generate --from ../baml_src" }, "keywords": [], diff --git a/integ-tests/typescript/tests/integ-tests.test.ts b/integ-tests/typescript/tests/integ-tests.test.ts index 1810e47d1..3e17d644f 100644 --- a/integ-tests/typescript/tests/integ-tests.test.ts +++ b/integ-tests/typescript/tests/integ-tests.test.ts @@ -2,9 +2,18 @@ import assert from 'assert' import { image_b64, audio_b64 } from './base64_test_data' import { Image} from '@boundaryml/baml' import { Audio } from '@boundaryml/baml' -import { b, NamedArgsSingleEnumList, flush, traceAsync, traceSync, setTags, TestClassNested } from '../baml_client' -import TypeBuilder from "../baml_client/type_builder"; -import { RecursivePartialNull } from '../baml_client/client'; +import { + b, + NamedArgsSingleEnumList, + flush, + traceAsync, + traceSync, + setTags, + TestClassNested, + onLogEvent, +} from '../baml_client' +import TypeBuilder from '../baml_client/type_builder' +import { RecursivePartialNull } from '../baml_client/client' describe('Integ tests', () => { @@ -179,8 +188,6 @@ describe('Integ tests', () => { expect(msgs.at(-1)).toEqual(final) }) - - it('supports tracing sync', async () => { const blah = 'blah' @@ -231,30 +238,45 @@ describe('Integ tests', () => { }) it('should work with dynamics', async () => { - let tb = new TypeBuilder(); - tb.Person.addProperty("last_name", tb.string().optional()); - tb.Person.addProperty("height", tb.float().optional()).description("Height in meters"); - tb.Hobby.addValue("CHESS") + let tb = new TypeBuilder() + tb.Person.addProperty('last_name', tb.string().optional()) + tb.Person.addProperty('height', tb.float().optional()).description('Height in meters') + tb.Hobby.addValue('CHESS') tb.Hobby.listValues().map(([name, v]) => v.alias(name.toLowerCase())) - tb.Person.addProperty("hobbies", tb.Hobby.type().list().optional()).description("Some suggested hobbies they might be good at"); + tb.Person.addProperty('hobbies', tb.Hobby.type().list().optional()).description( + 'Some suggested hobbies they might be good at', + ) - const res = await b.ExtractPeople("My name is Harrison. My hair is black and I'm 6 feet tall. I'm pretty good around the hoop.", { tb }) + const res = await b.ExtractPeople( + "My name is Harrison. My hair is black and I'm 6 feet tall. I'm pretty good around the hoop.", + { tb }, + ) expect(res.length).toBeGreaterThan(0) console.log(res) }) it('should work with nested classes', async () => { - let stream = b.stream.FnOutputClassNested('hi!'); - let msgs: RecursivePartialNull = []; + let stream = b.stream.FnOutputClassNested('hi!') + let msgs: RecursivePartialNull = [] for await (const msg of stream) { console.log('msg', msg) - msgs.push(msg); + msgs.push(msg) } const final = await stream.getFinalResponse() expect(msgs.length).toBeGreaterThan(0) expect(msgs.at(-1)).toEqual(final) }) + + it("should work with 'onLogEvent'", async () => { + onLogEvent((error: any, param2) => { + console.log('msg', error, 'param2', param2) + }) + const res = await b.TestFnNamedArgsSingleStringList(['a', 'b', 'c']) + expect(res).toContain('a') + const res2 = await b.TestFnNamedArgsSingleStringList(['d', 'e', 'f']) + expect(res2).toContain('d') + }) }) function asyncDummyFunc(myArg: string): Promise {