diff --git a/engine/baml-runtime/src/type_builder/json_schema.rs b/engine/baml-runtime/src/type_builder/json_schema.rs index efbb4004d..56f557658 100644 --- a/engine/baml-runtime/src/type_builder/json_schema.rs +++ b/engine/baml-runtime/src/type_builder/json_schema.rs @@ -8,6 +8,7 @@ use internal_baml_jinja::types::{OutputFormatContent, RenderOptions}; use serde::Deserialize; use std::collections::HashMap; use std::collections::HashSet; +use std::f32::consts::E; use crate::internal::prompt_renderer; use crate::RuntimeContext; @@ -26,17 +27,16 @@ pub enum OutputFormatMode { // TODO: // - maps, unions, tuples +// - errors.is_empty() is a bad pattern, should use whether or not new errors were added as a signal +// - root def should use schema.title as the type name +// - handle inline types? need to figure out a schema for the refs #[derive(Debug, Deserialize)] pub struct JsonSchema { - #[serde(rename = "$defs")] + #[serde(default, rename = "$defs")] defs: HashMap, - /// Pydantic includes this by default. - #[serde(rename = "title")] - _title: Option, - #[serde(flatten)] - type_spec: TypeSpec, + type_spec_with_meta: TypeSpecWithMeta, } #[derive(Debug, Deserialize)] @@ -134,10 +134,16 @@ struct RefinedTypeResolver { } impl RefinedTypeResolver { + fn record_type(&mut self, position: &Vec, refined_type: RefinedType) { + self.refined.insert(position.join("/"), refined_type); + } + fn resolve_ref(&self, name: &str) -> Result { + // TODO: this does not handle inline-defined types + let type_name = name.strip_prefix("#/$defs/").unwrap_or(name); match self.refined.get(name) { - Some(RefinedType::Class) => Ok(FieldType::Class(name.to_string())), - Some(RefinedType::Enum) => Ok(FieldType::Enum(name.to_string())), + Some(RefinedType::Class) => Ok(FieldType::Class(type_name.to_string())), + Some(RefinedType::Enum) => Ok(FieldType::Enum(type_name.to_string())), None => anyhow::bail!("Unresolved ref: {}", name), } } @@ -338,14 +344,9 @@ impl Visit1 for JsonSchema { let _ = type_def.visit1(position, resolver, errors); } - let _ = self.type_spec.visit1(position.clone(), resolver, errors); - // for (name, prop) in self.properties.iter() { - // let mut position = position.clone(); - // position.push("properties".to_string()); - // position.push(name.clone()); - - // let _ = prop.type_spec.visit1(position, resolver, errors); - // } + let _ = self + .type_spec_with_meta + .visit1(position.clone(), resolver, errors); if !errors.is_empty() { return Err(()); @@ -356,27 +357,26 @@ impl Visit1 for JsonSchema { } fn position_to_type_name(position: &Vec) -> Result { - if position.len() == 2 { - if position[0] == "$defs" { - return Ok(position[1].clone()); - } + if position.len() == 3 && position[0] == "#" && position[1] == "$defs" { + return Ok(position[2].clone()); + } + + if position.len() == 1 && position[0] == "#" { + return Ok("#".to_string()); } - Ok(position.join("___")) + anyhow::bail!("Only top-level defs are supported: {:?}", position) } -impl Visit1 for TypeSpec { +impl Visit1 for TypeSpecWithMeta { fn visit1( &self, position: Vec, resolver: &mut RefinedTypeResolver, errors: &mut Vec, ) -> core::result::Result<(), ()> { - match self { + match &self.type_spec { TypeSpec::Inline(type_def) => { - let mut position = position.clone(); - position.push("???inline???".to_string()); - let _ = type_def.visit1(position, resolver, errors); } TypeSpec::Ref(_) => {} @@ -385,7 +385,7 @@ impl Visit1 for TypeSpec { let mut position = position.clone(); position.push(format!("anyOf[{}]", i)); - let _ = t.type_spec.visit1(position, resolver, errors); + let _ = t.visit1(position, resolver, errors); } } } @@ -405,49 +405,32 @@ impl Visit1 for TypeDef { ) -> core::result::Result<(), ()> { match self { TypeDef::StringOrEnum(StringOrEnumDef { r#enum: Some(_) }) => { - match position_to_type_name(&position) { - Ok(name) => { - resolver.refined.insert(name, RefinedType::Enum); - } - Err(e) => { - errors.push(SerializationError { - position: position.clone(), - message: format!("{:?}", e), - }); - } - } + resolver.record_type(&position, RefinedType::Enum); + Ok(()) } TypeDef::Class(class_def) => { - match position_to_type_name(&position) { - Ok(name) => { - resolver.refined.insert(name, RefinedType::Class); - } - Err(e) => { - errors.push(SerializationError { - position: position.clone(), - message: format!("{:?}", e), - }); - } - } + resolver.record_type(&position, RefinedType::Class); + + let mut ret = Ok(()); for (field_name, field_type) in class_def.properties.iter() { let mut position = position.clone(); - position.push(format!("properties_{}", field_name)); + position.push(format!("properties:{}", field_name)); - let _ = field_type.type_spec.visit1(position, resolver, errors); + if let Err(field_err) = field_type.visit1(position, resolver, errors) { + ret = Err(field_err); + } } + + ret } TypeDef::Array(array_def) => { let mut position = position.clone(); position.push("items".to_string()); - let _ = array_def.items.type_spec.visit1(position, resolver, errors); + array_def.items.visit1(position, resolver, errors) } - _ => {} + _ => Ok(()), } - if !errors.is_empty() { - return Err(()); - } - Ok(()) } } @@ -459,56 +442,6 @@ struct TypeCollector { } impl TypeCollector { - // fn to_output_format(&self) -> Result { - // let (class_overrides, enum_overrides) = self.tb.to_overrides(); - // let ctx = RuntimeContext { - // env: HashMap::new(), - // tags: HashMap::new(), - // class_override: class_overrides, - // enum_overrides: enum_overrides, - // }; - - // let ir = IntermediateRepr::create_empty(); - - // let output_format = FieldType::null(); - - // let output_format = prompt_renderer::render_output_format(&ir, &ctx, &output_format) - // .context("Failed to render output format")?; - - // match output_format.render(RenderOptions::default()) { - // Ok(Some(s)) => Ok(s), - // Ok(None) => anyhow::bail!("Failed to render output format"), - // Err(e) => anyhow::bail!("Failed to render output format: {:?}", e), - // } - // } - // fn schema_to_field_type(&self, type_spec: &TypeSpec) -> Result { - // Ok(match type_spec { - // TypeSpec::Inline(ref type_def) => match type_def { - // TypeDef::StringOrEnum(StringOrEnumDef { r#enum: None }) => { - // FieldType::Primitive(TypeValue::String) - // } - // TypeDef::StringOrEnum(StringOrEnumDef { r#enum: Some(_) }) => { - // anyhow::bail!("inline TypeDef for enum not allowed") - // } - // TypeDef::Int => FieldType::Primitive(TypeValue::Int), - // TypeDef::Float => FieldType::Primitive(TypeValue::Float), - // TypeDef::Bool => FieldType::Primitive(TypeValue::Bool), - // TypeDef::Null => FieldType::Primitive(TypeValue::Null), - // TypeDef::Array(array_def) => FieldType::List(Box::new( - // self.schema_to_field_type(&array_def.items.type_spec)?, - // )), - // TypeDef::Class(class_def) => anyhow::bail!("inline TypeDef for class not allowed"), - // }, - // TypeSpec::Ref(TypeRef { ref r#ref }) => self.resolver.resolve_ref(r#ref)?, - // TypeSpec::Union(UnionRef { ref any_of }) => FieldType::Union( - // any_of - // .iter() - // .map(|t| self.schema_to_field_type(&t.type_spec)) - // .collect::>()?, - // ), - // }) - // } - fn add_class( &self, position: &Vec, @@ -561,52 +494,19 @@ impl Visit2 for JsonSchema { let _ = type_def.visit2(position, v, errors); } - let cb_arc = v.tb.class("OutputFormat"); - let cb = cb_arc.lock().unwrap(); - - // for (name, prop) in self.properties.iter() { - // let mut position = position.clone(); - // position.push("properties".to_string()); - // position.push(name.clone()); - - // let _ = prop.type_spec.visit2(position.clone(), v, errors); - - // let cb_prop = cb.property(&name); - // match v.schema_to_field_type(&prop.type_spec) { - // Ok(t) => { - // cb_prop.lock().unwrap().r#type(t); - // } - // Err(e) => { - // errors.push(SerializationError { - // position: position, - // message: format!("{:?}", e), - // }); - // } - // } - // } - - if !errors.is_empty() { - return Err(()); - } - - Ok(FieldType::null()) + self.type_spec_with_meta.visit2(position.clone(), v, errors) } } -impl Visit2 for TypeSpec { +impl Visit2 for TypeSpecWithMeta { fn visit2( &self, position: Vec, v: &mut TypeCollector, errors: &mut Vec, ) -> core::result::Result { - match self { - TypeSpec::Inline(type_def) => { - let mut position = position.clone(); - position.push("???inline???".to_string()); - - type_def.visit2(position, v, errors) - } + match &self.type_spec { + TypeSpec::Inline(type_def) => type_def.visit2(position, v, errors), TypeSpec::Ref(TypeRef { ref r#ref }) => match v.resolver.resolve_ref(r#ref) { Ok(t) => Ok(t), Err(e) => { @@ -624,7 +524,7 @@ impl Visit2 for TypeSpec { let mut position = position.clone(); position.push(format!("anyOf[{}]", i)); - if let Ok(one_of) = t.type_spec.visit2(position, v, errors) { + if let Ok(one_of) = t.visit2(position, v, errors) { any_of.push(one_of); } } @@ -654,8 +554,15 @@ impl Visit2 for TypeDef { let mut position = position.clone(); position.push(format!("properties:{}", field_name)); - match field_type.type_spec.visit2(position, v, errors) { - Ok(t) => Ok((field_name.clone(), t)), + match field_type.visit2(position, v, errors) { + Ok(t) => Ok(( + field_name.clone(), + if class_def.required.contains(&field_name) { + t + } else { + FieldType::Optional(Box::new(t)) + }, + )), Err(()) => Err(()), } }) @@ -675,28 +582,26 @@ impl Visit2 for TypeDef { TypeDef::Array(array_def) => { let mut position = position.clone(); position.push("items".to_string()); - match array_def.items.type_spec.visit2(position, v, errors) { - Ok(t) => FieldType::List(Box::new(t)), - Err(()) => { - return Err(()); - } - } - } - TypeDef::StringOrEnum(StringOrEnumDef { r#enum: None }) => { - FieldType::Primitive(TypeValue::String) + array_def + .items + .visit2(position, v, errors) + .map(|t| FieldType::List(Box::new(t)))? } TypeDef::StringOrEnum(StringOrEnumDef { r#enum: Some(enum_values), }) => match v.add_enum(&position, enum_values.as_slice()) { - Ok(e) => e, + Ok(t) => t, Err(e) => { errors.push(SerializationError { position: position.clone(), - message: format!("Failed to add enum: {:?}", e), + message: format!("Failed to add class: {:?}", e), }); return Err(()); } }, + TypeDef::StringOrEnum(StringOrEnumDef { r#enum: None }) => { + FieldType::Primitive(TypeValue::String) + } TypeDef::Int => FieldType::Primitive(TypeValue::Int), TypeDef::Float => FieldType::Primitive(TypeValue::Float), TypeDef::Bool => FieldType::Primitive(TypeValue::Bool), @@ -705,205 +610,6 @@ impl Visit2 for TypeDef { } } -//---------------------------------------------------------------------- -// trait AddClassOrEnum { -// fn add_class(&self, name: &str, class_def: &ClassDef) -> Result<()>; -// fn add_enum(&self, name: &str, enum_values: &Vec) -> Result<()>; - -// /// Add refs to classes and enums -// fn visit2(&self) -> Result<()>; - -// fn to_field_type(&self, type_spec: &TypeSpecWithMeta) -> Result; -// fn resolve_ref(&self, name: &str) -> Result; -// } - -// impl AddClassOrEnum for TypeBuilder { -// fn add_class(&self, class_name: &str, class_def: &ClassDef) -> Result<()> { -// let class_builder = self.class(&class_name); -// let class_builder = class_builder.lock().unwrap(); -// for (property_name, property_type) in class_def.properties.iter() { -// class_builder -// .property(&property_name) -// .lock() -// .unwrap() -// .r#type(property_type.try_into()?); -// } - -// Ok(()) -// } - -// fn add_enum(&self, enum_name: &str, enum_values: &Vec) -> Result<()> { -// let enum_builder = self.r#enum(&enum_name); -// let enum_builder = enum_builder.lock().unwrap(); -// for v in enum_values.iter() { -// enum_builder.value(&v); -// } -// Ok(()) -// } - -// fn visit2(&self) -> Result<()> { -// todo!() -// } - -// fn to_field_type(&self, type_spec: &TypeSpecWithMeta) -> Result { -// Ok(match &type_spec.type_spec { -// TypeSpec::Inline(type_def) => match type_def { -// TypeDef::StringOrEnum(StringOrEnumDef { r#enum: None }) => { -// FieldType::Primitive(TypeValue::String) -// } -// TypeDef::StringOrEnum(StringOrEnumDef { r#enum: Some(_) }) => { -// anyhow::bail!("inline TypeDef for enum not allowed") -// } -// TypeDef::Int => FieldType::Primitive(TypeValue::Int), -// TypeDef::Float => FieldType::Primitive(TypeValue::Float), -// TypeDef::Bool => FieldType::Primitive(TypeValue::Bool), -// TypeDef::Null => FieldType::Primitive(TypeValue::Null), -// TypeDef::Array(array_def) => { -// FieldType::List(Box::new(self.to_field_type(&array_def.items)?)) -// } -// TypeDef::Class(class_def) => anyhow::bail!("inline TypeDef for class not allowed"), -// }, -// TypeSpec::Ref(TypeRef { r#ref }) => match r#ref.strip_prefix("#/$defs/") { -// Some(ref_name) => self.resolve_ref(ref_name)?, -// None => anyhow::bail!("Invalid ref: {}", r#ref), -// }, -// TypeSpec::Union(UnionRef { any_of }) => FieldType::Union( -// any_of -// .iter() -// .map(|t| self.to_field_type(t)) -// .collect::>()?, -// ), -// }) -// } -// fn resolve_ref(&self, name: &str) -> Result { -// let classes = self.classes.clone(); -// let classes = classes.lock().unwrap(); -// let enums = self.enums.clone(); -// let enums = enums.lock().unwrap(); - -// if classes.contains_key(name) { -// return Ok(FieldType::Class(name.to_string())); -// } -// if enums.contains_key(name) { -// return Ok(FieldType::Enum(name.to_string())); -// } - -// anyhow::bail!("Unknown ref: {}", name) -// } -// } - -// impl TryInto for &JsonSchema { -// type Error = anyhow::Error; - -// fn try_into(self) -> Result { -// log::debug!("Converting JsonSchema to TypeBuilder: {:#?}", self); - -// let t = TypeBuilder::new(); - -// for (type_name, type_def) in self.defs.iter() { -// match type_def { -// TypeDef::StringOrEnum(string_or_enum_def) => { -// if let Some(ref enum_values) = string_or_enum_def.r#enum { -// t.add_enum(type_name, enum_values)?; -// } -// } -// TypeDef::Class(class_def) => t.add_class(type_name, class_def)?, -// _ => {} -// } -// } - -// let output_type = t.class("OutputFormat"); -// let output_type = output_type.lock().unwrap(); -// for (property_name, property_type) in self.properties.iter() { -// output_type -// .property(&property_name) -// .lock() -// .unwrap() -// .r#type(property_type.try_into()?); -// } - -// Ok(t) -// } -// } - -// impl TryInto for &TypeSpecWithMeta { -// type Error = anyhow::Error; -// fn try_into(self) -> Result { -// Ok(match &self.type_spec { -// TypeSpec::Inline(type_def) => match type_def { -// TypeDef::StringOrEnum(StringOrEnumDef { r#enum: None }) => { -// FieldType::Primitive(TypeValue::String) -// } -// TypeDef::StringOrEnum(StringOrEnumDef { r#enum: Some(_) }) => { -// anyhow::bail!("inline TypeDef for enum not allowed") -// } -// TypeDef::Int => FieldType::Primitive(TypeValue::Int), -// TypeDef::Float => FieldType::Primitive(TypeValue::Float), -// TypeDef::Bool => FieldType::Primitive(TypeValue::Bool), -// TypeDef::Null => FieldType::Primitive(TypeValue::Null), -// TypeDef::Array(array_def) => { -// FieldType::List(Box::new((&array_def.items).try_into()?)) -// } -// TypeDef::Class(class_def) => anyhow::bail!("inline TypeDef for class not allowed"), -// }, -// TypeSpec::Ref(TypeRef { r#ref }) => match r#ref.strip_prefix("#/$defs/") { -// //Some(ref_name) => self.resolve_ref(ref_name)?, -// Some(ref_name) => todo!(), -// None => anyhow::bail!("Invalid ref: {}", r#ref), -// }, -// TypeSpec::Union(UnionRef { any_of }) => { -// FieldType::Union(any_of.iter().map(|t| t.try_into()).collect::>()?) -// } -// }) -// } -// } - -// impl Into for &JsonSchema { -// fn into(self) -> OutputFormatContent { -// let mut enums = vec![]; -// let mut classes = vec![]; - -// for (name, type_def) in self.defs.iter() { -// match type_def { -// TypeDef::StringOrEnum(string_or_enum_def) => { -// if let Some(enum_values) = &string_or_enum_def.r#enum { -// enums.push(jt::Enum { -// name: jt::Name::new(name.clone()), -// values: enum_values -// .iter() -// .map(|v| (jt::Name::new(v.clone()), None)) -// .collect(), -// }); -// } -// } -// TypeDef::Class(class_def) => { -// classes.push(jt::Class { -// name: jt::Name::new(name.clone()), -// fields: class_def -// .properties -// .iter() -// .map(|(field_name, field_type)| { -// (jt::Name::new(field_name.clone()), field_type.into(), None) -// }) -// .collect(), -// }); -// } -// _ => {} -// } -// } -// todo!() -// } -// } - -// pub fn create_output_format( -// from_schema: OutputFormatContent, -// mode: OutputFormatMode, -// ) -> Result { -// let rendered = from_schema -// .render(RenderOptions::default()) -// .context("Failed to render output format")?; -// Ok("".to_string()) -// } pub struct JsonSchemaType { inner: FieldType, } @@ -977,30 +683,15 @@ impl AddJsonSchema for TypeBuilder { .iter() .map(|(k, v)| (k.clone(), v.clone())), ); - // let json_schema: json_schema::JsonSchema = serde_json::from_str(&schema)?; - // json_schema.classes_and_enums()?; + println!("{:#?}", self); - // let other: TypeBuilder = (&json_schema).try_into()?; - - // self.classes - // .lock() - // .unwrap() - // .extend(other.classes.lock().unwrap().clone()); - // self.enums - // .lock() - // .unwrap() - // .extend(other.enums.lock().unwrap().clone()); - - // Ok(()) Ok(JsonSchemaType { inner: field_type }) } } #[cfg(test)] mod tests { - use infer::Type; - use super::*; #[test] @@ -1132,7 +823,7 @@ mod tests { "name", "age", "roles", - "primary_address", + //"primary_address", "secondary_addresses", "zebra_addresses", "gpa", @@ -1143,24 +834,27 @@ mod tests { "type": "object" }); - // let schema = JsonSchema::deserialize(&model_json_schema)?; - // println!("{:#?}", schema); - - // let mut resolver = RefinedTypeResolver { - // refined: HashMap::new(), - // }; - // let _ = schema.visit1(vec![], &mut resolver, &mut vec![]); + let tb = TypeBuilder::new(); + let output_format = tb + .add_json_schema(model_json_schema.to_string())? + .output_format(&tb)?; - // println!("{:#?}", resolver); + println!("{}", output_format); - // let mut tc = TypeCollector { - // tb: TypeBuilder::new(), - // resolver, - // }; + Ok(()) + } - // let _ = schema.visit2(vec![], &mut tc, &mut vec![]); - // println!("{:#?}", tc.tb); - // println!("{}", tc.to_output_format()?); + #[test] + fn test1() -> Result<()> { + let model_json_schema = serde_json::json!({ + "enum": [ + "admin", + "user", + "guest" + ], + "title": "Role", + "type": "string" + }); let tb = TypeBuilder::new(); let output_format = tb