Skip to content

Commit

Permalink
Adding various fixes (#806)
Browse files Browse the repository at this point in the history
- set_log_event_callback
   - accepts undefined to clear itself
   - fixed memory leak issues when this was called more than once
- generator version error tells you what file to change
- during streaming, we only yield parseable values (this prevents
crashes that caused streams to hang!)
  • Loading branch information
hellovai authored Jul 19, 2024
1 parent 0a950e0 commit e8c1a61
Show file tree
Hide file tree
Showing 27 changed files with 189 additions and 117 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

use anyhow::Result;
use async_std::stream::StreamExt;
use baml_types::BamlValue;
Expand Down
23 changes: 19 additions & 4 deletions engine/baml-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use baml_types::BamlValue;
use client_registry::ClientRegistry;
use indexmap::IndexMap;
use internal_baml_core::configuration::GeneratorOutputType;
use internal_core::configuration::Generator;
use on_log_event::LogEventCallbackSync;
use runtime::InternalBamlRuntime;

Expand Down Expand Up @@ -273,15 +274,15 @@ impl BamlRuntime {
) -> Result<Vec<internal_baml_codegen::GenerateOutput>> {
use internal_baml_codegen::GenerateClient;

let client_types: Vec<(GeneratorOutputType, internal_baml_codegen::GeneratorArgs)> = self
let client_types: Vec<(&Generator, internal_baml_codegen::GeneratorArgs)> = self
.inner
.ir()
.configuration()
.generators
.iter()
.map(|(generator, _)| {
Ok((
generator.output_type.clone(),
generator,
internal_baml_codegen::GeneratorArgs::new(
generator.output_dir(),
generator.baml_src.clone(),
Expand All @@ -296,7 +297,18 @@ impl BamlRuntime {

client_types
.iter()
.map(|(client_type, args)| client_type.generate_client(self.inner.ir(), args))
.map(|(generator, args)| {
generator
.output_type
.generate_client(self.inner.ir(), args)
.map_err(|e| {
let ((line, col), _) = generator.span.line_and_column();
anyhow::anyhow!(
"Error in file {}:{line}:{col} {e}",
generator.span.file.path()
)
})
})
.collect()
}
}
Expand Down Expand Up @@ -376,7 +388,10 @@ impl ExperimentalTracingInterface for BamlRuntime {
}

#[cfg(not(target_arch = "wasm32"))]
fn set_log_event_callback(&self, log_event_callback: LogEventCallbackSync) -> Result<()> {
fn set_log_event_callback(
&self,
log_event_callback: Option<LogEventCallbackSync>,
) -> Result<()> {
self.tracer.set_log_event_callback(log_event_callback);
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-runtime/src/runtime_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ pub trait ExperimentalTracingInterface {
fn drain_stats(&self) -> crate::InnerTraceStats;

#[cfg(not(target_arch = "wasm32"))]
fn set_log_event_callback(&self, callback: LogEventCallbackSync) -> Result<()>;
fn set_log_event_callback(&self, callback: Option<LogEventCallbackSync>) -> Result<()>;
}

pub trait InternalClientLookup<'a> {
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-runtime/src/tracing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl BamlTracer {
}

#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn set_log_event_callback(&self, log_event_callback: LogEventCallbackSync) {
pub(crate) fn set_log_event_callback(&self, log_event_callback: Option<LogEventCallbackSync>) {
if let Some(tracer) = &self.tracer {
tracer.set_log_event_callback(log_event_callback);
}
Expand Down
4 changes: 2 additions & 2 deletions engine/baml-runtime/src/tracing/threaded_tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,11 @@ impl ThreadedTracer {
anyhow::bail!("BatchProcessor worker thread did not finish in time")
}

pub fn set_log_event_callback(&self, log_event_callback: LogEventCallbackSync) {
pub fn set_log_event_callback(&self, log_event_callback: Option<LogEventCallbackSync>) {
// Get a mutable lock on the log_event_callback
let mut callback_lock = self.log_event_callback.lock().unwrap();

*callback_lock = Some(log_event_callback);
*callback_lock = log_event_callback;
}

pub fn submit(&self, mut event: LogSchema) -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-schema-wasm/src/runtime_wasm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,7 @@ impl WasmRuntime {
generator_language,
is_diagnostic,
)
.map(|error| error.msg)
.map(|error| error.msg())
}

#[wasm_bindgen]
Expand Down
2 changes: 1 addition & 1 deletion engine/language-client-codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ fn version_check_with_error(
true,
);
match res {
Some(e) => Err(anyhow::anyhow!("Version mismatch: {}", e.msg)),
Some(e) => Err(anyhow::anyhow!("{}", e.msg())),
None => Ok(()),
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@ export class BamlSyncClient {
private runtime: BamlRuntime
private ctx_manager: BamlCtxManager

constructor(runtime: BamlRuntime, ctx_manager: BamlCtxManager) {
this.runtime = runtime
this.ctx_manager = ctx_manager
this.stream_client = new BamlStreamClient(runtime, ctx_manager)
}
constructor(private runtime: BamlRuntime, private ctx_manager: BamlCtxManager) {}

/*
* @deprecated NOT IMPLEMENTED as streaming must by async. We
Expand Down
16 changes: 11 additions & 5 deletions engine/language-client-codegen/src/version_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@ use semver::Version;

#[derive(Debug, PartialEq)]
pub struct VersionCheckError {
pub msg: String,
msg: String,
}

impl VersionCheckError {
pub fn msg(&self) -> String {
format!("Version mismatch: {}", self.msg)
}
}

#[derive(Debug, PartialEq)]
Expand Down Expand Up @@ -58,7 +64,7 @@ pub fn check_version(

if gen_version.major != runtime_version.major || gen_version.minor != runtime_version.minor {
let base_message = format!(
"Version mismatch: Generator version ({}) does not match the {} version ({}). Major and minor versions must match.",
"Generator version ({}) does not match the {} version ({}). Major and minor versions must match.",
gen_version,
match generator_type {
GeneratorType::VSCode => "VSCode extension",
Expand Down Expand Up @@ -155,7 +161,7 @@ mod tests {
fn test_mismatched_major_version_cli_python() {
let result = check_version("2.0.0", "1.0.0", GeneratorType::CLI, VersionCheckMode::Strict, GeneratorOutputType::PythonPydantic, false);
assert!(result.is_some());
let error_msg = result.unwrap().msg;
let error_msg = result.unwrap().msg();
assert!(error_msg.contains("Version mismatch"));
assert!(error_msg.contains("installed BAML CLI"));
assert!(error_msg.contains("pip install --upgrade baml-py==2.0.0"));
Expand All @@ -165,7 +171,7 @@ mod tests {
fn test_mismatched_minor_version_vscode_typescript() {
let result = check_version("1.3.0", "1.2.0", GeneratorType::VSCode, VersionCheckMode::Strict, GeneratorOutputType::Typescript, false);
assert!(result.is_some());
let error_msg = result.unwrap().msg;
let error_msg = result.unwrap().msg();
println!("{}", error_msg);
assert!(error_msg.contains("Version mismatch"));
assert!(error_msg.contains("VSCode extension"));
Expand All @@ -176,7 +182,7 @@ mod tests {
fn test_older_vscode_version_ruby() {
let result = check_version("1.3.0", "1.2.0", GeneratorType::VSCodeCLI, VersionCheckMode::Strict, GeneratorOutputType::RubySorbet, false);
assert!(result.is_some());
let error_msg = result.unwrap().msg;
let error_msg = result.unwrap().msg();
assert!(error_msg.contains("Version mismatch"));
assert!(error_msg.contains("baml package"));
assert!(error_msg.contains("gem install baml -v 1.3.0"));
Expand Down
4 changes: 3 additions & 1 deletion engine/language_client_python/python_src/baml_py/baml_py.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class FunctionResult:

def __str__(self) -> str: ...
def parsed(self) -> Any: ...
# Returns True if the function call was successful, False otherwise
def is_ok(self) -> bool: ...

class FunctionResultStream:
"""The result of a BAML function stream.
Expand Down Expand Up @@ -106,7 +108,7 @@ class BamlRuntime:
def flush(self) -> None: ...
def drain_stats(self) -> TraceStats: ...
def set_log_event_callback(
self, handler: Callable[[BamlLogEvent], None]
self, handler: Optional[Callable[[BamlLogEvent], None]]
) -> None: ...

class LogEventMetadata:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def end_trace(self, span: BamlSpan, response: typing.Any) -> None:
def flush(self) -> None:
self.rt.flush()

def on_log_event(self, handler: typing.Callable[[BamlLogEvent], None]) -> None:
def on_log_event(
self, handler: typing.Optional[typing.Callable[[BamlLogEvent], None]]
) -> None:
self.rt.set_log_event_callback(handler)

def trace_fn(self, func: F) -> F:
Expand Down
26 changes: 10 additions & 16 deletions engine/language_client_python/python_src/baml_py/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
SyncFunctionResultStream,
RuntimeContextManager,
)
from typing import Callable, Generic, Optional, TypeVar, Union
from typing import Callable, Generic, Optional, TypeVar
import threading
import asyncio
import concurrent.futures
Expand Down Expand Up @@ -73,7 +73,8 @@ async def __aiter__(self):
event = self.__event_queue.get()
if event is None:
break
yield self.__partial_coerce(event.parsed())
if event.is_ok():
yield self.__partial_coerce(event.parsed())

async def get_final_response(self):
final = self.__drive_to_completion_in_bg()
Expand Down Expand Up @@ -107,21 +108,17 @@ def __init__(
self.__exception = None

def __enqueue(self, data: FunctionResult) -> None:
print("Enqueuing data")
self.__event_queue.put_nowait(data)

def __drive_to_completion(self) -> FunctionResult:
try:
print(f"Driving to completion: {type(self.__ffi_stream)}")
retval = self.__ffi_stream.done(self.__ctx_manager)
print(f"Setting result: {type(retval)}")
self.__result = retval
return retval
except Exception as e:
self.__exception = e
raise e
finally:
print("Putting None in queue")
self.__event_queue.put_nowait(None)

def __drive_to_completion_in_bg(self):
Expand All @@ -139,24 +136,21 @@ def __iter__(self):
while True:
event = self.__event_queue.get()
if event is None:
print("Breaking out of loop")
break
yield self.__partial_coerce(event.parsed())
if event.is_ok():
yield self.__partial_coerce(event.parsed())

def get_final_response(self):
print("Getting final response")
self.__drive_to_completion_in_bg()
if self.__task is not None:
print("Waiting for task to complete")
self.__task.join()
else:
print("Task is None")

if self.__exception is not None:
print("Raising exception")
raise self.__exception

if self.__result is None:
raise Exception("BAML Internal error: Stream did not complete successfully. Please report this issue.")

raise Exception(
"BAML Internal error: Stream did not complete successfully. Please report this issue."
)

return self.__final_coerce(self.__result.parsed())
61 changes: 34 additions & 27 deletions engine/language_client_python/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,36 +268,43 @@ impl BamlRuntime {
}

#[pyo3()]
fn set_log_event_callback(&self, callback: PyObject) -> PyResult<()> {
fn set_log_event_callback(&self, callback: Option<PyObject>) -> PyResult<()> {
let callback = callback.clone();
let baml_runtime = self.inner.clone();

baml_runtime
.as_ref()
.set_log_event_callback(Box::new(move |log_event| {
Python::with_gil(|py| {
match callback.call1(
py,
(BamlLogEvent {
metadata: LogEventMetadata {
event_id: log_event.metadata.event_id.clone(),
parent_id: log_event.metadata.parent_id.clone(),
root_event_id: log_event.metadata.root_event_id.clone(),
},
prompt: log_event.prompt.clone(),
raw_output: log_event.raw_output.clone(),
parsed_output: log_event.parsed_output.clone(),
start_time: log_event.start_time.clone(),
},),
) {
Ok(_) => Ok(()),
Err(e) => {
log::error!("Error calling log_event_callback: {:?}", e);
Err(anyhow::Error::new(e).into()) // Proper error handling
if let Some(callback) = callback {
baml_runtime
.as_ref()
.set_log_event_callback(Some(Box::new(move |log_event| {
Python::with_gil(|py| {
match callback.call1(
py,
(BamlLogEvent {
metadata: LogEventMetadata {
event_id: log_event.metadata.event_id.clone(),
parent_id: log_event.metadata.parent_id.clone(),
root_event_id: log_event.metadata.root_event_id.clone(),
},
prompt: log_event.prompt.clone(),
raw_output: log_event.raw_output.clone(),
parsed_output: log_event.parsed_output.clone(),
start_time: log_event.start_time.clone(),
},),
) {
Ok(_) => Ok(()),
Err(e) => {
log::error!("Error calling log_event_callback: {:?}", e);
Err(anyhow::Error::new(e).into()) // Proper error handling
}
}
}
})
}))
.map_err(BamlError::from_anyhow)
})
})))
.map_err(BamlError::from_anyhow)
} else {
baml_runtime
.as_ref()
.set_log_event_callback(None)
.map_err(BamlError::from_anyhow)
}
}
}
4 changes: 4 additions & 0 deletions engine/language_client_python/src/types/function_results.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ impl FunctionResult {
format!("{:#}", self.inner)
}

fn is_ok(&self) -> bool {
self.inner.parsed_content().is_ok()
}

fn parsed(&self, py: Python<'_>) -> PyResult<PyObject> {
let parsed = self
.inner
Expand Down
2 changes: 1 addition & 1 deletion engine/language_client_typescript/async_context_vars.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export declare class BamlCtxManager {
startTrace(name: string, args: Record<string, any>): [RuntimeContextManager, BamlSpan];
endTrace(span: BamlSpan, response: any): void;
flush(): void;
onLogEvent(callback: (event: BamlLogEvent) => void): void;
onLogEvent(callback: ((event: BamlLogEvent) => void) | undefined): void;
traceFnSync<ReturnType, F extends (...args: any[]) => ReturnType>(name: string, func: F): F;
traceFnAsync<ReturnType, F extends (...args: any[]) => Promise<ReturnType>>(name: string, func: F): F;
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions engine/language_client_typescript/async_context_vars.js
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class BamlCtxManager {
this.rt.flush();
}
onLogEvent(callback) {
if (!callback) {
this.rt.setLogEventCallback(undefined);
return;
}
this.rt.setLogEventCallback((error, param) => {
if (!error) {
callback(param);
Expand Down
Loading

0 comments on commit e8c1a61

Please sign in to comment.