Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

genai: design for automatic function calling #75

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions genai/client_test.go
Original file line number Diff line number Diff line change
@@ -297,13 +297,9 @@ func TestLive(t *testing.T) {
})
t.Run("tools", func(t *testing.T) {

weatherChat := func(t *testing.T, s *Schema) {
weatherChat := func(t *testing.T, fd *FunctionDeclaration) {
weatherTool := &Tool{
FunctionDeclarations: []*FunctionDeclaration{{
Name: "CurrentWeather",
Description: "Get the current weather in a given location",
Parameters: s,
}},
FunctionDeclarations: []*FunctionDeclaration{fd},
}
model := client.GenerativeModel(*modelName)
model.SetTemperature(0)
@@ -341,20 +337,35 @@ func TestLive(t *testing.T) {
}

t.Run("direct", func(t *testing.T) {
weatherChat(t, &Schema{
Type: TypeObject,
Properties: map[string]*Schema{
"location": {
Type: TypeString,
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: TypeString,
Enum: []string{"celsius", "fahrenheit"},
fd := &FunctionDeclaration{
Name: "CurrentWeather",
Description: "Get the current weather in a given location",
Parameters: &Schema{
Type: TypeObject,
Properties: map[string]*Schema{
"location": {
Type: TypeString,
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: TypeString,
Enum: []string{"celsius", "fahrenheit"},
},
},
Required: []string{"location"},
},
Required: []string{"location"},
})
}
weatherChat(t, fd)
})
t.Run("inferred", func(t *testing.T) {
fds, err := NewCallableFunctionDeclaration(
"CurrentWeather",
"Get the current weather in a given location",
func(string) {}, "location")
if err != nil {
t.Fatal(err)
}
weatherChat(t, fds)
})
})
}
1 change: 0 additions & 1 deletion genai/config.yaml
Original file line number Diff line number Diff line change
@@ -117,7 +117,6 @@ types:

# Types for function calling
Tool:
FunctionDeclaration:
FunctionCall:
FunctionResponse:
Schema:
19 changes: 19 additions & 0 deletions genai/example_test.go
Original file line number Diff line number Diff line change
@@ -361,6 +361,25 @@ func ExampleTool() {
printResponse(res)
}

func ExampleNewCallableFunctionDeclaration() {
// Define a function to use as a tool.
weatherToday := func(city string) string {
return "comfortable, if you have the right clothing"
}
// You can construct the Schema for this function by hand, or
// let Go reflection figure it out.
// This also makes the function automatically callable.
// Reflection can't see parameter names, so provide those too.
fd, err := genai.NewCallableFunctionDeclaration("CurrentWeather", "Get the current weather in a given location", weatherToday, "city")
if err != nil {
// Not every type can be used in a tool function.
panic(err)
}

// Use the FunctionDeclaration to populate Model.Tools.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example should probably be extended to show this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started to do that. I can finish when I have an implementation.

_ = fd
}

func printResponse(resp *genai.GenerateContentResponse) {
for _, cand := range resp.Candidates {
if cand.Content != nil {
186 changes: 186 additions & 0 deletions genai/functions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package genai

import (
"errors"
"fmt"
"reflect"

pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb"
"github.com/google/generative-ai-go/internal/support"
)

// A Tool is a piece of code that enables the system to interact with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's Tool in file functions.go but talking about functions - do we just use the two interchangeably? Aligning with how the Python SDK names things?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both OpenAI and Gemini group functions into "tools." The Python SDK follows that, and we do too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, so maybe we should structure it like that too - tool is the high level concept, and functions are one of the tools available. So for example the Tool type would be in tools.go, not functions.go?

Also, is containment the right paradigm here? Is a function a kind of tool, or do tools contain functions and potentially other stuff?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed the file.
In the protos, a Tool contains FunctionDeclarations. Python keeps that structure.

// external systems to perform an action, or set of actions, outside of
// knowledge and scope of the model.
type Tool struct {
// A list of `FunctionDeclarations` available to the model that can
// be used for function calling.
//
// The model or system does not execute the function. Instead the defined
// function may be returned as a [FunctionCall]
// with arguments to the client side for execution. The model may decide to
// call a subset of these functions by populating
// [FunctionCall][content.part.function_call] in the response. The next
// conversation turn may contain a
// [FunctionResponse][content.part.function_response]
// with the [content.role] "function" generation context for the next model
// turn.
FunctionDeclarations []*FunctionDeclaration
}

func (v *Tool) toProto() *pb.Tool {
if v == nil {
return nil
}
return &pb.Tool{
FunctionDeclarations: support.TransformSlice(v.FunctionDeclarations, (*FunctionDeclaration).toProto),
}
}

func (Tool) fromProto(p *pb.Tool) *Tool {
if p == nil {
return nil
}
return &Tool{
FunctionDeclarations: support.TransformSlice(p.FunctionDeclarations, (FunctionDeclaration{}).fromProto),
}
}

// FunctionDeclaration is structured representation of a function declaration as defined by the
// [OpenAPI 3.03 specification](https://spec.openapis.org/oas/v3.0.3). Included
// in this declaration are the function name and parameters.
// Combine FunctionDeclarations into Tools for use in a [ChatSession].
type FunctionDeclaration struct {
// Required. The name of the function.
// Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum
// length of 63.
Name string
// Required. A brief description of the function.
Description string
// Optional. Describes the parameters to this function.
Parameters *Schema
// If set the Go function to call automatically. Its signature must match
// the schema. Call [NewCallableFunctionDeclaration] to create a FunctionDeclaration
// with schema inferred from the function itself.
Function any
}

func (v *FunctionDeclaration) toProto() *pb.FunctionDeclaration {
if v == nil {
return nil
}
return &pb.FunctionDeclaration{
Name: v.Name,
Description: v.Description,
Parameters: v.Parameters.toProto(),
}
}

func (FunctionDeclaration) fromProto(p *pb.FunctionDeclaration) *FunctionDeclaration {
if p == nil {
return nil
}
return &FunctionDeclaration{
Name: p.Name,
Description: p.Description,
Parameters: (Schema{}).fromProto(p.Parameters),
}
}

// NewCallableFunctionDeclaration creates a [FunctionDeclaration] from a Go
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the significance of "Callable" here? Are there also non-callable function declarations?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean "automatically callable." As in, the client will invoke the function for you, instead of returning a FunctionCall Part and having you provide the result. (The automatic calling isn't implemented yet.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So maybe removing the "Callable" can help make this a bit shorter without loss of meaning? After all, the type is FunctionDeclaration, so NewFunctionDeclaration for the constructor sounds logical

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After all, the type is FunctionDeclaration, so NewFunctionDeclaration for the constructor sounds logical

True, but usually a function named NewFoo does about the same thing as &Foo{...}. Here the New function also makes the function automatically callable.

// function. When added to a [ChatSession], the function will be called
// automatically when the model requests it.
//
// This function infers the schema ([FunctionDeclaration.Parameters]) from the
// function. Not all functions can be represented as Schemas.
// At present, variadic functions are not supported, and parameters
// must be of builtin, pointer, slice or array type.
// An error is returned if the schema cannot be inferred.
// It may still be possible to construct a usable schema for the function; if so,
// build a [FunctionDeclaration] by hand, setting its exported fields.
//
// Parameter names are not available to the program. They can be supplied
// as arguments. If omitted, the names "p0", "p1", ... are used.
func NewCallableFunctionDeclaration(name, description string, function any, paramNames ...string) (*FunctionDeclaration, error) {
schema, err := inferSchema(function, paramNames)
if err != nil {
return nil, err
}
return &FunctionDeclaration{
Name: name,
Description: description,
Parameters: schema,
Function: function,
}, nil
}

func inferSchema(function any, paramNames []string) (*Schema, error) {
t := reflect.TypeOf(function)
if t == nil || t.Kind() != reflect.Func {
return nil, fmt.Errorf("value of type %T is not a function", function)
}
if t.IsVariadic() {
return nil, errors.New("variadic functions not supported")
}
params := map[string]*Schema{}
var req []string
for i := 0; i < t.NumIn(); i++ {
var name string
if i < len(paramNames) {
name = paramNames[i]
} else {
name = fmt.Sprintf("p%d", i)
}
s, err := typeSchema(t.In(i))
if err != nil {
return nil, fmt.Errorf("param %s: %w", name, err)
}
params[name] = s
// All parameters are required.
req = append(req, name)
}
return &Schema{
Type: TypeObject,
Properties: params,
Required: req,
}, nil

}

func typeSchema(t reflect.Type) (_ *Schema, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("%s: %w", t, err)
}
}()
switch t.Kind() {
case reflect.Bool:
return &Schema{Type: TypeBoolean}, nil
case reflect.String:
return &Schema{Type: TypeString}, nil
case reflect.Int, reflect.Int64, reflect.Uint32:
return &Schema{Type: TypeInteger, Format: "int64"}, nil
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16:
return &Schema{Type: TypeInteger, Format: "int32"}, nil
case reflect.Float32:
return &Schema{Type: TypeNumber, Format: "float"}, nil
case reflect.Float64, reflect.Uint, reflect.Uint64, reflect.Uintptr:
return &Schema{Type: TypeNumber, Format: "double"}, nil
case reflect.Slice, reflect.Array:
elemSchema, err := typeSchema(t.Elem())
if err != nil {
return nil, err
}
return &Schema{Type: TypeArray, Items: elemSchema}, nil
case reflect.Pointer:
// Treat a *T as a nullable T.
s, err := typeSchema(t.Elem())
if err != nil {
return nil, err
}
s.Nullable = true
return s, nil
default:
return nil, errors.New("not supported")
}
}
69 changes: 69 additions & 0 deletions genai/functions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package genai

import (
"reflect"
"testing"
)

var intSchema = &Schema{Type: TypeInteger, Format: "int64"}

func TestTypeSchema(t *testing.T) {
for _, test := range []struct {
in any
want *Schema
}{
{true, &Schema{Type: TypeBoolean}},
{"", &Schema{Type: TypeString}},
{1, intSchema},
{byte(1), &Schema{Type: TypeInteger, Format: "int32"}},
{1.2, &Schema{Type: TypeNumber, Format: "double"}},
{float32(1.2), &Schema{Type: TypeNumber, Format: "float"}},
{new(int), &Schema{Type: TypeInteger, Format: "int64", Nullable: true}},
{
[]int{},
&Schema{Type: TypeArray, Items: intSchema},
},
} {
got, err := typeSchema(reflect.TypeOf(test.in))
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, test.want) {
t.Errorf("%T:\ngot %+v\nwant %+v", test.in, got, test.want)
}
}
}

func TestInferSchema(t *testing.T) {
f := func(a int, b string, c float64) int { return 0 }
got, err := inferSchema(f, []string{"a", "b"})
if err != nil {
t.Fatal(err)
}
want := &Schema{
Type: TypeObject,
Properties: map[string]*Schema{
"a": intSchema,
"b": {Type: TypeString},
"p2": {Type: TypeNumber, Format: "double"},
},
Required: []string{"a", "b", "p2"},
}
if !reflect.DeepEqual(got, want) {
t.Errorf("\ngot %+v\nwant %+v", got, want)
}
}

func TestInferSchemaErrors(t *testing.T) {
for i, f := range []any{
nil,
3, // not a function
func(x ...int) {}, // variadic
func(x any) {}, // unsupported type
} {
_, err := inferSchema(f, nil)
if err == nil {
t.Errorf("#%d: got nil, want error", i)
}
}
}
Loading
Loading