diff --git a/README.md b/README.md index 09a8ed7..40c44ae 100644 --- a/README.md +++ b/README.md @@ -403,7 +403,7 @@ func main() { The following examlpe describes the statistics of an index by name. -```Go +```go package main import ( @@ -781,7 +781,7 @@ func main() { The following example deletes a vector by its ID value from `example-index` and `example-namespace`. You can pass a slice of vector IDs to `DeleteVectorsById`. -```Go +```go package main import ( @@ -829,7 +829,7 @@ func main() { The following example deletes vectors from `example-index` using a metadata filter. -```Go +```go package main import ( @@ -884,7 +884,7 @@ func main() { The following example deletes all vectors from `example-index` and `example-namespace`. -```Go +```go package main import ( @@ -1012,7 +1012,7 @@ func main() { The following example updates vectors by ID in `example-index` and `example-namespace`. -```Go +```go package main import ( @@ -1145,7 +1145,7 @@ Collections are only available for pod-based indexes. The following example creates a collection from a source index. -```Go +```go package main import ( @@ -1186,7 +1186,7 @@ func main() { The following example lists all collections in your Pinecone project. -```Go +```go package main import ( @@ -1231,7 +1231,7 @@ func main() { The following example describes a collection by name. -```Go +```go package main import ( @@ -1269,7 +1269,7 @@ func main() { The following example deletes a collection by name. -```Go +```go package main import ( @@ -1305,6 +1305,73 @@ func main() { } ``` +## Inference + +The `Client` object has an `Inference` namespace which allows interacting with Pinecone's [Inference API](https://docs.pinecone.io/reference/api/2024-07/inference/generate-embeddings). The Inference API is a service that gives you access to embedding models hosted on Pinecone's infrastructure. Read more at [Understanding Pinecone Inference](https://docs.pinecone.io/guides/inference/understanding-inference). + +**Notes:** + +Models currently supported: + +- [multilingual-e5-large](https://docs.pinecone.io/guides/inference/understanding-inference#embedding-models) + +### Create Embeddings + +Send text to Pinecone's inference API to generate embeddings for documents and queries. + +```go + ctx := context.Background() + + pc, err := pinecone.NewClient(pinecone.NewClientParams{ + ApiKey: "YOUR_API_KEY", + }) + if err != nil { + log.Fatalf("Failed to create Client: %v", err) + } + + embeddingModel := "multilingual-e5-large" + documents := []string{ + "Turkey is a classic meat to eat at American Thanksgiving." + "Many people enjoy the beautiful mosques in Turkey." + } + docParameters := pinecone.EmbedParameters{ + InputType: "passage", + Truncate: "END", + } + + docEmbeddingsResponse, err := pc.Inference.Embed(ctx, &pinecone.EmbedRequest{ + Model: embeddingModel, + TextInputs: documents, + Parameters: docParameters, + }) + if err != nil { + log.Fatalf("Failed to embed documents: %v", err) + } + fmt.Printf("docs embedding response: %+v", docEmbeddingsResponse) + + // << Upsert documents into Pinecone >> + + userQuery := []string{ + "How should I prepare my turkey?" + } + queryParameters := pinecone.EmbedParameters{ + InputType: "query", + Truncate: "END", + } + queryEmbeddingsResponse, err := pc.Inference.Embed(ctx, &pinecone.EmbedRequest{ + Model: embeddingModel, + TextInputs: userQuery, + Parameters: queryParameters + }) + if err != nil { + log.Fatalf("Failed to embed query: %v", err) + } + fmt.Printf("query embedding response: %+v", queryEmbeddingsResponse) + + // << Send query to Pinecone to retrieve similar documents >> + +``` + ## Support To get help using go-pinecone you can file an issue on [GitHub](https://github.com/pinecone-io/go-pinecone/issues), visit the [community forum](https://community.pinecone.io/), diff --git a/pinecone/client.go b/pinecone/client.go index b6b1381..261abba 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -24,15 +24,16 @@ import ( // Client holds the parameters for connecting to the Pinecone service. It is returned by the NewClient and NewClientBase // functions. To use Client, first build the parameters of the request using NewClientParams (or NewClientBaseParams). // Then, pass those parameters into the NewClient (or NewClientBase) function to create a new Client object. -// Once instantiated, you can use Client to execute control plane API requests (e.g. create an Index, list Indexes, -// etc.). Read more about different control plane API routes at [docs.pinecone.io/reference/api]. +// Once instantiated, you can use Client to execute Pinecone API requests (e.g. create an Index, list Indexes, +// etc.), and Inference API requests. Read more about different Pinecone API routes at [docs.pinecone.io/reference/api]. // // Note: Client methods are safe for concurrent use. // // Fields: -// - headers: An optional map of additional HTTP headers to include in each API request to the control plane, -// provided through NewClientParams.Headers or NewClientBaseParams.Headers. -// - restClient: Optional underlying *http.Client object used to communicate with the Pinecone control plane API, +// - Inference: An InferenceService object that exposes methods for interacting with the Pinecone [Inference API]. +// - headers: An optional map of HTTP headers to include in each API request, provided through +// NewClientParams.Headers or NewClientBaseParams.Headers. +// - restClient: Optional underlying *http.Client object used to communicate with the Pinecone API, // provided through NewClientParams.RestClient or NewClientBaseParams.RestClient. If not provided, // a default client is created for you. // - sourceTag: An optional string used to help Pinecone attribute API activity, provided through NewClientParams.SourceTag @@ -67,7 +68,9 @@ import ( // } // // [docs.pinecone.io/reference/api]: https://docs.pinecone.io/reference/api/control-plane/list_indexes +// [Inference API]: https://docs.pinecone.io/reference/api/2024-07/inference/generate-embeddings type Client struct { + Inference *InferenceService headers map[string]string restClient *control.Client sourceTag string @@ -76,12 +79,11 @@ type Client struct { // NewClientParams holds the parameters for creating a new Client instance while authenticating via an API key. // // Fields: -// - ApiKey: (Required) The API key used to authenticate with the Pinecone control plane API. +// - ApiKey: (Required) The API key used to authenticate with the Pinecone API. // This value must be passed by the user unless it is set as an environment variable ("PINECONE_API_KEY"). -// - Headers: An optional map of additional HTTP headers to include in each API request to the control plane. -// - Host: (Optional) The host URL of the Pinecone control plane API. If not provided, -// the default value is "https://api.pinecone.io". -// - RestClient: An optional HTTP client to use for communication with the control plane API. +// - Headers: (Optional) An optional map of HTTP headers to include in each API request. +// - Host: (Optional) The host URL of the Pinecone API. If not provided, the default value is "https://api.pinecone.io". +// - RestClient: An optional HTTP client to use for communication with the Pinecone API. // - SourceTag: An optional string used to help Pinecone attribute API activity. // // See Client for code example. @@ -94,15 +96,15 @@ type NewClientParams struct { } // NewClientBaseParams holds the parameters for creating a new Client instance while passing custom authentication -// headers. +// headers. If there is no API key or authentication provided through Headers, API calls will fail. // // Fields: -// - Headers: An optional map of additional HTTP headers to include in each API request to the control plane. +// - Headers: (Optional) A map of HTTP headers to include in each API request. // "Authorization" and "X-Project-Id" headers are required if authenticating using a JWT. -// - Host: (Optional) The host URL of the Pinecone control plane API. If not provided, +// - Host: (Optional) The host URL of the Pinecone API. If not provided, // the default value is "https://api.pinecone.io". -// - RestClient: An optional *http.Client object to use for communication with the control plane API. -// - SourceTag: An optional string used to help Pinecone attribute API activity. +// - RestClient: (Optional) An *http.Client object to use for communication with the Pinecone API. +// - SourceTag: (Optional) A string used to help Pinecone attribute API activity. // // See Client for code example. type NewClientBaseParams struct { @@ -117,8 +119,8 @@ type NewClientBaseParams struct { // Fields: // - Host: (Required) The host URL of the Pinecone index. To find your host url use the DescribeIndex or ListIndexes methods. // Alternatively, the host is displayed in the Pinecone web console. -// - Namespace: Optional index namespace to use for operations. If not provided, the default namespace of "" will be used. -// - AdditionalMetadata: Optional additional metadata to be sent with each RPC request. +// - Namespace: (Optional) The index namespace to use for operations. If not provided, the default namespace of "" will be used. +// - AdditionalMetadata: (Optional) Metadata to be sent with each RPC request. // // See Client.Index for code example. type NewIndexConnParams struct { @@ -128,13 +130,13 @@ type NewIndexConnParams struct { } // NewClient creates and initializes a new instance of Client. -// This function sets up the control plane client with the necessary configuration for authentication and communication. +// This function sets up the Pinecone client with the necessary configuration for authentication and communication. // // Parameters: // - in: A NewClientParams object. See NewClientParams for more information. // // Note: It is important to handle the error returned by this function to ensure that the -// control plane client has been created successfully before attempting to make API calls. +// Pinecone client has been created successfully before attempting to make API calls. // // Returns a pointer to an initialized Client instance or an error. // @@ -178,12 +180,12 @@ func NewClient(in NewClientParams) (*Client, error) { // NewClientBase creates and initializes a new instance of Client with custom authentication headers. // // Parameters: -// - in: A NewClientBaseParams object that includes the necessary configuration for the control plane client. See +// - in: A NewClientBaseParams object that includes the necessary configuration for the Pinecone client. See // NewClientBaseParams for more information. // // Notes: // - It is important to handle the error returned by this function to ensure that the -// control plane client has been created successfully before attempting to make API calls. +// Pinecone client has been created successfully before attempting to make API calls. // - A Pinecone API key is not required when using NewClientBase. // // Returns a pointer to an initialized Client instance or an error. @@ -223,7 +225,7 @@ func NewClientBase(in NewClientBaseParams) (*Client, error) { return nil, err } - c := Client{restClient: client, sourceTag: in.SourceTag, headers: in.Headers} + c := Client{Inference: &InferenceService{client: client}, restClient: client, sourceTag: in.SourceTag, headers: in.Headers} return &c, nil } @@ -569,7 +571,7 @@ func (c *Client) CreatePodIndex(ctx context.Context, in *CreatePodIndexRequest) // - Cloud: (Required) The public [cloud provider] where you would like your Index hosted. // For serverless Indexes, you define only the cloud and region where the Index should be hosted. // - Region: (Required) The [region] where you would like your Index to be created. -// - DeletionProtection: (Optional) determines whether [deletion protection] is "enabled" or "disabled" for the index. +// - DeletionProtection: (Optional) Determines whether [deletion protection] is "enabled" or "disabled" for the index. // When "enabled", the index cannot be deleted. Defaults to "disabled". // // To create a new Serverless Index, use the CreateServerlessIndex method on the Client object. @@ -1199,6 +1201,124 @@ func (c *Client) DeleteCollection(ctx context.Context, collectionName string) er return nil } +// EmbedRequest holds the parameters for generating embeddings for a list of input strings. +// +// Fields: +// - Model: (Required) The model to use for generating embeddings. +// - TextInputs: (Required) A list of strings to generate embeddings for. +// - Parameters: (Optional) EmbedParameters object that contains additional parameters to use when generating embeddings. +type EmbedRequest struct { + Model string + TextInputs []string + Parameters EmbedParameters +} + +// EmbedParameters contains model-specific parameters that can be used for generating embeddings. +// +// Fields: +// - InputType: (Optional) A common property used to distinguish between different types of data. For example, "passage", or "query". +// - Truncate: (Optional) How to handle inputs longer than those supported by the model. if "NONE", when the input exceeds +// the maximum input token length, an error will be returned. +type EmbedParameters struct { + InputType string + Truncate string +} + +// InferenceService is a struct which exposes methods for interacting with the Pinecone Inference API. InferenceService +// can be accessed via the Client object through the Client.Inference namespace. +// +// [Pinecone Inference API]: https://docs.pinecone.io/guides/inference/understanding-inference#embedding-models +type InferenceService struct { + client *control.Client +} + +// Embed generates embeddings for a list of inputs using the specified model and (optional) parameters. +// +// Parameters: +// - ctx: A context.Context object controls the request's lifetime, allowing for the request +// to be canceled or to timeout according to the context's deadline. +// - in: A pointer to an EmbedRequest object that contains the model t4o use for embedding generation, the +// list of input strings to generate embeddings for, and any additional parameters to use for generation. +// +// Returns a pointer to an EmbeddingsList object or an error. +// +// Example: +// +// ctx := context.Background() +// +// clientParams := pinecone.NewClientParams{ +// ApiKey: "YOUR_API_KEY", +// SourceTag: "your_source_identifier", // optional +// } +// +// pc, err := pinecone.NewClient(clientParams) +// +// if err != nil { +// log.Fatalf("Failed to create Client: %v", err) +// } else { +// fmt.Println("Successfully created a new Client object!") +// } +// +// in := &pinecone.EmbedRequest{ +// Model: "multilingual-e5-large", +// TextInputs: []string{"Who created the first computer?"}, +// Parameters: pinecone.EmbedParameters{ +// InputType: "passage", +// Truncate: "END", +// }, +// } +// +// res, err := pc.Inference.Embed(ctx, in) +// if err != nil { +// log.Fatalf("Failed to embed: %v", err) +// } else { +// fmt.Printf("Successfull generated embeddings: %+v", res) +// } +func (i *InferenceService) Embed(ctx context.Context, in *EmbedRequest) (*control.EmbeddingsList, error) { + + if len(in.TextInputs) == 0 { + return nil, fmt.Errorf("TextInputs must contain at least one value") + } + + // Convert text inputs to the expected type + convertedInputs := make([]struct { + Text *string `json:"text,omitempty"` + }, len(in.TextInputs)) + for i, input := range in.TextInputs { + convertedInputs[i] = struct { + Text *string `json:"text,omitempty"` + }{Text: &input} + } + + req := control.EmbedRequest{ + Model: in.Model, + Inputs: convertedInputs, + } + + // convert embedding parameters to expected type + if in.Parameters.InputType != "" || in.Parameters.Truncate != "" { + req.Parameters = &struct { + InputType *string `json:"input_type,omitempty"` + Truncate *string `json:"truncate,omitempty"` + }{ + InputType: pointerOrNil(in.Parameters.InputType), + Truncate: pointerOrNil(in.Parameters.Truncate), + } + } + + res, err := i.client.Embed(ctx, req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return nil, handleErrorResponseBody(res, "failed to embed: ") + } + + return decodeEmbeddingsList(res.Body) +} + func (c *Client) extractAuthHeader() map[string]string { possibleAuthKeys := []string{ "api-key", @@ -1269,6 +1389,16 @@ func decodeIndex(resBody io.ReadCloser) (*Index, error) { return toIndex(&idx), nil } +func decodeEmbeddingsList(resBody io.ReadCloser) (*control.EmbeddingsList, error) { + var embeddingsList control.EmbeddingsList + err := json.NewDecoder(resBody).Decode(&embeddingsList) + if err != nil { + return nil, fmt.Errorf("failed to decode embeddings response: %w", err) + } + + return &embeddingsList, nil +} + func toCollection(cm *control.CollectionModel) *Collection { if cm == nil { return nil diff --git a/pinecone/client_test.go b/pinecone/client_test.go index 4f286fb..de27d83 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -10,6 +10,7 @@ import ( "reflect" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/pinecone-io/go-pinecone/internal/gen" @@ -40,7 +41,7 @@ func (ts *IntegrationTests) TestCreatePodIndex() { idx, err := ts.client.CreatePodIndex(context.Background(), &CreatePodIndexRequest{ Name: name, - Dimension: 10, + Dimension: 2, Metric: Cosine, Environment: "us-east1-gcp", PodType: "p1.x1", @@ -180,7 +181,7 @@ func (ts *IntegrationTests) TestConfigureIndexIllegalScaleDown() { _, err := ts.client.CreatePodIndex(context.Background(), &CreatePodIndexRequest{ Name: name, - Dimension: 10, + Dimension: 2, Metric: Cosine, Environment: "us-east1-gcp", PodType: "p1.x2", @@ -196,14 +197,9 @@ func (ts *IntegrationTests) TestConfigureIndexIllegalScaleDown() { func (ts *IntegrationTests) TestConfigureIndexScaleUpNoPods() { name := uuid.New().String() - defer func(ts *IntegrationTests, name string) { - err := ts.deleteIndex(name) - require.NoError(ts.T(), err) - }(ts, name) - _, err := ts.client.CreatePodIndex(context.Background(), &CreatePodIndexRequest{ Name: name, - Dimension: 10, + Dimension: 2, Metric: Cosine, Environment: "us-east1-gcp", PodType: "p1.x2", @@ -215,22 +211,22 @@ func (ts *IntegrationTests) TestConfigureIndexScaleUpNoPods() { _, err = ts.client.ConfigureIndex(context.Background(), name, ConfigureIndexParams{Replicas: 2}) require.NoError(ts.T(), err) - // Before moving on to another test, wait for the index to be done upgrading - _, err = WaitUntilIndexReady(ts, context.Background()) + // give index a bit of time to start upgrading before we poll + time.Sleep(500 * time.Millisecond) + + isReady, _ := WaitUntilIndexReady(ts, context.Background()) + require.True(ts.T(), isReady, "Expected index to be ready") + + err = ts.client.DeleteIndex(context.Background(), name) require.NoError(ts.T(), err) } func (ts *IntegrationTests) TestConfigureIndexScaleUpNoReplicas() { name := uuid.New().String() - defer func(ts *IntegrationTests, name string) { - err := ts.deleteIndex(name) - require.NoError(ts.T(), err) - }(ts, name) - _, err := ts.client.CreatePodIndex(context.Background(), &CreatePodIndexRequest{ Name: name, - Dimension: 10, + Dimension: 2, Metric: Cosine, Environment: "us-east1-gcp", PodType: "p1.x2", @@ -242,8 +238,13 @@ func (ts *IntegrationTests) TestConfigureIndexScaleUpNoReplicas() { _, err = ts.client.ConfigureIndex(context.Background(), name, ConfigureIndexParams{PodType: "p1.x4"}) require.NoError(ts.T(), err) - // Before moving on to another test, wait for the index to be done upgrading - _, err = WaitUntilIndexReady(ts, context.Background()) + // give index a bit of time to start upgrading before we poll + time.Sleep(500 * time.Millisecond) + + isReady, _ := WaitUntilIndexReady(ts, context.Background()) + require.True(ts.T(), isReady, "Expected index to be ready") + + err = ts.client.DeleteIndex(context.Background(), name) require.NoError(ts.T(), err) } @@ -262,7 +263,7 @@ func (ts *IntegrationTests) TestConfigureIndexHitPodLimit() { _, err := ts.client.CreatePodIndex(context.Background(), &CreatePodIndexRequest{ Name: name, - Dimension: 10, + Dimension: 2, Metric: Cosine, Environment: "us-east1-gcp", PodType: "p1.x2", @@ -275,6 +276,43 @@ func (ts *IntegrationTests) TestConfigureIndexHitPodLimit() { require.ErrorContainsf(ts.T(), err, "You've reached the max pods allowed", err.Error()) } +func (ts *IntegrationTests) TestGenerateEmbeddings() { + ctx := context.Background() + embeddingModel := "multilingual-e5-large" + embeddings, err := ts.client.Inference.Embed(ctx, &EmbedRequest{ + Model: embeddingModel, + TextInputs: []string{ + "The quick brown fox jumps over the lazy dog", + "Lorem ipsum", + }, + Parameters: EmbedParameters{ + InputType: "query", + Truncate: "END", + }, + }) + + require.NoError(ts.T(), err) + require.NotNil(ts.T(), embeddings, "Expected embedding to be non-nil") + require.Equal(ts.T(), embeddingModel, *embeddings.Model, "Expected model to be '%s', but got '%s'", embeddingModel, embeddings.Model) + require.Equal(ts.T(), 2, len(*embeddings.Data), "Expected 2 embeddings") + require.Equal(ts.T(), 1024, len(*(*embeddings.Data)[0].Values), "Expected embeddings to have length 1024") +} + +func (ts *IntegrationTests) TestGenerateEmbeddingsInvalidInputs() { + ctx := context.Background() + embeddingModel := "multilingual-e5-large" + _, err := ts.client.Inference.Embed(ctx, &EmbedRequest{ + Model: embeddingModel, + Parameters: EmbedParameters{ + InputType: "query", + Truncate: "END", + }, + }) + + require.Error(ts.T(), err) + require.Contains(ts.T(), err.Error(), "TextInputs must contain at least one value") +} + // Unit tests: func TestExtractAuthHeaderUnit(t *testing.T) { globalApiKey := os.Getenv("PINECONE_API_KEY") @@ -1162,6 +1200,9 @@ func TestBuildClientBaseOptionsUnit(t *testing.T) { // Helper functions: func (ts *IntegrationTests) deleteIndex(name string) error { + _, err := WaitUntilIndexReady(ts, context.Background()) + require.NoError(ts.T(), err) + return ts.client.DeleteIndex(context.Background(), name) } diff --git a/pinecone/test_suite.go b/pinecone/test_suite.go index 242b599..d1ecb11 100644 --- a/pinecone/test_suite.go +++ b/pinecone/test_suite.go @@ -126,27 +126,28 @@ func createCollection(ts *IntegrationTests, ctx context.Context) { } func WaitUntilIndexReady(ts *IntegrationTests, ctx context.Context) (bool, error) { - maxRetries := 24 + start := time.Now() delay := 5 * time.Second - totalSeconds := 0 + maxWaitTimeSeconds := 280 * time.Second - for i := 0; i < maxRetries; i++ { + for { index, err := ts.client.DescribeIndex(ctx, ts.idxName) - if err != nil { - fmt.Printf("Error describing index: %v\n", err) + require.NoError(ts.T(), err) + + if index.Status.Ready && index.Status.State == Ready { + fmt.Printf("Index \"%s\" is ready after %f seconds\n", ts.idxName, time.Since(start).Seconds()) + return true, err } - if index.Status.State == Ready && index.Status.Ready { - fmt.Printf("Index \"%s\" is ready!\n", ts.idxName) - return true, nil - } else { - fmt.Printf("Index \"%s\" not ready yet, retrying... (%d/%d)\n", ts.idxName, i, maxRetries) - time.Sleep(delay) - totalSeconds += int(delay.Seconds()) + + totalSeconds := time.Since(start) + + if totalSeconds >= maxWaitTimeSeconds { + return false, fmt.Errorf("Index \"%s\" not ready after %f seconds", ts.idxName, totalSeconds.Seconds()) } - } - fmt.Printf("Index \"%s\" not ready after %d seconds\n", ts.idxName, totalSeconds) - return false, nil + fmt.Printf("Index \"%s\" not ready yet, retrying... (%f/%f)\n", ts.idxName, totalSeconds.Seconds(), maxWaitTimeSeconds.Seconds()) + time.Sleep(delay) + } } func createVectorsForUpsert() []*Vector {