Skip to content

Commit

Permalink
All tests working for index_connection now
Browse files Browse the repository at this point in the history
  • Loading branch information
aulorbe committed Jul 22, 2024
1 parent 8016f9d commit 17e7f35
Showing 1 changed file with 51 additions and 32 deletions.
83 changes: 51 additions & 32 deletions pinecone/index_connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (ts *IndexConnectionTestsIntegration) SetupSuite() {

maxRetries := 12
delay := 12 * time.Second
fmt.Printf("Attempting to populate host \"%s\" with vectors...\n", ts.host)
fmt.Printf("Attempting to upsert vectors into host \"%s\"...\n", ts.host)
for i := 0; i < maxRetries; i++ {
ready, err := getStatus(ts, ctx)
if err != nil {
Expand All @@ -160,17 +160,17 @@ func (ts *IndexConnectionTestsIntegration) SetupSuite() {
require.NoError(ts.T(), err)
ts.idxConnSourceTag = idxConnSourceTag

fmt.Printf("\n %s Setup suite completed successfully\n", ts.indexType)
fmt.Printf("\n %s set up suite completed successfully\n", ts.indexType)
}

func (ts *IndexConnectionTestsIntegration) TearDownSuite() {
// TODO: move index deletion to here
// TODO: move index deletion to here to avoid wasting resources
err := ts.idxConn.Close()
require.NoError(ts.T(), err)

err = ts.idxConnSourceTag.Close()
require.NoError(ts.T(), err)
fmt.Printf("\n %s Setup suite torn down successfully\n", ts.indexType)
fmt.Printf("\n %s setup suite torn down successfully\n", ts.indexType)
}

func (ts *IndexConnectionTestsIntegration) TestNewIndexConnection() {
Expand Down Expand Up @@ -396,62 +396,80 @@ func (ts *IndexConnectionTestsIntegration) TestMetadataAppliedToRequests() {

func (ts *IndexConnectionTestsIntegration) TestUpdateVectorValues() {
ctx := context.Background()
dims := int(ts.dimension)

podsConn := ts.idxConn
fmt.Printf("\nPodsConn... %+v", podsConn)
fmt.Printf("\nidxConn... %+v", ts.idxConn)

idxStats, _ := ts.idxConn.DescribeIndexStats(ctx)
fmt.Printf("\nIndex stats: %+v", idxStats)
fmt.Printf("\nHost: %s", ts.host)

err := podsConn.UpdateVector(ctx, &UpdateVectorRequest{
expectedVals := []float32{7.2, 7.2, 7.2, 7.2, 7.2}
err := ts.idxConn.UpdateVector(ctx, &UpdateVectorRequest{
Id: ts.vectorIds[0],
Values: generateFloat32Array(dims),
Values: expectedVals,
})
assert.NoError(ts.T(), err)

time.Sleep(5 * time.Second)

vector, err := ts.idxConn.FetchVectors(ctx, []string{ts.vectorIds[0]})
if err != nil {
ts.FailNow(fmt.Sprintf("Failed to fetch vector: %v", err))
}
actualVals := vector.Vectors[ts.vectorIds[0]].Values

assert.ElementsMatch(ts.T(), expectedVals, actualVals, "Values do not match")
}

func (ts *IndexConnectionTestsIntegration) TestUpdateVectorMetadata() {
ctx := context.Background()

metadataMap := map[string]interface{}{
"genre": "classical",
expectedMetadata := map[string]interface{}{
"genre": "death-metal",
}

metadataForUpdate, err := structpb.NewStruct(metadataMap)
expectedMetadataMap, err := structpb.NewStruct(expectedMetadata)

err = ts.idxConn.UpdateVector(ctx, &UpdateVectorRequest{
Id: ts.vectorIds[0],
Metadata: metadataForUpdate,
Metadata: expectedMetadataMap,
})
assert.NoError(ts.T(), err)

time.Sleep(5 * time.Second)

vector, err := ts.idxConn.FetchVectors(ctx, []string{ts.vectorIds[0]})
if err != nil {
ts.FailNow(fmt.Sprintf("Failed to fetch vector: %v", err))
}

expectedGenre := expectedMetadataMap.Fields["genre"].GetStringValue()
actualGenre := vector.Vectors[ts.vectorIds[0]].Metadata.Fields["genre"].GetStringValue()

assert.Equal(ts.T(), expectedGenre, actualGenre, "Metadata does not match")
}

func (ts *IndexConnectionTestsIntegration) TestUpdateVectorSparseValues() {
ctx := context.Background()

dims := int(ts.dimension)
generatedSparseIndices := generateUint32Array(dims)
generatedSparseValues := generateFloat32Array(dims)
indices := generateUint32Array(dims)
vals := generateFloat32Array(dims)
expectedSparseValues := SparseValues{
Indices: indices,
Values: vals,
}

err := ts.idxConn.UpdateVector(ctx, &UpdateVectorRequest{
Id: ts.vectorIds[0],
SparseValues: &SparseValues{
Indices: generatedSparseIndices,
Values: generatedSparseValues,
},
Id: ts.vectorIds[0],
SparseValues: &expectedSparseValues,
})
assert.NoError(ts.T(), err)

fmt.Printf("Vector ID is: %v\n", ts.vectorIds[0])
time.Sleep(5 * time.Second)

vector, err := ts.idxConn.FetchVectors(ctx, []string{ts.vectorIds[0]})
fmt.Printf("Ignore me %v", &vector)
//fmt.Printf("Generated sparse values: %v\n", generatedSparseValues)
//assert.Equal(ts.T(), vector.Vectors[ts.vectorIds[0]].SparseValues.Values, generatedValues)
if err != nil {
ts.FailNow(fmt.Sprintf("Failed to fetch vector: %v", err))
}
actualSparseValues := vector.Vectors[ts.vectorIds[0]].SparseValues.Values

assert.ElementsMatch(ts.T(), expectedSparseValues.Values, actualSparseValues, "Sparse values do not match")
}

// TODO: necessary?
func generateFloat32Array(n int) []float32 {
array := make([]float32, n)
for i := 0; i < n; i++ {
Expand All @@ -460,6 +478,7 @@ func generateFloat32Array(n int) []float32 {
return array
}

// TODO: necessary?
func generateUint32Array(n int) []uint32 {
array := make([]uint32, n)
for i := 0; i < n; i++ {
Expand Down

0 comments on commit 17e7f35

Please sign in to comment.