diff --git a/engine/baml-lib/baml-core/src/ir/jinja_helpers.rs b/engine/baml-lib/baml-core/src/ir/jinja_helpers.rs index e5d42aea0..b5d3a636c 100644 --- a/engine/baml-lib/baml-core/src/ir/jinja_helpers.rs +++ b/engine/baml-lib/baml-core/src/ir/jinja_helpers.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use baml_types::{BamlValue, JinjaExpression}; +use minijinja::value::Value; use regex::Regex; pub fn get_env<'a>() -> minijinja::Environment<'a> { @@ -9,16 +10,42 @@ pub fn get_env<'a>() -> minijinja::Environment<'a> { env.set_trim_blocks(true); env.set_lstrip_blocks(true); env.add_filter("regex_match", regex_match); + env.add_filter("sum", sum_filter); env } fn regex_match(value: String, regex: String) -> bool { match Regex::new(®ex) { Err(_) => false, - Ok(re) => re.is_match(&value) + Ok(re) => re.is_match(&value), } } +fn sum_filter(value: Vec) -> Value { + let int_sum: Option = value + .iter() + .map(|v| ::try_from(v.clone()).ok()) + .collect::>>() + .map(|ints| ints.into_iter().sum()); + let float_sum: Option = value + .into_iter() + .map(|v| ::try_from(v).ok()) + .collect::>>() + .map(|floats| floats.into_iter().sum()); + // If we could downcast all the Values to ints, return an int. + // Otherwise, if we could downcast all the Values to floats, return the + // float. + // Otherwise, return 0. We rely on our jinja typechecker to make sure an + // erroneous 0 never makes it back to the user. + if int_sum.is_none() && float_sum.is_none() { + log::warn!("The `sum` jinja filter was run against non-numeric arguments") + } + int_sum.map_or( + float_sum.map_or(Value::from(0), |float| Value::from(float)), + |int| Value::from(int), + ) +} + /// Render a bare minijinaja expression with the given context. /// E.g. `"a|length > 2"` with context `{"a": [1, 2, 3]}` will return `"true"`. pub fn render_expression( @@ -50,16 +77,23 @@ pub fn evaluate_predicate( #[cfg(test)] mod tests { - use baml_types::BamlValue; use super::*; - + use baml_types::BamlValue; #[test] fn test_render_expressions() { - let ctx = vec![( - "a".to_string(), - BamlValue::List(vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into()) - ), ("b".to_string(), BamlValue::String("(123)456-7890".to_string()))] + let ctx = vec![ + ( + "a".to_string(), + BamlValue::List( + vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into(), + ), + ), + ( + "b".to_string(), + BamlValue::String("(123)456-7890".to_string()), + ), + ] .into_iter() .collect(); @@ -79,19 +113,55 @@ mod tests { #[test] fn test_render_regex_match() { - let ctx = vec![( - "a".to_string(), - BamlValue::List(vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into()) - ), ("b".to_string(), BamlValue::String("(123)456-7890".to_string()))] + let ctx = vec![ + ( + "a".to_string(), + BamlValue::List( + vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into(), + ), + ), + ( + "b".to_string(), + BamlValue::String("(123)456-7890".to_string()), + ), + ] .into_iter() .collect(); assert_eq!( - render_expression(&JinjaExpression(r##"b|regex_match("123")"##.to_string()), &ctx).unwrap(), + render_expression( + &JinjaExpression(r##"b|regex_match("123")"##.to_string()), + &ctx + ) + .unwrap(), "true" ); assert_eq!( - render_expression(&JinjaExpression(r##"b|regex_match("\\(?\\d{3}\\)?[-.\\s]?\\d{3}[-.\\s]?\\d{4}")"##.to_string()), &ctx).unwrap(), + render_expression( + &JinjaExpression( + r##"b|regex_match("\\(?\\d{3}\\)?[-.\\s]?\\d{3}[-.\\s]?\\d{4}")"##.to_string() + ), + &ctx + ) + .unwrap(), "true" ) } + + #[test] + fn test_sum_filter() { + let ctx = vec![].into_iter().collect(); + assert_eq!( + render_expression(&JinjaExpression( + r#"[1,2]|sum"#.to_string() + ), &ctx).unwrap(), + "3" + ); + + assert_eq!( + render_expression(&JinjaExpression( + r#"[1,2.5]|sum"#.to_string() + ), &ctx).unwrap(), + "3.5" + ); + } } diff --git a/engine/baml-lib/jinja/src/evaluate_type/expr.rs b/engine/baml-lib/jinja/src/evaluate_type/expr.rs index c52d9833f..d2c468e78 100644 --- a/engine/baml-lib/jinja/src/evaluate_type/expr.rs +++ b/engine/baml-lib/jinja/src/evaluate_type/expr.rs @@ -214,6 +214,7 @@ fn tracker_visit_expr<'a>( "slice", "sort", "split", + "sum", "title", "tojson", "json", @@ -291,6 +292,21 @@ fn tracker_visit_expr<'a>( "slice" => Type::Unknown, "sort" => Type::Unknown, "split" => Type::List(Box::new(Type::String)), + "sum" => match inner.clone() { + Type::List(elem_type) => + if elem_type.is_subtype_of(&Type::Float) { + Type::Float + } else if elem_type.is_subtype_of(&Type::Int) { + Type::Int + } else { + ensure_type("(int|float)[]"); + Type::String + } + _ => { + ensure_type("(int|float)[]"); + Type::Bool + }, + } "title" => Type::String, "tojson" | "json" => Type::String, "trim" => Type::String, diff --git a/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs b/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs index c1aad3199..497b5ecc8 100644 --- a/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs +++ b/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs @@ -274,3 +274,28 @@ fn test_output_format() { vec!["Function 'baml::OutputFormat' does not have an argument 'unknown'. Did you mean one of these: 'always_hoist_enums', 'enum_value_prefix', 'or_splitter'?"] ); } + + +#[test] +fn sum_filter() { + let types = PredefinedTypes::default(JinjaContext::Prompt); + assert_eq!( + assert_evaluates_to!(r#"[1,2,3]|sum"#, types), + Type::Int + ); + assert_eq!( + assert_evaluates_to!(r#"[1.1,2.1,3.2]|sum"#, types), + Type::Float + ); + // // This would be nice, but it doesn't work. + // // Type checker says this is a subtype of `int[]`. + // // BUG. + // assert_eq!( + // assert_evaluates_to!(r#"[1.1,2,3]|sum"#, types), + // Type::Float + // ); + assert_eq!( + assert_fails_to!(r#"["hi", 1]|sum"#, types), + vec![r#"'[hi,1]' is a list[(literal["hi"] | literal[1])], expected (int|float)[]"#] + ); +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 000000000..bbfd497e1 --- /dev/null +++ b/flake.nix @@ -0,0 +1,142 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + fenix = { + url = "github:nix-community/fenix"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + flake-compat = { + url = "github:edolstra/flake-compat"; + flake = false; + }; + }; + + outputs = { self, nixpkgs, flake-utils, fenix, ... }: + + let + + # buildTargets = { + # "x86_64-linux" = { + # crossSystemConfig = "x86_64-unknown-linux-musl"; + # rustTarget = "x86_64-unknown-linux-musl"; + # }; + # "aarch64-linux" = { + # crossSystemConfig = "x86_64-unknown-linux-musl"; + # rustTarget = "x86_64-unknown-linux-musl"; + # }; + # "aarch64-darwin" = {}; + # "wasm" = { + # crossSystemConfig = "wasm32-unknown-unknown"; + # rustTarget = "wasm32-unknown-unknown"; + # makeBuildPackageAttrs = pkgsCross: { + # OPENSSL_STATIC = null; + # OPENSSL_LIB_DIR = null; + # OPENSSL_INCLUDE_DIR = null; + # }; + # }; + # }; + + # mkPkgs = buildSystem: targetSystem: import nixpkgs ({ + # system = buildSystem; + # } // (if targetSystem == null then {} else { + # crossSystemcnofig = buildTargets.${targetSystem}.crossSystemConfig; + # })); + + # eachSystem = supportedSystems: callback: builtins.fold' + # (overall: system: overall // { ${system} = callback system; }) + # {} + # supportedSystems; + + in + + flake-utils.lib.eachDefaultSystem (system: + + let + pkgs = nixpkgs.legacyPackages.${system}; + clang = pkgs.llvmPackages_19.clang; + pythonEnv = pkgs.python3.withPackages (ps: []); + + toolchain = with fenix.packages.${system}; combine [ + minimal.cargo + minimal.rustc + latest.rust-std + ]; + + version = (builtins.fromTOML (builtins.readFile ./engine/Cargo.toml)).workspace.package.version; + + appleDeps = with pkgs.darwin.apple_sdk.frameworks; [ + CoreServices + SystemConfiguration + pkgs.libiconv-darwin + ]; + + rustPlatform = pkgs.makeRustPlatform { + inherit (fenix.packages.${system}.minimal) cargo rustc; + inherit (fenix.packages.${system}.latest) rust-std; + }; + + in + { + packages.default = rustPlatform.buildRustPackage { + pname = "baml-cli"; + version = version; + src = let + extraFiles = pkgs.copyPathToStore ./engine/baml-runtime/src/cli/initial_project/baml_src; + in pkgs.symlinkJoin { + name = "source"; + paths = [ ./engine extraFiles ]; + }; + LIBCLANG_PATH = pkgs.libclang.lib + "/lib/"; + BINDGEN_EXTRA_CLANG_ARGS = if pkgs.stdenv.isDarwin then + "-I${pkgs.llvmPackages_19.libclang.lib}/lib/clang/19/headers " + else + "-isystem ${pkgs.llvmPackages_19.libclang.lib}/lib/clang/19/include -isystem ${pkgs.glibc.dev}/include"; + + cargoLock = { lockFile = ./engine/Cargo.lock; outputHashes = { + "pyo3-asyncio-0.21.0" = "sha256-5ZLzWkxp3e2u0B4+/JJTwO9SYKhtmBpMBiyIsTCW5Zw="; + "serde_magnus-0.9.0" = "sha256-+iIHleftJ+Yl9QHEBVI91NOhBw9qtUZfgooHKoyY1w4="; + }; }; + + # Add build-time environment variables + RUSTFLAGS = "-C target-feature=+crt-static --cfg tracing_unstable"; + + # Modify the test phase to only run library tests + checkPhase = '' + runHook preCheck + echo "Running cargo test --lib" + cargo test --lib + runHook postCheck + ''; + + buildInputs = (with pkgs; [ + openssl + pkg-config + lld_19 + pythonEnv + ruby + maturin + nodePackages.pnpm + nodePackages.nodejs + ]) ++ (if pkgs.stdenv.isDarwin then appleDeps else []); + nativeBuildInputs = [ + pkgs.openssl + pkgs.pkg-config + pkgs.ruby + pythonEnv + pkgs.maturin + ]; + PYTHON_SYS_EXECUTABLE="${pythonEnv}/bin/python3"; + LD_LIBRARY_PATH="${pythonEnv}/lib"; + PYTHONPATH="${pythonEnv}/${pythonEnv.sitePackages}"; + CC="${clang}/bin/clang"; + + }; + devShell = pkgs.mkShell rec { + buildInputs = [toolchain]; + PATH="${clang}/bin:$PATH"; + LIBCLANG_PATH = pkgs.libclang.lib + "/lib/"; + }; + } + ); +}