Skip to content

Commit

Permalink
Adding recursive defintions for nested classes (when streaming for pa…
Browse files Browse the repository at this point in the history
…rtials). (#676)

* Added some unit tests in all languages to ensure this continues to
work
  • Loading branch information
hellovai authored Jun 14, 2024
1 parent 2890b2c commit 31e275e
Show file tree
Hide file tree
Showing 22 changed files with 477 additions and 342 deletions.
24 changes: 15 additions & 9 deletions engine/language-client-codegen/src/python/generate_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ impl<'ir> From<ClassWalker<'ir>> for PartialPythonClass<'ir> {
f.elem.name.as_str(),
add_default_value(
&f.elem.r#type.elem,
&f.elem.r#type.elem.to_partial_type_ref(&c.db),
&f.elem.r#type.elem.to_partial_type_ref(&c.db, false),
),
)
})
Expand All @@ -159,7 +159,7 @@ pub fn add_default_value(node: &FieldType, type_str: &String) -> String {

trait ToTypeReferenceInTypeDefinition {
fn to_type_ref(&self, ir: &IntermediateRepr) -> String;
fn to_partial_type_ref(&self, ir: &IntermediateRepr) -> String;
fn to_partial_type_ref(&self, ir: &IntermediateRepr, wrapped: bool) -> String;
}

impl ToTypeReferenceInTypeDefinition for FieldType {
Expand Down Expand Up @@ -202,9 +202,15 @@ impl ToTypeReferenceInTypeDefinition for FieldType {
}
}

fn to_partial_type_ref(&self, ir: &IntermediateRepr) -> String {
fn to_partial_type_ref(&self, ir: &IntermediateRepr, wrapped: bool) -> String {
match self {
FieldType::Class(name) => format!("\"{name}\""),
FieldType::Class(name) => {
if wrapped {
format!("\"{name}\"")
} else {
format!("Optional[\"{name}\"]")
}
}
FieldType::Enum(name) => {
if ir
.find_enum(name)
Expand All @@ -216,32 +222,32 @@ impl ToTypeReferenceInTypeDefinition for FieldType {
format!("Optional[types.{name}]")
}
}
FieldType::List(inner) => format!("List[{}]", inner.to_partial_type_ref(ir)),
FieldType::List(inner) => format!("List[{}]", inner.to_partial_type_ref(ir, true)),
FieldType::Map(key, value) => {
format!(
"Dict[{}, {}]",
key.to_type_ref(ir),
value.to_partial_type_ref(ir)
value.to_partial_type_ref(ir, false)
)
}
FieldType::Primitive(r#type) => format!("Optional[{}]", r#type.to_python()),
FieldType::Union(inner) => format!(
"Optional[Union[{}]]",
inner
.iter()
.map(|t| t.to_partial_type_ref(ir))
.map(|t| t.to_partial_type_ref(ir, true))
.collect::<Vec<_>>()
.join(", ")
),
FieldType::Tuple(inner) => format!(
"Optional[Tuple[{}]]",
inner
.iter()
.map(|t| t.to_partial_type_ref(ir))
.map(|t| t.to_partial_type_ref(ir, false))
.collect::<Vec<_>>()
.join(", ")
),
FieldType::Optional(inner) => inner.to_partial_type_ref(ir),
FieldType::Optional(inner) => inner.to_partial_type_ref(ir, false),
}
}
}
90 changes: 45 additions & 45 deletions engine/language-client-codegen/src/typescript/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct TypescriptClient {
}
struct TypescriptFunction {
name: String,
partial_return_type: String,
// partial_return_type: String,
return_type: String,
args: Vec<(String, bool, String)>,
}
Expand Down Expand Up @@ -76,7 +76,7 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for TypescriptCli
Ok(TypescriptFunction {
name: f.name().to_string(),
return_type: f.elem().output().to_type_ref(ir),
partial_return_type: f.elem().output().to_partial_type_ref(ir),
// partial_return_type: f.elem().output().to_partial_type_ref(ir),
args: match f.inputs() {
either::Either::Left(_args) => anyhow::bail!("Typescript codegen does not support unnamed args: please add names to all arguments of BAML function '{}'", f.name().to_string()),
either::Either::Right(args) => args
Expand Down Expand Up @@ -144,52 +144,52 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for TypescriptIni
trait ToTypeReferenceInClientDefinition {
fn to_type_ref(&self, ir: &IntermediateRepr) -> String;

fn to_partial_type_ref(&self, ir: &IntermediateRepr) -> String;
// fn to_partial_type_ref(&self, ir: &IntermediateRepr) -> String;
}

impl ToTypeReferenceInClientDefinition for FieldType {
fn to_partial_type_ref(&self, ir: &IntermediateRepr) -> String {
match self {
FieldType::Enum(name) => {
if ir
.find_enum(name)
.map(|e| e.item.attributes.get("dynamic_type").is_some())
.unwrap_or(false)
{
format!("(string | {name} | null)")
} else {
format!("({name} | null)")
}
}
FieldType::Class(name) => format!("(Partial<{name}> | null)"),
FieldType::List(inner) => format!("{}[]", inner.to_partial_type_ref(ir)),
FieldType::Map(key, value) => {
format!(
"(Record<{}, {}> | null)",
key.to_type_ref(ir),
value.to_partial_type_ref(ir)
)
}
FieldType::Primitive(r#type) => format!("({} | null)", r#type.to_typescript()),
FieldType::Union(inner) => format!(
"({} | null)",
inner
.iter()
.map(|t| t.to_partial_type_ref(ir))
.collect::<Vec<_>>()
.join(" | ")
),
FieldType::Tuple(inner) => format!(
"([{}] | null)",
inner
.iter()
.map(|t| t.to_partial_type_ref(ir))
.collect::<Vec<_>>()
.join(", ")
),
FieldType::Optional(inner) => format!("({} | null)", inner.to_partial_type_ref(ir)),
}
}
// fn to_partial_type_ref(&self, ir: &IntermediateRepr) -> String {
// match self {
// FieldType::Enum(name) => {
// if ir
// .find_enum(name)
// .map(|e| e.item.attributes.get("dynamic_type").is_some())
// .unwrap_or(false)
// {
// format!("(string | {name} | null)")
// } else {
// format!("({name} | null)")
// }
// }
// FieldType::Class(name) => format!("(RecursivePartialNull<{name}>)"),
// FieldType::List(inner) => format!("{}[]", inner.to_partial_type_ref(ir)),
// FieldType::Map(key, value) => {
// format!(
// "(Record<{}, {}> | null)",
// key.to_type_ref(ir),
// value.to_partial_type_ref(ir)
// )
// }
// FieldType::Primitive(r#type) => format!("({} | null)", r#type.to_typescript()),
// FieldType::Union(inner) => format!(
// "({} | null)",
// inner
// .iter()
// .map(|t| t.to_partial_type_ref(ir))
// .collect::<Vec<_>>()
// .join(" | ")
// ),
// FieldType::Tuple(inner) => format!(
// "([{}] | null)",
// inner
// .iter()
// .map(|t| t.to_partial_type_ref(ir))
// .collect::<Vec<_>>()
// .join(", ")
// ),
// FieldType::Optional(inner) => format!("({} | null)", inner.to_partial_type_ref(ir)),
// }
// }

fn to_type_ref(&self, ir: &IntermediateRepr) -> String {
match self {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ import {
} from "./types"
import TypeBuilder from "./type_builder"

export type RecursivePartialNull<T> = T extends object
? {
[P in keyof T]?: RecursivePartialNull<T[P]>;
}
: T | null;

export class BamlClient {
private stream_client: BamlStreamClient

Expand Down Expand Up @@ -46,7 +52,7 @@ class BamlStreamClient {
{{name}}{% if optional %}?{% endif %}: {{type}},
{%- endfor %}
__baml_options__?: { tb?: TypeBuilder }
): BamlStream<{{ fn.partial_return_type }}, {{ fn.return_type }}> {
): BamlStream<RecursivePartialNull<{{ fn.return_type }}>, {{ fn.return_type }}> {
const raw = this.runtime.streamFunction(
"{{fn.name}}",
{
Expand All @@ -58,9 +64,9 @@ class BamlStreamClient {
this.ctx_manager.get(),
__baml_options__?.tb?.__tb(),
)
return new BamlStream<{{ fn.partial_return_type }}, {{ fn.return_type }}>(
return new BamlStream<RecursivePartialNull<{{ fn.return_type }}>, {{ fn.return_type }}>(
raw,
(a): a is {{ fn.partial_return_type }} => a,
(a): a is RecursivePartialNull<{{ fn.return_type }}> => a,
(a): a is {{ fn.return_type }} => a,
this.ctx_manager.get(),
__baml_options__?.tb?.__tb(),
Expand Down
1 change: 0 additions & 1 deletion integ-tests/baml_src/clients.baml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ client<llm> Ollama {
provider ollama
options {
model llama2
api_key ""
}
}

Expand Down
33 changes: 33 additions & 0 deletions integ-tests/baml_src/test-files/functions/output/class-nested.baml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
class TestClassNested {
prop1 string
prop2 InnerClass
}

class InnerClass {
prop1 string
prop2 string
inner InnerClass2
}

class InnerClass2 {
prop2 int
prop3 float
}

function FnOutputClassNested(input: string) -> TestClassNested {
client Ollama
prompt #"
Return a made up json blob that matches this schema:
{{ctx.output_format}}
---

JSON:
"#
}

test FnOutputClassNested {
functions [FnOutputClassNested]
args {
input "example input"
}
}
20 changes: 1 addition & 19 deletions integ-tests/baml_src/test-files/functions/output/class.baml
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,8 @@ function FnOutputClass(input: string) -> TestOutputClass {
}

test TestClass {
functions [FnOutputClass, FnOutputNestedClass]
functions [FnOutputClass]
args {
input "example input"
}
}



class TestOutputClassNested {
prop1 string
prop2 int
prop3 TestOutputClass
}

function FnOutputNestedClass(input: string) -> TestOutputClassNested {
client GPT35
prompt #"
Return a JSON blob with this schema:
{{ctx.output_format}}

JSON:
"#
}
Loading

0 comments on commit 31e275e

Please sign in to comment.