-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
genai: design for automatic function calling #75
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
package genai | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"reflect" | ||
|
||
pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb" | ||
"github.com/google/generative-ai-go/internal/support" | ||
) | ||
|
||
// A Tool is a piece of code that enables the system to interact with | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's Tool in file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both OpenAI and Gemini group functions into "tools." The Python SDK follows that, and we do too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, so maybe we should structure it like that too - tool is the high level concept, and functions are one of the tools available. So for example the Also, is containment the right paradigm here? Is a function a kind of tool, or do tools contain functions and potentially other stuff? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Renamed the file. |
||
// external systems to perform an action, or set of actions, outside of | ||
// knowledge and scope of the model. | ||
type Tool struct { | ||
// A list of `FunctionDeclarations` available to the model that can | ||
// be used for function calling. | ||
// | ||
// The model or system does not execute the function. Instead the defined | ||
// function may be returned as a [FunctionCall] | ||
// with arguments to the client side for execution. The model may decide to | ||
// call a subset of these functions by populating | ||
// [FunctionCall][content.part.function_call] in the response. The next | ||
// conversation turn may contain a | ||
// [FunctionResponse][content.part.function_response] | ||
// with the [content.role] "function" generation context for the next model | ||
// turn. | ||
FunctionDeclarations []*FunctionDeclaration | ||
} | ||
|
||
func (v *Tool) toProto() *pb.Tool { | ||
if v == nil { | ||
return nil | ||
} | ||
return &pb.Tool{ | ||
FunctionDeclarations: support.TransformSlice(v.FunctionDeclarations, (*FunctionDeclaration).toProto), | ||
} | ||
} | ||
|
||
func (Tool) fromProto(p *pb.Tool) *Tool { | ||
if p == nil { | ||
return nil | ||
} | ||
return &Tool{ | ||
FunctionDeclarations: support.TransformSlice(p.FunctionDeclarations, (FunctionDeclaration{}).fromProto), | ||
} | ||
} | ||
|
||
// FunctionDeclaration is structured representation of a function declaration as defined by the | ||
// [OpenAPI 3.03 specification](https://spec.openapis.org/oas/v3.0.3). Included | ||
// in this declaration are the function name and parameters. | ||
// Combine FunctionDeclarations into Tools for use in a [ChatSession]. | ||
type FunctionDeclaration struct { | ||
// Required. The name of the function. | ||
// Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum | ||
// length of 63. | ||
Name string | ||
// Required. A brief description of the function. | ||
Description string | ||
// Optional. Describes the parameters to this function. | ||
Parameters *Schema | ||
// If set the Go function to call automatically. Its signature must match | ||
// the schema. Call [NewCallableFunctionDeclaration] to create a FunctionDeclaration | ||
// with schema inferred from the function itself. | ||
Function any | ||
} | ||
|
||
func (v *FunctionDeclaration) toProto() *pb.FunctionDeclaration { | ||
if v == nil { | ||
return nil | ||
} | ||
return &pb.FunctionDeclaration{ | ||
Name: v.Name, | ||
Description: v.Description, | ||
Parameters: v.Parameters.toProto(), | ||
} | ||
} | ||
|
||
func (FunctionDeclaration) fromProto(p *pb.FunctionDeclaration) *FunctionDeclaration { | ||
if p == nil { | ||
return nil | ||
} | ||
return &FunctionDeclaration{ | ||
Name: p.Name, | ||
Description: p.Description, | ||
Parameters: (Schema{}).fromProto(p.Parameters), | ||
} | ||
} | ||
|
||
// NewCallableFunctionDeclaration creates a [FunctionDeclaration] from a Go | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the significance of "Callable" here? Are there also non-callable function declarations? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean "automatically callable." As in, the client will invoke the function for you, instead of returning a FunctionCall Part and having you provide the result. (The automatic calling isn't implemented yet.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So maybe removing the "Callable" can help make this a bit shorter without loss of meaning? After all, the type is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
True, but usually a function named |
||
// function. When added to a [ChatSession], the function will be called | ||
// automatically when the model requests it. | ||
// | ||
// This function infers the schema ([FunctionDeclaration.Parameters]) from the | ||
// function. Not all functions can be represented as Schemas. | ||
// At present, variadic functions are not supported, and parameters | ||
// must be of builtin, pointer, slice or array type. | ||
// An error is returned if the schema cannot be inferred. | ||
// It may still be possible to construct a usable schema for the function; if so, | ||
// build a [FunctionDeclaration] by hand, setting its exported fields. | ||
// | ||
// Parameter names are not available to the program. They can be supplied | ||
// as arguments. If omitted, the names "p0", "p1", ... are used. | ||
func NewCallableFunctionDeclaration(name, description string, function any, paramNames ...string) (*FunctionDeclaration, error) { | ||
schema, err := inferSchema(function, paramNames) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return &FunctionDeclaration{ | ||
Name: name, | ||
Description: description, | ||
Parameters: schema, | ||
Function: function, | ||
}, nil | ||
} | ||
|
||
func inferSchema(function any, paramNames []string) (*Schema, error) { | ||
t := reflect.TypeOf(function) | ||
if t == nil || t.Kind() != reflect.Func { | ||
return nil, fmt.Errorf("value of type %T is not a function", function) | ||
} | ||
if t.IsVariadic() { | ||
return nil, errors.New("variadic functions not supported") | ||
} | ||
params := map[string]*Schema{} | ||
var req []string | ||
for i := 0; i < t.NumIn(); i++ { | ||
var name string | ||
if i < len(paramNames) { | ||
name = paramNames[i] | ||
} else { | ||
name = fmt.Sprintf("p%d", i) | ||
} | ||
s, err := typeSchema(t.In(i)) | ||
if err != nil { | ||
return nil, fmt.Errorf("param %s: %w", name, err) | ||
} | ||
params[name] = s | ||
// All parameters are required. | ||
req = append(req, name) | ||
} | ||
return &Schema{ | ||
Type: TypeObject, | ||
Properties: params, | ||
Required: req, | ||
}, nil | ||
|
||
} | ||
|
||
func typeSchema(t reflect.Type) (_ *Schema, err error) { | ||
defer func() { | ||
if err != nil { | ||
err = fmt.Errorf("%s: %w", t, err) | ||
} | ||
}() | ||
switch t.Kind() { | ||
case reflect.Bool: | ||
return &Schema{Type: TypeBoolean}, nil | ||
case reflect.String: | ||
return &Schema{Type: TypeString}, nil | ||
case reflect.Int, reflect.Int64, reflect.Uint32: | ||
return &Schema{Type: TypeInteger, Format: "int64"}, nil | ||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16: | ||
return &Schema{Type: TypeInteger, Format: "int32"}, nil | ||
case reflect.Float32: | ||
return &Schema{Type: TypeNumber, Format: "float"}, nil | ||
case reflect.Float64, reflect.Uint, reflect.Uint64, reflect.Uintptr: | ||
return &Schema{Type: TypeNumber, Format: "double"}, nil | ||
case reflect.Slice, reflect.Array: | ||
elemSchema, err := typeSchema(t.Elem()) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return &Schema{Type: TypeArray, Items: elemSchema}, nil | ||
case reflect.Pointer: | ||
// Treat a *T as a nullable T. | ||
s, err := typeSchema(t.Elem()) | ||
if err != nil { | ||
return nil, err | ||
} | ||
s.Nullable = true | ||
return s, nil | ||
default: | ||
return nil, errors.New("not supported") | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package genai | ||
|
||
import ( | ||
"reflect" | ||
"testing" | ||
) | ||
|
||
var intSchema = &Schema{Type: TypeInteger, Format: "int64"} | ||
|
||
func TestTypeSchema(t *testing.T) { | ||
for _, test := range []struct { | ||
in any | ||
want *Schema | ||
}{ | ||
{true, &Schema{Type: TypeBoolean}}, | ||
{"", &Schema{Type: TypeString}}, | ||
{1, intSchema}, | ||
{byte(1), &Schema{Type: TypeInteger, Format: "int32"}}, | ||
{1.2, &Schema{Type: TypeNumber, Format: "double"}}, | ||
{float32(1.2), &Schema{Type: TypeNumber, Format: "float"}}, | ||
{new(int), &Schema{Type: TypeInteger, Format: "int64", Nullable: true}}, | ||
{ | ||
[]int{}, | ||
&Schema{Type: TypeArray, Items: intSchema}, | ||
}, | ||
} { | ||
got, err := typeSchema(reflect.TypeOf(test.in)) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
if !reflect.DeepEqual(got, test.want) { | ||
t.Errorf("%T:\ngot %+v\nwant %+v", test.in, got, test.want) | ||
} | ||
} | ||
} | ||
|
||
func TestInferSchema(t *testing.T) { | ||
f := func(a int, b string, c float64) int { return 0 } | ||
got, err := inferSchema(f, []string{"a", "b"}) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
want := &Schema{ | ||
Type: TypeObject, | ||
Properties: map[string]*Schema{ | ||
"a": intSchema, | ||
"b": {Type: TypeString}, | ||
"p2": {Type: TypeNumber, Format: "double"}, | ||
}, | ||
Required: []string{"a", "b", "p2"}, | ||
} | ||
if !reflect.DeepEqual(got, want) { | ||
t.Errorf("\ngot %+v\nwant %+v", got, want) | ||
} | ||
} | ||
|
||
func TestInferSchemaErrors(t *testing.T) { | ||
for i, f := range []any{ | ||
nil, | ||
3, // not a function | ||
func(x ...int) {}, // variadic | ||
func(x any) {}, // unsupported type | ||
} { | ||
_, err := inferSchema(f, nil) | ||
if err == nil { | ||
t.Errorf("#%d: got nil, want error", i) | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example should probably be extended to show this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I started to do that. I can finish when I have an implementation.