Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

logger callback #715

Merged
merged 5 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub async fn make_request(
stream: bool,
) -> Result<(Response, web_time::SystemTime, web_time::Instant), LLMResponse> {
let (system_now, instant_now) = (web_time::SystemTime::now(), web_time::Instant::now());
log::info!("Making request using client {}", client.context().name);
log::debug!("Making request using client {}", client.context().name);

let req = match client.build_request(prompt, stream).build() {
Ok(req) => req,
Expand Down
8 changes: 8 additions & 0 deletions engine/baml-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ use baml_types::BamlMap;
use baml_types::BamlValue;
use indexmap::IndexMap;
use internal_baml_core::configuration::GeneratorOutputType;
use on_log_event::LogEventCallbackSync;
use runtime::InternalBamlRuntime;

#[cfg(not(target_arch = "wasm32"))]
pub use cli::CallerType;
use runtime_interface::ExperimentalTracingInterface;
use runtime_interface::RuntimeConstructor;
use runtime_interface::RuntimeInterface;
use tracing::api_wrapper::core_types::LogSchema;
use tracing::{BamlTracer, TracingSpan};
use type_builder::TypeBuilder;
pub use types::*;
Expand Down Expand Up @@ -329,4 +331,10 @@ impl ExperimentalTracingInterface for BamlRuntime {
fn flush(&self) -> Result<()> {
self.tracer.flush()
}

#[cfg(not(target_arch = "wasm32"))]
fn set_log_event_callback(&self, log_event_callback: LogEventCallbackSync) -> Result<()> {
self.tracer.set_log_event_callback(log_event_callback);
Ok(())
}
}
5 changes: 5 additions & 0 deletions engine/baml-runtime/src/runtime_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ use std::{collections::HashMap, sync::Arc};

use crate::internal::llm_client::llm_provider::LLMProvider;
use crate::internal::llm_client::orchestrator::{OrchestrationScope, OrchestratorNode};
use crate::tracing::api_wrapper::core_types::LogSchema;
use crate::tracing::{BamlTracer, TracingSpan};
use crate::type_builder::TypeBuilder;
use crate::types::on_log_event::LogEventCallbackSync;
use crate::RuntimeContextManager;
use crate::{
internal::{ir_features::IrFeatures, llm_client::retry_policy::CallablePolicy},
Expand Down Expand Up @@ -94,6 +96,9 @@ pub trait ExperimentalTracingInterface {
) -> Result<Option<uuid::Uuid>>;

fn flush(&self) -> Result<()>;

#[cfg(not(target_arch = "wasm32"))]
fn set_log_event_callback(&self, callback: LogEventCallbackSync) -> Result<()>;
}

pub trait InternalClientLookup<'a> {
Expand Down
54 changes: 27 additions & 27 deletions engine/baml-runtime/src/tracing/api_wrapper/core_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;

#[derive(Serialize, Debug)]
pub(crate) struct UpdateTestCase {
pub struct UpdateTestCase {
pub project_id: Option<String>,
pub test_cycle_id: String,
pub test_dataset_name: String,
Expand All @@ -15,8 +15,8 @@ pub(crate) struct UpdateTestCase {
pub error_data: Option<Value>, // Rust doesn't have a direct equivalent of Python's Any type, so we use serde_json::Value
}

#[derive(Serialize, Debug)]
pub(crate) struct LogSchema {
#[derive(Serialize, Debug, Clone)]
pub struct LogSchema {
pub project_id: Option<String>,
pub event_type: EventType,
pub root_event_id: String,
Expand All @@ -29,26 +29,26 @@ pub(crate) struct LogSchema {
}

#[derive(Serialize, Debug, Clone)]
pub(crate) struct IO {
pub struct IO {
pub(crate) input: Option<IOValue>,
pub(crate) output: Option<IOValue>,
}

#[derive(Serialize, Debug, Clone)]
pub(crate) struct IOValue {
pub struct IOValue {
pub(crate) value: ValueType,
pub(crate) r#override: Option<HashMap<String, Value>>,
pub(crate) r#type: TypeSchema,
}

#[derive(Serialize, Debug, Clone)]
pub(crate) struct TypeSchema {
pub struct TypeSchema {
pub(crate) name: TypeSchemaName,
pub(crate) fields: IndexMap<String, String>,
}

#[derive(Serialize, Debug, Clone)]
pub(crate) enum TypeSchemaName {
pub enum TypeSchemaName {
#[serde(rename = "single")]
Single,
#[serde(rename = "multi")]
Expand All @@ -57,7 +57,7 @@ pub(crate) enum TypeSchemaName {

#[derive(Serialize, Debug, Clone)]
#[serde(untagged)]
pub(crate) enum ValueType {
pub enum ValueType {
String(String),
// For mutli-args, we use a list of strings
List(Vec<String>),
Expand Down Expand Up @@ -98,7 +98,7 @@ pub enum EventType {
}

#[derive(Serialize, Debug, Clone)]
pub(crate) struct LogSchemaContext {
pub struct LogSchemaContext {
pub hostname: String,
pub process_id: String,
pub stage: Option<String>,
Expand All @@ -109,7 +109,7 @@ pub(crate) struct LogSchemaContext {
}

#[derive(Serialize, Debug, Clone)]
pub(crate) struct EventChain {
pub struct EventChain {
pub function_name: String,
pub variant_name: Option<String>,
}
Expand All @@ -122,30 +122,30 @@ pub(crate) struct Error {
pub r#override: Option<HashMap<String, Value>>,
}

#[derive(Serialize, Debug, Deserialize, Default)]
pub(crate) struct LLMOutputModelMetadata {
#[derive(Serialize, Debug, Deserialize, Default, Clone)]
pub struct LLMOutputModelMetadata {
pub logprobs: Option<Value>,
pub prompt_tokens: Option<i64>,
pub output_tokens: Option<i64>,
pub total_tokens: Option<i64>,
pub finish_reason: Option<String>,
}

#[derive(Serialize, Debug)]
pub(crate) struct LLMOutputModel {
#[derive(Serialize, Debug, Clone)]
pub struct LLMOutputModel {
pub raw_text: String,
pub metadata: LLMOutputModelMetadata,
pub r#override: Option<HashMap<String, Value>>,
}

#[derive(Serialize, Debug)]
#[derive(Serialize, Debug, Clone)]
pub(crate) struct LLMChat {
pub role: Role,
pub content: Vec<ContentPart>,
}

#[derive(Serialize, Debug)]
pub(crate) enum ContentPart {
#[derive(Serialize, Debug, Clone)]
pub enum ContentPart {
#[serde(rename = "text")]
Text(String),
#[serde(rename = "url_image")]
Expand All @@ -158,9 +158,9 @@ pub(crate) enum ContentPart {
B64Audio(String),
}

#[derive(Serialize, Debug, Deserialize)]
#[derive(Serialize, Debug, Deserialize, Clone)]
#[serde(untagged)]
pub(crate) enum Role {
pub enum Role {
#[serde(rename = "assistant")]
Assistant,
#[serde(rename = "user")]
Expand All @@ -170,35 +170,35 @@ pub(crate) enum Role {
Other(String),
}

#[derive(Serialize, Debug)]
#[derive(Serialize, Debug, Clone)]
pub(crate) struct LLMEventInput {
pub prompt: LLMEventInputPrompt,
pub invocation_params: HashMap<String, Value>,
}

#[derive(Serialize, Debug)]
pub(crate) struct LLMEventSchema {
#[derive(Serialize, Debug, Clone)]
pub struct LLMEventSchema {
pub model_name: String,
pub provider: String,
pub input: LLMEventInput,
pub output: Option<LLMOutputModel>,
pub error: Option<String>,
}

#[derive(Serialize, Debug)]
#[derive(Serialize, Debug, Clone)]
#[serde(untagged)]
pub(crate) enum MetadataType {
pub enum MetadataType {
Single(LLMEventSchema),
Multi(Vec<LLMEventSchema>),
}
#[derive(Serialize, Debug)]
pub(crate) struct LLMEventInputPrompt {
#[derive(Serialize, Debug, Clone)]
pub struct LLMEventInputPrompt {
pub template: Template,
pub template_args: HashMap<String, String>,
pub r#override: Option<HashMap<String, Value>>,
}

#[derive(Serialize, Debug)]
#[derive(Serialize, Debug, Clone)]
#[serde(untagged)]
#[allow(dead_code)]
pub enum Template {
Expand Down
8 changes: 8 additions & 0 deletions engine/baml-runtime/src/tracing/api_wrapper/env_setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@ pub struct Config {
pub stage: String,
#[serde(default = "default_host_name")]
pub host_name: String,
#[serde(default)] // default is false
pub log_redaction_enabled: bool,
#[serde(default = "default_redaction_placeholder")]
pub log_redaction_placeholder: String,
}

fn default_base_url() -> String {
"https://app.boundaryml.com/api".to_string()
}

fn default_redaction_placeholder() -> String {
"<BAML_LOG_REDACTED>".to_string()
}

fn default_sessions_id() -> String {
uuid::Uuid::new_v4().to_string()
}
Expand Down
45 changes: 42 additions & 3 deletions engine/baml-runtime/src/tracing/api_wrapper/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod env_setup;

use anyhow::Result;
pub(super) mod api_interface;
pub(super) mod core_types;
pub(crate) mod core_types;
use instant::Duration;
use serde::{de::DeserializeOwned, Deserialize};
use serde_json::{json, Value};
Expand All @@ -14,11 +14,11 @@ use self::core_types::{TestCaseStatus, UpdateTestCase};

#[derive(Debug, Clone)]
pub struct APIWrapper {
config: APIConfig,
pub(super) config: APIConfig,
}

#[derive(Debug, Clone)]
enum APIConfig {
pub(super) enum APIConfig {
LocalOnly(PartialAPIConfig),
Web(CompleteAPIConfig),
}
Expand Down Expand Up @@ -59,6 +59,20 @@ impl APIConfig {
}
}

pub fn log_redaction_enabled(&self) -> bool {
match self {
Self::LocalOnly(config) => config.log_redaction_enabled,
Self::Web(config) => config.log_redaction_enabled,
}
}

pub fn log_redaction_placeholder(&self) -> &str {
match self {
Self::LocalOnly(config) => &config.log_redaction_placeholder,
Self::Web(config) => &config.log_redaction_placeholder,
}
}

#[allow(dead_code)]
#[allow(clippy::too_many_arguments)]
pub(crate) fn copy_from(
Expand All @@ -69,6 +83,8 @@ impl APIConfig {
sessions_id: Option<&str>,
stage: Option<&str>,
host_name: Option<&str>,
log_redaction_enabled: Option<bool>,
log_redaction_placeholder: Option<String>,
_debug_level: Option<bool>,
) -> Self {
let base_url = base_url.unwrap_or(match self {
Expand All @@ -95,6 +111,14 @@ impl APIConfig {
Self::LocalOnly(config) => &config.host_name,
Self::Web(config) => &config.host_name,
});
let log_redaction_enabled = log_redaction_enabled.unwrap_or_else(|| match self {
Self::LocalOnly(config) => config.log_redaction_enabled,
Self::Web(config) => config.log_redaction_enabled,
});
let log_redaction_placeholder = log_redaction_placeholder.unwrap_or_else(|| match self {
Self::LocalOnly(config) => config.log_redaction_placeholder.clone(),
Self::Web(config) => config.log_redaction_placeholder.clone(),
});

match (api_key, project_id) {
(Some(api_key), Some(project_id)) => Self::Web(CompleteAPIConfig {
Expand All @@ -105,6 +129,8 @@ impl APIConfig {
sessions_id: sessions_id.to_string(),
host_name: host_name.to_string(),
client: create_client().unwrap(),
log_redaction_enabled,
log_redaction_placeholder,
}),
_ => Self::LocalOnly(PartialAPIConfig {
base_url: base_url.to_string(),
Expand All @@ -113,6 +139,8 @@ impl APIConfig {
stage: stage.to_string(),
sessions_id: sessions_id.to_string(),
host_name: host_name.to_string(),
log_redaction_enabled,
log_redaction_placeholder,
}),
}
}
Expand All @@ -126,6 +154,8 @@ pub(super) struct CompleteAPIConfig {
pub stage: String,
pub sessions_id: String,
pub host_name: String,
pub log_redaction_enabled: bool,
pub log_redaction_placeholder: String,

client: reqwest::Client,
}
Expand All @@ -140,6 +170,8 @@ pub(super) struct PartialAPIConfig {
stage: String,
sessions_id: String,
host_name: String,
log_redaction_enabled: bool,
log_redaction_placeholder: String,
}

impl CompleteAPIConfig {
Expand Down Expand Up @@ -318,6 +350,9 @@ impl BoundaryTestAPI for APIWrapper {
impl APIWrapper {
pub fn from_env_vars<T: AsRef<str>>(value: impl Iterator<Item = (T, T)>) -> Self {
let config = env_setup::Config::from_env_vars(value).unwrap();
if config.log_redaction_enabled {
log::info!("Redaction enabled: {}", config.log_redaction_enabled);
}
match (&config.secret, &config.project_id) {
(Some(api_key), Some(project_id)) => Self {
config: APIConfig::Web(CompleteAPIConfig {
Expand All @@ -328,6 +363,8 @@ impl APIWrapper {
sessions_id: config.sessions_id,
host_name: config.host_name,
client: create_client().unwrap(),
log_redaction_enabled: config.log_redaction_enabled,
log_redaction_placeholder: config.log_redaction_placeholder,
}),
},
_ => Self {
Expand All @@ -338,6 +375,8 @@ impl APIWrapper {
stage: config.stage,
sessions_id: config.sessions_id,
host_name: config.host_name,
log_redaction_enabled: config.log_redaction_enabled,
log_redaction_placeholder: config.log_redaction_placeholder,
}),
},
}
Expand Down
Loading
Loading