diff --git a/genai/client_test.go b/genai/client_test.go index 57cf748..1bbc260 100644 --- a/genai/client_test.go +++ b/genai/client_test.go @@ -297,13 +297,9 @@ func TestLive(t *testing.T) { }) t.Run("tools", func(t *testing.T) { - weatherChat := func(t *testing.T, s *Schema) { + weatherChat := func(t *testing.T, fd *FunctionDeclaration) { weatherTool := &Tool{ - FunctionDeclarations: []*FunctionDeclaration{{ - Name: "CurrentWeather", - Description: "Get the current weather in a given location", - Parameters: s, - }}, + FunctionDeclarations: []*FunctionDeclaration{fd}, } model := client.GenerativeModel(*modelName) model.SetTemperature(0) @@ -341,20 +337,35 @@ func TestLive(t *testing.T) { } t.Run("direct", func(t *testing.T) { - weatherChat(t, &Schema{ - Type: TypeObject, - Properties: map[string]*Schema{ - "location": { - Type: TypeString, - Description: "The city and state, e.g. San Francisco, CA", - }, - "unit": { - Type: TypeString, - Enum: []string{"celsius", "fahrenheit"}, + fd := &FunctionDeclaration{ + Name: "CurrentWeather", + Description: "Get the current weather in a given location", + Parameters: &Schema{ + Type: TypeObject, + Properties: map[string]*Schema{ + "location": { + Type: TypeString, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: TypeString, + Enum: []string{"celsius", "fahrenheit"}, + }, }, + Required: []string{"location"}, }, - Required: []string{"location"}, - }) + } + weatherChat(t, fd) + }) + t.Run("inferred", func(t *testing.T) { + fds, err := NewCallableFunctionDeclaration( + "CurrentWeather", + "Get the current weather in a given location", + func(string) {}, "location") + if err != nil { + t.Fatal(err) + } + weatherChat(t, fds) }) }) } diff --git a/genai/config.yaml b/genai/config.yaml index f39be99..6e6accd 100644 --- a/genai/config.yaml +++ b/genai/config.yaml @@ -117,7 +117,6 @@ types: # Types for function calling Tool: - FunctionDeclaration: FunctionCall: FunctionResponse: Schema: diff --git a/genai/doc.go b/genai/doc.go index f3d7c90..2638fd4 100644 --- a/genai/doc.go +++ b/genai/doc.go @@ -23,6 +23,22 @@ // You will need an API key to use the service. // See the [setup tutorial] for details. // +// # Tools +// +// Gemini can call functions if you tell it about them. +// Create FunctionDeclarations, add them to a Tool, and install the Tool in a Model. +// When used in a ChatSession, the content returned from a model may include FunctionCall +// parts. Your code performs the requested call and sends back a FunctionResponse. +// See The example for Tool +// +// To have the SDK call a Go function for you, assign it to the FunctionDeclaration.Function. +// field. A ChatSession will look for FunctionCalls, invoke the function you supply, and reply +// with a FunctionResponse. Your code will see only the final result. +// +// The NewCallableFunctionDeclaration function will infer the schema for a function you supply, +// and create a FunctionDeclaration that exposes that function for automatic calling. +// See the example for NewCallableFunctionDeclaration. +// // # Errors // // [examples]: https://pkg.go.dev/github.com/google/generative-ai-go/genai#pkg-examples diff --git a/genai/example_test.go b/genai/example_test.go index 0f9d8f3..5a3c523 100644 --- a/genai/example_test.go +++ b/genai/example_test.go @@ -361,6 +361,42 @@ func ExampleTool() { printResponse(res) } +func ExampleNewCallableFunctionDeclaration() { + // Define a function to use as a tool. + weatherToday := func(city string) string { + return "comfortable, if you have the right clothing" + } + // You can construct the Schema for this function by hand, or + // let Go reflection figure it out. + // This also makes the function automatically callable. + // Reflection can't see parameter names, so provide those too. + fd, err := genai.NewCallableFunctionDeclaration("CurrentWeather", "Get the current weather in a given location", weatherToday, "city") + if err != nil { + // Not every type can be used in a tool function. + panic(err) + } + + ctx := context.Background() + client, err := genai.NewClient(ctx, option.WithAPIKey(os.Getenv("GEMINI_API_KEY"))) + if err != nil { + log.Fatal(err) + } + defer client.Close() + + // Use the FunctionDeclaration to populate Model.Tools. + model := client.GenerativeModel("gemini-1.0-pro") + + // Before initiating a conversation, we tell the model which tools it has + // at its disposal. + weatherTool := &genai.Tool{ + FunctionDeclarations: []*genai.FunctionDeclaration{fd}, + } + + model.Tools = []*genai.Tool{weatherTool} + + // Now use the model in a ChatSession; see [ExampleTool]. +} + func printResponse(resp *genai.GenerateContentResponse) { for _, cand := range resp.Candidates { if cand.Content != nil { diff --git a/genai/generativelanguagepb_veneer.gen.go b/genai/generativelanguagepb_veneer.gen.go index 191b179..f692768 100644 --- a/genai/generativelanguagepb_veneer.gen.go +++ b/genai/generativelanguagepb_veneer.gen.go @@ -400,47 +400,6 @@ func (FunctionCall) fromProto(p *pb.FunctionCall) *FunctionCall { } } -// FunctionDeclaration is structured representation of a function declaration as defined by the -// [OpenAPI 3.03 specification](https://spec.openapis.org/oas/v3.0.3). Included -// in this declaration are the function name and parameters. This -// FunctionDeclaration is a representation of a block of code that can be used -// as a `Tool` by the model and executed by the client. -type FunctionDeclaration struct { - // Required. The name of the function. - // Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum - // length of 63. - Name string - // Required. A brief description of the function. - Description string - // Optional. Describes the parameters to this function. Reflects the Open - // API 3.03 Parameter Object string Key: the name of the parameter. Parameter - // names are case sensitive. Schema Value: the Schema defining the type used - // for the parameter. - Parameters *Schema -} - -func (v *FunctionDeclaration) toProto() *pb.FunctionDeclaration { - if v == nil { - return nil - } - return &pb.FunctionDeclaration{ - Name: v.Name, - Description: v.Description, - Parameters: v.Parameters.toProto(), - } -} - -func (FunctionDeclaration) fromProto(p *pb.FunctionDeclaration) *FunctionDeclaration { - if p == nil { - return nil - } - return &FunctionDeclaration{ - Name: p.Name, - Description: p.Description, - Parameters: (Schema{}).fromProto(p.Parameters), - } -} - // FunctionResponse is the result output from a `FunctionCall` that contains a string // representing the `FunctionDeclaration.name` and a structured JSON // object containing any output from the function is used as context to @@ -1005,44 +964,6 @@ func (v TaskType) String() string { return fmt.Sprintf("TaskType(%d)", v) } -// Tool details that the model may use to generate response. -// -// A `Tool` is a piece of code that enables the system to interact with -// external systems to perform an action, or set of actions, outside of -// knowledge and scope of the model. -type Tool struct { - // Optional. A list of `FunctionDeclarations` available to the model that can - // be used for function calling. - // - // The model or system does not execute the function. Instead the defined - // function may be returned as a [FunctionCall][content.part.function_call] - // with arguments to the client side for execution. The model may decide to - // call a subset of these functions by populating - // [FunctionCall][content.part.function_call] in the response. The next - // conversation turn may contain a - // [FunctionResponse][content.part.function_response] - // with the [content.role] "function" generation context for the next model - // turn. - FunctionDeclarations []*FunctionDeclaration -} - -func (v *Tool) toProto() *pb.Tool { - if v == nil { - return nil - } - return &pb.Tool{ - FunctionDeclarations: support.TransformSlice(v.FunctionDeclarations, (*FunctionDeclaration).toProto), - } -} - -func (Tool) fromProto(p *pb.Tool) *Tool { - if p == nil { - return nil - } - return &Tool{ - FunctionDeclarations: support.TransformSlice(p.FunctionDeclarations, (FunctionDeclaration{}).fromProto), - } -} // Type contains the list of OpenAPI data types as defined by // https://spec.openapis.org/oas/v3.0.3#data-types diff --git a/genai/tools.go b/genai/tools.go new file mode 100644 index 0000000..da8dd39 --- /dev/null +++ b/genai/tools.go @@ -0,0 +1,186 @@ +package genai + +import ( + "errors" + "fmt" + "reflect" + + pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb" + "github.com/google/generative-ai-go/internal/support" +) + +// A Tool is a piece of code that enables the system to interact with +// external systems to perform an action, or set of actions, outside of +// knowledge and scope of the model. +type Tool struct { + // A list of `FunctionDeclarations` available to the model that can + // be used for function calling. + // + // The model or system does not execute the function. Instead the defined + // function may be returned as a [FunctionCall] + // with arguments to the client side for execution. The model may decide to + // call a subset of these functions by populating + // [FunctionCall][content.part.function_call] in the response. The next + // conversation turn may contain a + // [FunctionResponse][content.part.function_response] + // with the [content.role] "function" generation context for the next model + // turn. + FunctionDeclarations []*FunctionDeclaration +} + +func (v *Tool) toProto() *pb.Tool { + if v == nil { + return nil + } + return &pb.Tool{ + FunctionDeclarations: support.TransformSlice(v.FunctionDeclarations, (*FunctionDeclaration).toProto), + } +} + +func (Tool) fromProto(p *pb.Tool) *Tool { + if p == nil { + return nil + } + return &Tool{ + FunctionDeclarations: support.TransformSlice(p.FunctionDeclarations, (FunctionDeclaration{}).fromProto), + } +} + +// FunctionDeclaration is structured representation of a function declaration as defined by the +// [OpenAPI 3.03 specification](https://spec.openapis.org/oas/v3.0.3). Included +// in this declaration are the function name and parameters. +// Combine FunctionDeclarations into Tools for use in a [ChatSession]. +type FunctionDeclaration struct { + // Required. The name of the function. + // Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum + // length of 63. + Name string + // Required. A brief description of the function. + Description string + // Optional. Describes the parameters to this function. + Parameters *Schema + // If set the Go function to call automatically. Its signature must match + // the schema. Call [NewCallableFunctionDeclaration] to create a FunctionDeclaration + // with schema inferred from the function itself. + Function any +} + +func (v *FunctionDeclaration) toProto() *pb.FunctionDeclaration { + if v == nil { + return nil + } + return &pb.FunctionDeclaration{ + Name: v.Name, + Description: v.Description, + Parameters: v.Parameters.toProto(), + } +} + +func (FunctionDeclaration) fromProto(p *pb.FunctionDeclaration) *FunctionDeclaration { + if p == nil { + return nil + } + return &FunctionDeclaration{ + Name: p.Name, + Description: p.Description, + Parameters: (Schema{}).fromProto(p.Parameters), + } +} + +// NewCallableFunctionDeclaration creates a [FunctionDeclaration] from a Go +// function. When added to a [ChatSession], the function will be called +// automatically when the model requests it. +// +// This function infers the schema ([FunctionDeclaration.Parameters]) from the +// function. Not all functions can be represented as Schemas. +// At present, variadic functions are not supported, and parameters +// must be of builtin, pointer, slice or array type. +// An error is returned if the schema cannot be inferred. +// It may still be possible to construct a usable schema for the function; if so, +// build a [FunctionDeclaration] by hand, setting its exported fields. +// +// Parameter names are not available to the program. They can be supplied +// as arguments. If omitted, the names "p0", "p1", ... are used. +func NewCallableFunctionDeclaration(name, description string, function any, paramNames ...string) (*FunctionDeclaration, error) { + schema, err := inferSchema(function, paramNames) + if err != nil { + return nil, err + } + return &FunctionDeclaration{ + Name: name, + Description: description, + Parameters: schema, + Function: function, + }, nil +} + +func inferSchema(function any, paramNames []string) (*Schema, error) { + t := reflect.TypeOf(function) + if t == nil || t.Kind() != reflect.Func { + return nil, fmt.Errorf("value of type %T is not a function", function) + } + if t.IsVariadic() { + return nil, errors.New("variadic functions not supported") + } + params := map[string]*Schema{} + var req []string + for i := 0; i < t.NumIn(); i++ { + var name string + if i < len(paramNames) { + name = paramNames[i] + } else { + name = fmt.Sprintf("p%d", i) + } + s, err := typeSchema(t.In(i)) + if err != nil { + return nil, fmt.Errorf("param %s: %w", name, err) + } + params[name] = s + // All parameters are required. + req = append(req, name) + } + return &Schema{ + Type: TypeObject, + Properties: params, + Required: req, + }, nil + +} + +func typeSchema(t reflect.Type) (_ *Schema, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("%s: %w", t, err) + } + }() + switch t.Kind() { + case reflect.Bool: + return &Schema{Type: TypeBoolean}, nil + case reflect.String: + return &Schema{Type: TypeString}, nil + case reflect.Int, reflect.Int64, reflect.Uint32: + return &Schema{Type: TypeInteger, Format: "int64"}, nil + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16: + return &Schema{Type: TypeInteger, Format: "int32"}, nil + case reflect.Float32: + return &Schema{Type: TypeNumber, Format: "float"}, nil + case reflect.Float64, reflect.Uint, reflect.Uint64, reflect.Uintptr: + return &Schema{Type: TypeNumber, Format: "double"}, nil + case reflect.Slice, reflect.Array: + elemSchema, err := typeSchema(t.Elem()) + if err != nil { + return nil, err + } + return &Schema{Type: TypeArray, Items: elemSchema}, nil + case reflect.Pointer: + // Treat a *T as a nullable T. + s, err := typeSchema(t.Elem()) + if err != nil { + return nil, err + } + s.Nullable = true + return s, nil + default: + return nil, errors.New("not supported") + } +} diff --git a/genai/tools_test.go b/genai/tools_test.go new file mode 100644 index 0000000..8552895 --- /dev/null +++ b/genai/tools_test.go @@ -0,0 +1,69 @@ +package genai + +import ( + "reflect" + "testing" +) + +var intSchema = &Schema{Type: TypeInteger, Format: "int64"} + +func TestTypeSchema(t *testing.T) { + for _, test := range []struct { + in any + want *Schema + }{ + {true, &Schema{Type: TypeBoolean}}, + {"", &Schema{Type: TypeString}}, + {1, intSchema}, + {byte(1), &Schema{Type: TypeInteger, Format: "int32"}}, + {1.2, &Schema{Type: TypeNumber, Format: "double"}}, + {float32(1.2), &Schema{Type: TypeNumber, Format: "float"}}, + {new(int), &Schema{Type: TypeInteger, Format: "int64", Nullable: true}}, + { + []int{}, + &Schema{Type: TypeArray, Items: intSchema}, + }, + } { + got, err := typeSchema(reflect.TypeOf(test.in)) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("%T:\ngot %+v\nwant %+v", test.in, got, test.want) + } + } +} + +func TestInferSchema(t *testing.T) { + f := func(a int, b string, c float64) int { return 0 } + got, err := inferSchema(f, []string{"a", "b"}) + if err != nil { + t.Fatal(err) + } + want := &Schema{ + Type: TypeObject, + Properties: map[string]*Schema{ + "a": intSchema, + "b": {Type: TypeString}, + "p2": {Type: TypeNumber, Format: "double"}, + }, + Required: []string{"a", "b", "p2"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("\ngot %+v\nwant %+v", got, want) + } +} + +func TestInferSchemaErrors(t *testing.T) { + for i, f := range []any{ + nil, + 3, // not a function + func(x ...int) {}, // variadic + func(x any) {}, // unsupported type + } { + _, err := inferSchema(f, nil) + if err == nil { + t.Errorf("#%d: got nil, want error", i) + } + } +}