diff --git a/engine/Cargo.lock b/engine/Cargo.lock index 784420819..37608db55 100644 --- a/engine/Cargo.lock +++ b/engine/Cargo.lock @@ -2373,6 +2373,22 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hello" +version = "0.1.0" +dependencies = [ + "baml-cli", + "baml-runtime", + "baml-types", + "internal-baml-codegen", + "libc", + "once_cell", + "serde", + "serde_json", + "tokio", + "tokio-util", +] + [[package]] name = "hermit-abi" version = "0.3.9" diff --git a/engine/Cargo.toml b/engine/Cargo.toml index 9e5b0e68e..e3f058327 100644 --- a/engine/Cargo.toml +++ b/engine/Cargo.toml @@ -11,6 +11,7 @@ members = [ "language_client_python", "language_client_ruby/ext/ruby_ffi", "language_client_typescript", + "language_client_go/lib/baml", "sandbox", ] default-members = [ @@ -28,6 +29,7 @@ default-members = [ "language_client_python", "language_client_ruby/ext/ruby_ffi", "language_client_typescript", + "language_client_go/lib/baml", ] [workspace.dependencies] diff --git a/engine/language_client_go/.gitignore b/engine/language_client_go/.gitignore new file mode 100644 index 000000000..48d8ae47d --- /dev/null +++ b/engine/language_client_go/.gitignore @@ -0,0 +1,8 @@ +main +*.dylib +*.so +*.dll +*.lib +*.exp +*.def +*.res \ No newline at end of file diff --git a/engine/language_client_go/lib/baml.h b/engine/language_client_go/lib/baml.h new file mode 100644 index 000000000..953ae5023 --- /dev/null +++ b/engine/language_client_go/lib/baml.h @@ -0,0 +1,83 @@ +#ifndef RUST_CFFI_H +#define RUST_CFFI_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include +#include + +/* + * Struct representing keyword arguments passed from C to Rust. + * - `len` is the number of key/value pairs. + * - `keys` is an array of null-terminated strings (the keys). + * - `values` is an array of null-terminated strings (the JSON-encoded values). + */ +typedef struct CKwargs { + size_t len; + const char **keys; + const char **values; +} CKwargs; + +/* + * Callback function type. + * The callback receives a pointer to a null-terminated C string containing the JSON result. + * Note: The returned string is allocated by Rust and must be freed using free_string. + */ +typedef void (*ResultCallback)(const char *result); + +/* + * Extern "C" functions exported from the Rust CFFI layer. + */ + +// Prints a hello message. `name` must be a null-terminated string. +void hello(const char *name); + +// Prints a whispered message. `message` must be a null-terminated string. +void whisper(const char *message); + +// Creates and returns a pointer to a Baml runtime instance. +const void* create_baml_runtime(void); + +// Destroys a previously created Baml runtime instance. +void destroy_baml_runtime(const void *runtime); + +/* + * Calls a function in the Baml runtime. + * + * Parameters: + * - runtime: a pointer to the runtime (as returned by create_baml_runtime). + * - function_name: the name of the function to call (null-terminated string). + * - kwargs: pointer to a CKwargs structure containing keyword arguments. + * - callback: a function to be called with the result. + * + * The callback receives a pointer to a C string (JSON) that must later be freed with free_string. + */ +void call_function_from_c(const void *runtime, + const char *function_name, + const CKwargs *kwargs, + ResultCallback callback); + +// Invokes the runtime CLI. `args` is a null-terminated array of null-terminated C strings. +void invoke_runtime_cli(const char * const* args); + +/* + * Frees a C string that was allocated by the Rust runtime (e.g., in call_function_from_c). + * Call this function on any string returned via a callback once it is no longer needed. + */ +void free_string(char *s); + + +// In baml.h +typedef void (*callback_func)(char*); +extern bool register_callback(uint32_t id, callback_func callback); +extern bool unregister_callback(uint32_t id); +extern bool trigger_callback(uint32_t id, char* message); + +#ifdef __cplusplus +} +#endif + +#endif // RUST_CFFI_H diff --git a/engine/language_client_go/lib/baml/Cargo.toml b/engine/language_client_go/lib/baml/Cargo.toml new file mode 100644 index 000000000..fc0cb426e --- /dev/null +++ b/engine/language_client_go/lib/baml/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "hello" +version = "0.1.0" +edition = "2021" + +[lib] +# If you only wanted dynamic library, you'd use only "cdylib". +# If you only wanted static library, you'd use only "staticlib". +# This demo shows both. See https://doc.rust-lang.org/reference/linkage.html +# for more information. +crate-type = ["cdylib", "staticlib"] + +[dependencies] +libc = "0.2.2" +baml-cli.workspace = true +baml-types.workspace = true +baml-runtime = { path = "../../../baml-runtime", default-features = false, features = [ + "internal", +] } +internal-baml-codegen.workspace = true +serde.workspace = true +serde_json.workspace = true +tokio = { version = "1", features = ["full"] } +tokio-util = { version = "0.7", features = ["full"] } +once_cell.workspace = true \ No newline at end of file diff --git a/engine/language_client_go/lib/baml/src/lib.rs b/engine/language_client_go/lib/baml/src/lib.rs new file mode 100644 index 000000000..56b1d858a --- /dev/null +++ b/engine/language_client_go/lib/baml/src/lib.rs @@ -0,0 +1,191 @@ +use std::{ffi::CStr, path::Path}; + +extern crate baml_runtime; +use baml_runtime::BamlRuntime; + +#[no_mangle] +pub extern "C" fn hello(name: *const libc::c_char) { + let name_cstr = unsafe { CStr::from_ptr(name) }; + let name = name_cstr.to_str().unwrap(); + println!("Hello {}!", name); +} + +#[no_mangle] +pub extern "C" fn whisper(message: *const libc::c_char) { + let message_cstr = unsafe { CStr::from_ptr(message) }; + let message = message_cstr.to_str().unwrap(); + println!("({})", message); +} + +#[no_mangle] +pub extern "C" fn create_baml_runtime() -> *const libc::c_void { + const BAML_DIR: &str = "/Users/vbv/repos/gloo-lang/integ-tests/baml_src"; + let env_vars = std::env::vars().into_iter().map(|(k, v)| (k.to_string(), v.to_string())).collect(); + let runtime = BamlRuntime::from_directory(&Path::new(BAML_DIR), env_vars); + Box::into_raw(Box::new(runtime)) as *const libc::c_void +} + +#[no_mangle] +pub extern "C" fn destroy_baml_runtime(runtime: *const libc::c_void) { + unsafe { + let _ = Box::from_raw(runtime as *mut BamlRuntime); + } +} + +#[no_mangle] +pub extern "C" fn invoke_runtime_cli(args: *const *const libc::c_char) { + // Safety: We assume `args` is a valid pointer to a null-terminated array of C strings. + let args_vec = unsafe { + // Ensure the pointer itself is not null. + if args.is_null() { + Vec::new() + } else { + let mut vec = Vec::new(); + let mut i = 0; + // Iterate until a null pointer is encountered. + while !(*args.add(i)).is_null() { + let c_str = CStr::from_ptr(*args.add(i)); + // Convert to Rust String (lossy conversion handles non-UTF8 gracefully). + vec.push(c_str.to_string_lossy().into_owned()); + i += 1; + } + vec + } + }; + baml_cli::run_cli( + args_vec, + baml_runtime::RuntimeCliDefaults { + output_type: baml_types::GeneratorOutputType::PythonPydantic, + }, + ) + .unwrap(); +} + +use std::ffi::CString; +use std::os::raw::c_char; +use std::ptr; + +use baml_types::{BamlMap, BamlValue}; + +#[repr(C)] +pub struct CKwargs { + pub len: libc::size_t, + pub keys: *const *const c_char, + pub values: *const *const c_char, +} + +/// Convert CKwargs to a BamlMap +unsafe fn ckwargs_to_map(kwargs: *const CKwargs) -> BamlMap { + let mut map = BamlMap::new(); + if kwargs.is_null() { + return map; + } + let kwargs_ref = &*kwargs; + for i in 0..(kwargs_ref.len as isize) { + let key_ptr = *kwargs_ref.keys.offset(i); + let value_ptr = *kwargs_ref.values.offset(i); + if let (Ok(key), Ok(value)) = ( + CStr::from_ptr(key_ptr).to_str(), + serde_json::from_str::(CStr::from_ptr(value_ptr).to_str().unwrap()), + ) { + map.insert(key.to_owned(), value.to_owned()); + } + } + map +} + +/// Type for the callback function. +/// The callback receives a pointer to a C string containing the JSON result. +pub type ResultCallback = extern "C" fn(result: *const c_char); + +/// Extern "C" function that returns immediately, scheduling the async call. +/// Once the asynchronous function completes, the provided callback is invoked. +#[no_mangle] +pub extern "C" fn call_function_from_c( + runtime: *const libc::c_void, + function_name: *const c_char, + kwargs: *const CKwargs, + callback: ResultCallback, +) { + // Safety: assume that the pointers provided are valid. + let runtime = unsafe { &*(runtime as *const BamlRuntime) }; + + // Convert the function name. + let func_name = match unsafe { CStr::from_ptr(function_name) }.to_str() { + Ok(s) => s.to_owned(), + Err(_) => { + callback(ptr::null()); + return; + } + }; + + // Convert keyword arguments. + let keyword_args = unsafe { ckwargs_to_map(kwargs) }; + + let ctx = runtime.create_ctx_manager(BamlValue::String("cffi".to_string()), None); + + // Spawn an async task to await the future and call the callback when done. + // Ensure that a Tokio runtime is running in your application. + tokio::spawn(async move { + let future = runtime.call_function(func_name, &keyword_args, &ctx, None, None); + let (result, _) = future.await; + let result_str = match result { + Ok(result) => result.to_string(), + Err(_) => String::new(), + }; + let c_result = CString::new(result_str).unwrap(); + callback(c_result.into_raw()); + // Note: Responsibility for freeing the returned string lies with the caller. + }); +} + + +// This is present so it's easy to test that the code works natively in Rust via `cargo test` +#[cfg(test)] +pub mod test { + + use std::ffi::CString; + use super::*; + + // This is meant to do the same stuff as the main function in the .go files + #[test] + fn simulated_main_function () { + hello(CString::new("world").unwrap().into_raw()); + whisper(CString::new("this is code from Rust").unwrap().into_raw()); + } +} + +// In your Rust code that becomes libhello.dylib +use std::collections::HashMap; +use std::sync::Mutex; +use once_cell::sync::Lazy; + +// Using a static HashMap to store multiple callbacks with an ID +static CALLBACKS: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); + +#[no_mangle] +pub extern "C" fn register_callback(id: u32, callback: extern "C" fn(*const libc::c_char)) -> bool { + let mut callbacks = CALLBACKS.lock().unwrap(); + callbacks.insert(id, callback); + true +} + +#[no_mangle] +pub extern "C" fn unregister_callback(id: u32) -> bool { + let mut callbacks = CALLBACKS.lock().unwrap(); + callbacks.remove(&id).is_some() +} + +#[no_mangle] +pub extern "C" fn trigger_callback(id: u32, message: *const libc::c_char) -> bool { + let callbacks = CALLBACKS.lock().unwrap(); + if let Some(callback) = callbacks.get(&id) { + unsafe { + callback(message); + } + true + } else { + false + } +} \ No newline at end of file diff --git a/engine/language_client_go/main.go b/engine/language_client_go/main.go new file mode 100644 index 000000000..8c4558ef4 --- /dev/null +++ b/engine/language_client_go/main.go @@ -0,0 +1,170 @@ +package main + +/* +#cgo LDFLAGS: ./lib/libhello.dylib -ldl +#include "./lib/baml.h" +#include + +// Declare the callback type +typedef void (*callback_func)(char*); + +// Export callbacks +extern void trampolineCallback1(char* result); +extern void trampolineCallback2(char* result); +extern void trampolineCallback3(char* result); +*/ +import "C" + +import ( + "os" + "strconv" + "sync" + "unsafe" +) + +// Map to store callbacks by ID +var ( + dynamicCallbacks = make(map[uint32]func(*C.char)) + callbackMutex sync.RWMutex + resultChan chan struct{} // Channel to signal completion +) + +//export trampolineCallback1 +func trampolineCallback1(result *C.char) { + handleCallback(1, result) +} + +//export trampolineCallback2 +func trampolineCallback2(result *C.char) { + handleCallback(2, result) +} + +//export trampolineCallback3 +func trampolineCallback3(result *C.char) { + handleCallback(3, result) +} + +func handleCallback(id uint32, result *C.char) { + callbackMutex.RLock() + callback, exists := dynamicCallbacks[id] + callbackMutex.RUnlock() + + if exists { + callback(result) + } +} + +func registerCallback(id uint32, callback func(*C.char)) bool { + callbackMutex.Lock() + dynamicCallbacks[id] = callback + callbackMutex.Unlock() + + // Register the appropriate trampoline based on ID + var success C.bool + switch id { + case 1: + success = C.register_callback(C.uint(id), (C.callback_func)(C.trampolineCallback1)) + case 2: + success = C.register_callback(C.uint(id), (C.callback_func)(C.trampolineCallback2)) + case 3: + success = C.register_callback(C.uint(id), (C.callback_func)(C.trampolineCallback3)) + default: + return false + } + + return bool(success) +} + +func unregisterCallback(id uint32) bool { + callbackMutex.Lock() + delete(dynamicCallbacks, id) + callbackMutex.Unlock() + + return bool(C.unregister_callback(C.uint(id))) +} + +func invokeRuntimeCli() { + args := os.Args + argc := len(args) + cArgs := make([]*C.char, argc+1) + for i, s := range args { + cArgs[i] = C.CString(s) + } + cArgs[argc] = nil + C.invoke_runtime_cli((**C.char)(unsafe.Pointer(&cArgs[0]))) + for i := 0; i < argc; i++ { + C.free(unsafe.Pointer(cArgs[i])) + } +} +func main() { + // Example: call invoke_runtime_cli with os.Args. + + // --- Now call TestOllama function --- + + // 1. Create the Baml runtime. + runtime := C.create_baml_runtime() + if runtime == nil { + println("Failed to create Baml runtime") + return + } + // Ensure the runtime is destroyed at the end. + defer C.destroy_baml_runtime(runtime) + + // 2. Prepare the function name "TestOllama". + funcName := C.CString("TestOllama") + defer C.free(unsafe.Pointer(funcName)) + + // 3. Prepare the argument for TestOllama. + // Assume TestOllama expects one parameter "input" of type string. + input := "Hello from Go" + // JSON-encode the input string (for proper deserialization in Rust). + jsonInput := strconv.Quote(input) // e.g., becomes "\"Hello from Go\"" + cValue := C.CString(jsonInput) + defer C.free(unsafe.Pointer(cValue)) + + // The key for our argument. + key := C.CString("input") + defer C.free(unsafe.Pointer(key)) + + // 4. Build arrays for keys and values. + keys := []*C.char{key} + values := []*C.char{cValue} + + // 5. Create the CKwargs struct. + var kwargs C.CKwargs + kwargs.len = 1 + kwargs.keys = (**C.char)(unsafe.Pointer(&keys[0])) + kwargs.values = (**C.char)(unsafe.Pointer(&values[0])) + + // 6. Prepare to wait for the callback. + resultChan = make(chan struct{}) + + // Use the callback registration system + callbackID := uint32(1) // Using ID 1 for this callback + registerCallback(callbackID, func(result *C.char) { + res := C.GoString(result) + println("Result from TestOllama:", res) + // Signal completion + resultChan <- struct{}{} + }) + + // Create a CString for the result callback function name + callback := C.CString("trampolineCallback1") + defer C.free(unsafe.Pointer(callback)) + + // 7. Call the function via the Rust CFFI layer using the registered callback + C.call_function_from_c( + runtime, + funcName, + &kwargs, + (C.ResultCallback)(C.callback_func(C.trampolineCallback1)), + ) + + // Wait until the callback signals completion. + <-resultChan + + // Clean up by unregistering the callback + unregisterCallback(callbackID) + + println("TestOllama function completed") +} \ No newline at end of file diff --git a/engine/language_client_go/temp-main.go b/engine/language_client_go/temp-main.go new file mode 100644 index 000000000..13847c8cd --- /dev/null +++ b/engine/language_client_go/temp-main.go @@ -0,0 +1,107 @@ +package main + +/* +#cgo LDFLAGS: ./lib/libhello.dylib -ldl +#include "./lib/baml.h" +#include + +// Declare the callback type +typedef void (*callback_func)(char*); + +// Export callbacks +extern void trampolineCallback1(char* result); +extern void trampolineCallback2(char* result); +extern void trampolineCallback3(char* result); +*/ +import "C" +import ( + "sync" + "unsafe" +) + +// Map to store callbacks by ID +var ( + dynamicCallbacks = make(map[uint32]func(*C.char)) + callbackMutex sync.RWMutex +) + +//export trampolineCallback1 +func trampolineCallback1(result *C.char) { + handleCallback(1, result) +} + +//export trampolineCallback2 +func trampolineCallback2(result *C.char) { + handleCallback(2, result) +} + +//export trampolineCallback3 +func trampolineCallback3(result *C.char) { + handleCallback(3, result) +} + +func handleCallback(id uint32, result *C.char) { + callbackMutex.RLock() + callback, exists := dynamicCallbacks[id] + callbackMutex.RUnlock() + + if exists { + callback(result) + } +} + +func registerCallback(id uint32, callback func(*C.char)) bool { + callbackMutex.Lock() + dynamicCallbacks[id] = callback + callbackMutex.Unlock() + + // Register the appropriate trampoline based on ID + var success C.bool + switch id { + case 1: + success = C.register_callback(C.uint(id), (C.callback_func)(C.trampolineCallback1)) + case 2: + success = C.register_callback(C.uint(id), (C.callback_func)(C.trampolineCallback2)) + case 3: + success = C.register_callback(C.uint(id), (C.callback_func)(C.trampolineCallback3)) + default: + return false + } + + return bool(success) +} + +func unregisterCallback(id uint32) bool { + callbackMutex.Lock() + delete(dynamicCallbacks, id) + callbackMutex.Unlock() + + return bool(C.unregister_callback(C.uint(id))) +} + +func main() { + // Register multiple callbacks + registerCallback(1, func(result *C.char) { + res := C.GoString(result) + println("Callback 1 received:", res) + }) + + registerCallback(2, func(result *C.char) { + res := C.GoString(result) + println("Callback 2 received:", res) + }) + + // Test triggering different callbacks + message1 := C.CString("Test message for callback 1") + defer C.free(unsafe.Pointer(message1)) + C.trigger_callback(C.uint(1), message1) + + message2 := C.CString("Test message for callback 2") + defer C.free(unsafe.Pointer(message2)) + C.trigger_callback(C.uint(2), message2) + + // Unregister a callback when done + unregisterCallback(1) + + println("All callbacks processed") +} \ No newline at end of file