Skip to content

Commit

Permalink
Fixes issues with BamlSpan for async vs sync thread.
Browse files Browse the repository at this point in the history
Now .finish_span is always sync on everything but WASM as it only pushes items into a queue.
On WASM builds, we still wait for the logs to be sent to complete the function
  • Loading branch information
hellovai committed May 31, 2024
1 parent 5cedf26 commit 0af64b2
Show file tree
Hide file tree
Showing 16 changed files with 219 additions and 190 deletions.
59 changes: 54 additions & 5 deletions engine/baml-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ pub mod type_builder;
mod types;

use std::collections::HashMap;
use std::env;
use std::path::PathBuf;
use std::sync::Arc;

Expand Down Expand Up @@ -158,6 +157,12 @@ impl BamlRuntime {

let mut target_id = None;
if let Some(span) = span {
#[cfg(not(target_arch = "wasm32"))]
match self.tracer.finish_span(span, ctx, None) {
Ok(id) => target_id = id,
Err(e) => log::debug!("Error during logging: {}", e),
}
#[cfg(target_arch = "wasm32")]
match self.tracer.finish_span(span, ctx, None).await {
Ok(id) => target_id = id,
Err(e) => log::debug!("Error during logging: {}", e),
Expand All @@ -184,6 +189,12 @@ impl BamlRuntime {

let mut target_id = None;
if let Some(span) = span {
#[cfg(not(target_arch = "wasm32"))]
match self.tracer.finish_baml_span(span, ctx, &response) {
Ok(id) => target_id = id,
Err(e) => log::debug!("Error during logging: {}", e),
}
#[cfg(target_arch = "wasm32")]
match self.tracer.finish_baml_span(span, ctx, &response).await {
Ok(id) => target_id = id,
Err(e) => log::debug!("Error during logging: {}", e),
Expand Down Expand Up @@ -259,22 +270,60 @@ impl ExperimentalTracingInterface for BamlRuntime {
self.tracer.start_span(function_name, ctx, None, params)
}

#[cfg(not(target_arch = "wasm32"))]
fn finish_function_span(
&self,
span: Option<TracingSpan>,
result: &Result<FunctionResult>,
ctx: &RuntimeContextManager,
) -> Result<Option<uuid::Uuid>> {
if let Some(span) = span {
self.tracer.finish_baml_span(span, ctx, result)
} else {
Ok(None)
}
}

#[cfg(target_arch = "wasm32")]
async fn finish_function_span(
&self,
span: TracingSpan,
span: Option<TracingSpan>,
result: &Result<FunctionResult>,
ctx: &RuntimeContextManager,
) -> Result<Option<uuid::Uuid>> {
self.tracer.finish_baml_span(span, ctx, result).await
if let Some(span) = span {
self.tracer.finish_baml_span(span, ctx, result).await
} else {
Ok(None)
}
}

#[cfg(not(target_arch = "wasm32"))]
fn finish_span(
&self,
span: Option<TracingSpan>,
result: Option<BamlValue>,
ctx: &RuntimeContextManager,
) -> Result<Option<uuid::Uuid>> {
if let Some(span) = span {
self.tracer.finish_span(span, ctx, result)
} else {
Ok(None)
}
}

#[cfg(target_arch = "wasm32")]
async fn finish_span(
&self,
span: TracingSpan,
span: Option<TracingSpan>,
result: Option<BamlValue>,
ctx: &RuntimeContextManager,
) -> Result<Option<uuid::Uuid>> {
self.tracer.finish_span(span, ctx, result).await
if let Some(span) = span {
self.tracer.finish_span(span, ctx, result).await
} else {
Ok(None)
}
}

fn flush(&self) -> Result<()> {
Expand Down
22 changes: 20 additions & 2 deletions engine/baml-runtime/src/runtime_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,36 @@ pub trait ExperimentalTracingInterface {
ctx: &RuntimeContextManager,
) -> (Option<TracingSpan>, RuntimeContext);

#[cfg(target_arch = "wasm32")]
#[allow(async_fn_in_trait)]
async fn finish_function_span(
&self,
span: TracingSpan,
span: Option<TracingSpan>,
result: &Result<FunctionResult>,
ctx: &RuntimeContextManager,
) -> Result<Option<uuid::Uuid>>;

#[cfg(not(target_arch = "wasm32"))]
fn finish_function_span(
&self,
span: Option<TracingSpan>,
result: &Result<FunctionResult>,
ctx: &RuntimeContextManager,
) -> Result<Option<uuid::Uuid>>;

#[cfg(target_arch = "wasm32")]
#[allow(async_fn_in_trait)]
async fn finish_span(
&self,
span: TracingSpan,
span: Option<TracingSpan>,
result: Option<BamlValue>,
ctx: &RuntimeContextManager,
) -> Result<Option<uuid::Uuid>>;

#[cfg(not(target_arch = "wasm32"))]
fn finish_span(
&self,
span: Option<TracingSpan>,
result: Option<BamlValue>,
ctx: &RuntimeContextManager,
) -> Result<Option<uuid::Uuid>>;
Expand Down
64 changes: 64 additions & 0 deletions engine/baml-runtime/src/tracing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ impl BamlTracer {
(Some(span), ctx.create_ctx(tb))
}

#[cfg(target_arch = "wasm32")]
pub(crate) async fn finish_span(
&self,
span: TracingSpan,
Expand Down Expand Up @@ -128,6 +129,34 @@ impl BamlTracer {
}
}

#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn finish_span(
&self,
span: TracingSpan,
ctx: &RuntimeContextManager,
response: Option<BamlValue>,
) -> Result<Option<uuid::Uuid>> {
let Some((span_id, event_chain, tags)) = ctx.exit() else {
anyhow::bail!(
"Attempting to finish a span {:#?} without first starting one. Current context {:#?}",
span,
ctx
);
};

if span.span_id != span_id {
anyhow::bail!("Span ID mismatch: {} != {}", span.span_id, span_id);
}

if let Some(tracer) = &self.tracer {
tracer.submit(response.to_log_schema(&self.options, event_chain, tags, span))?;
Ok(Some(span_id))
} else {
Ok(None)
}
}

#[cfg(target_arch = "wasm32")]
pub(crate) async fn finish_baml_span(
&self,
span: TracingSpan,
Expand Down Expand Up @@ -163,6 +192,41 @@ impl BamlTracer {
Ok(None)
}
}

#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn finish_baml_span(
&self,
span: TracingSpan,
ctx: &RuntimeContextManager,
response: &Result<FunctionResult>,
) -> Result<Option<uuid::Uuid>> {
let Some((span_id, event_chain, tags)) = ctx.exit() else {
anyhow::bail!("Attempting to finish a span without first starting one");
};

if span.span_id != span_id {
anyhow::bail!("Span ID mismatch: {} != {}", span.span_id, span_id);
}

if let Ok(response) = &response {
let name = event_chain.last().map(|s| s.name.as_str());
let is_ok = response.parsed().as_ref().is_some_and(|r| r.is_ok());
log::log!(
target: "baml_events",
if is_ok { log::Level::Info } else { log::Level::Warn },
"{}{}",
name.map(|s| format!("Function {}:\n", s)).unwrap_or_default().purple(),
response
);
}

if let Some(tracer) = &self.tracer {
tracer.submit(response.to_log_schema(&self.options, event_chain, tags, span))?;
Ok(Some(span_id))
} else {
Ok(None)
}
}
}

// Function to convert web_time::SystemTime to ISO 8601 string
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-runtime/src/tracing/threaded_tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl ThreadedTracer {
}
}

pub async fn submit(&self, event: LogSchema) -> Result<()> {
pub fn submit(&self, event: LogSchema) -> Result<()> {
log::info!("Submitting work {}", event.event_id);
let tx = self
.tx
Expand Down
6 changes: 6 additions & 0 deletions engine/baml-runtime/src/types/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ impl FunctionResultStream {

let mut target_id = None;
if let Some(span) = span {
#[cfg(not(target_arch = "wasm32"))]
match self.tracer.finish_baml_span(span, ctx, &res) {
Ok(id) => target_id = id,
Err(e) => log::debug!("Error during logging: {}", e),
}
#[cfg(target_arch = "wasm32")]
match self.tracer.finish_baml_span(span, ctx, &res).await {
Ok(id) => target_id = id,
Err(e) => log::debug!("Error during logging: {}", e),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,8 @@ def start_trace_async(
self.ctx.set(cln)
return BamlSpan.new(self.rt, name, args, cln)

async def end_trace(self, span: BamlSpan, response: typing.Any) -> None:
await span.finish(response, self.ctx.get())

def end_trace_sync(self, span: BamlSpan, response: typing.Any) -> None:
span.finish_sync(response, self.ctx.get())
def end_trace(self, span: BamlSpan, response: typing.Any) -> None:
span.finish(response, self.ctx.get())

def flush(self) -> None:
self.rt.flush()
Expand All @@ -69,10 +66,10 @@ async def async_wrapper(
span = self.start_trace_async(func_name, params)
try:
response = await func(*args, **kwargs)
await self.end_trace(span, response)
self.end_trace(span, response)
return response
except Exception as e:
await self.end_trace(span, e)
self.end_trace(span, e)
raise e

return typing.cast(F, async_wrapper)
Expand All @@ -89,10 +86,10 @@ def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
span = self.start_trace_sync(func_name, params)
try:
response = func(*args, **kwargs)
self.end_trace_sync(span, response)
self.end_trace(span, response)
return response
except Exception as e:
self.end_trace_sync(span, e)
self.end_trace(span, e)
raise e

return typing.cast(F, wrapper)
19 changes: 8 additions & 11 deletions engine/language_client_python/python_src/baml_py/baml_py.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, Tuple

class FunctionResult:
"""The result of a BAML function call.
Expand Down Expand Up @@ -32,16 +32,14 @@ class FunctionResultStream:
async def done(self, ctx: RuntimeContextManager) -> FunctionResult: ...

class BamlImagePy:
@staticmethod
def from_url(url: str) -> BamlImagePy: ...
@staticmethod
def from_base64(base64: str, media_type: str) -> BamlImagePy: ...
@property
def url(self) -> Optional[str]: ...
@url.setter
def url(self, value: Optional[str]) -> None: ...
@property
def base64(self) -> Optional[str]: ...
@base64.setter
def base64(self, value: Optional[str]) -> None: ...
def is_url(self) -> bool: ...
def is_base64(self) -> bool: ...
def as_url(self) -> str: ...
def as_base64(self) -> Tuple[str, str]: ...

class RuntimeContextManager:
def upsert_tags(self, tags: Dict[str, Any]) -> None: ...
Expand Down Expand Up @@ -80,8 +78,7 @@ class BamlSpan:
args: Dict[str, Any],
ctx: RuntimeContextManager,
) -> BamlSpan: ...
async def finish(self, result: Any, ctx: RuntimeContextManager) -> str | None: ...
def finish_sync(self, result: Any, ctx: RuntimeContextManager) -> str | None: ...
def finish(self, result: Any, ctx: RuntimeContextManager) -> str | None: ...

class TypeBuilder:
def __init__(self) -> None: ...
Expand Down
37 changes: 14 additions & 23 deletions engine/language_client_python/src/types/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,24 @@ impl BamlImagePy {
}
}

#[getter]
pub fn get_url(&self) -> PyResult<Option<String>> {
Ok(match &self.inner {
baml_types::BamlImage::Url(url) => Some(url.url.clone()),
_ => None,
})
}

#[getter]
pub fn get_base64(&self) -> PyResult<Option<(String, String)>> {
Ok(match &self.inner {
baml_types::BamlImage::Base64(base64) => {
Some((base64.base64.clone(), base64.media_type.clone()))
}
_ => None,
})
pub fn is_url(&self) -> bool {
matches!(&self.inner, baml_types::BamlImage::Url(_))
}

#[setter]
pub fn set_url(&mut self, url: String) {
self.inner = baml_types::BamlImage::Url(baml_types::ImageUrl::new(url));
pub fn as_url(&self) -> PyResult<String> {
match &self.inner {
baml_types::BamlImage::Url(url) => Ok(url.url.clone()),
_ => Err(crate::BamlError::new_err("Image is not a URL")),
}
}

#[setter]
pub fn set_base64(&mut self, base64: (String, String)) {
self.inner =
baml_types::BamlImage::Base64(baml_types::ImageBase64::new(base64.0, base64.1));
pub fn as_base64(&self) -> PyResult<Vec<String>> {
match &self.inner {
baml_types::BamlImage::Base64(base64) => {
Ok(vec![base64.base64.clone(), base64.media_type.clone()])
}
_ => Err(crate::BamlError::new_err("Image is not base64")),
}
}

pub fn __repr__(&self) -> String {
Expand Down
Loading

0 comments on commit 0af64b2

Please sign in to comment.