Skip to content

Commit

Permalink
Add client/index tests when including a source_tag
Browse files Browse the repository at this point in the history
  • Loading branch information
ssmith-pc committed Mar 27, 2024
1 parent b0ceea9 commit a8ce2fd
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 5 deletions.
27 changes: 27 additions & 0 deletions pinecone/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
type ClientTests struct {
suite.Suite
client Client
clientSourceTag Client
sourceTag string
podIndex string
serverlessIndex string
}
Expand All @@ -36,6 +38,13 @@ func (ts *ClientTests) SetupSuite() {
}
ts.client = *client

ts.sourceTag = "test_source_tag"
clientSourceTag, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: ts.sourceTag})
if err != nil {
ts.FailNow(err.Error())
}
ts.clientSourceTag = *clientSourceTag

// this will clean up the project deleting all indexes and collections that are
// named a UUID. Generally not needed as all tests are cleaning up after themselves
// Left here as a convenience during active development.
Expand All @@ -49,6 +58,12 @@ func (ts *ClientTests) TestListIndexes() {
require.Greater(ts.T(), len(indexes), 0, "Expected at least one index to exist")
}

func (ts *ClientTests) TestListIndexesSourceTag() {
indexes, err := ts.clientSourceTag.ListIndexes(context.Background())
require.NoError(ts.T(), err)
require.Greater(ts.T(), len(indexes), 0, "Expected at least one index to exist")
}

func (ts *ClientTests) TestCreatePodIndex() {
name := uuid.New().String()

Expand Down Expand Up @@ -93,12 +108,24 @@ func (ts *ClientTests) TestDescribeServerlessIndex() {
require.Equal(ts.T(), ts.serverlessIndex, index.Name, "Index name does not match")
}

func (ts *ClientTests) TestDescribeServerlessIndexSourceTag() {
index, err := ts.clientSourceTag.DescribeIndex(context.Background(), ts.serverlessIndex)
require.NoError(ts.T(), err)
require.Equal(ts.T(), ts.serverlessIndex, index.Name, "Index name does not match")
}

func (ts *ClientTests) TestDescribePodIndex() {
index, err := ts.client.DescribeIndex(context.Background(), ts.podIndex)
require.NoError(ts.T(), err)
require.Equal(ts.T(), ts.podIndex, index.Name, "Index name does not match")
}

func (ts *ClientTests) TestDescribePodIndexSourceTag() {
index, err := ts.clientSourceTag.DescribeIndex(context.Background(), ts.podIndex)
require.NoError(ts.T(), err)
require.Equal(ts.T(), ts.podIndex, index.Name, "Index name does not match")
}

func (ts *ClientTests) TestListCollections() {
ctx := context.Background()

Expand Down
78 changes: 73 additions & 5 deletions pinecone/index_connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ import (

type IndexConnectionTests struct {
suite.Suite
host string
dimension int32
apiKey string
idxConn *IndexConnection
vectorIds []string
host string
dimension int32
apiKey string
idxConn *IndexConnection
sourceTag string
idxConnSourceTag *IndexConnection
vectorIds []string
}

// Runs the test suite with `go test`
Expand Down Expand Up @@ -63,6 +65,11 @@ func (ts *IndexConnectionTests) SetupSuite() {
assert.NoError(ts.T(), err)
ts.idxConn = idxConn

ts.sourceTag = "test_source_tag"
idxConnSourceTag, err := newIndexConnection(ts.apiKey, ts.host, namespace.String(), ts.sourceTag)
assert.NoError(ts.T(), err)
ts.idxConnSourceTag = idxConnSourceTag

ts.loadData()
}

Expand All @@ -80,6 +87,13 @@ func (ts *IndexConnectionTests) TestFetchVectors() {
assert.NotNil(ts.T(), res)
}

func (ts *IndexConnectionTests) TestFetchVectorsSourceTag() {
ctx := context.Background()
res, err := ts.idxConnSourceTag.FetchVectors(&ctx, ts.vectorIds)
assert.NoError(ts.T(), err)
assert.NotNil(ts.T(), res)
}

func (ts *IndexConnectionTests) TestQueryByVector() {
vec := make([]float32, ts.dimension)
for i := range vec {
Expand All @@ -97,6 +111,23 @@ func (ts *IndexConnectionTests) TestQueryByVector() {
assert.NotNil(ts.T(), res)
}

func (ts *IndexConnectionTests) TestQueryByVectorSourceTag() {
vec := make([]float32, ts.dimension)
for i := range vec {
vec[i] = 0.01
}

req := &QueryByVectorValuesRequest{
Vector: vec,
TopK: 5,
}

ctx := context.Background()
res, err := ts.idxConnSourceTag.QueryByVectorValues(&ctx, req)
assert.NoError(ts.T(), err)
assert.NotNil(ts.T(), res)
}

func (ts *IndexConnectionTests) TestQueryById() {
req := &QueryByVectorIdRequest{
VectorId: ts.vectorIds[0],
Expand All @@ -109,6 +140,18 @@ func (ts *IndexConnectionTests) TestQueryById() {
assert.NotNil(ts.T(), res)
}

func (ts *IndexConnectionTests) TestQueryByIdSourceTag() {
req := &QueryByVectorIdRequest{
VectorId: ts.vectorIds[0],
TopK: 5,
}

ctx := context.Background()
res, err := ts.idxConnSourceTag.QueryByVectorId(&ctx, req)
assert.NoError(ts.T(), err)
assert.NotNil(ts.T(), res)
}

func (ts *IndexConnectionTests) TestDeleteVectorsById() {
ctx := context.Background()
err := ts.idxConn.DeleteVectorsById(&ctx, ts.vectorIds)
Expand Down Expand Up @@ -182,6 +225,31 @@ func (ts *IndexConnectionTests) loadData() {
assert.NoError(ts.T(), err)
}

func (ts *IndexConnectionTests) loadDataSourceTag() {
vals := []float32{0.01, 0.02, 0.03, 0.04, 0.05}
vectors := make([]*Vector, len(vals))
ts.vectorIds = make([]string, len(vals))

for i, val := range vals {
vec := make([]float32, ts.dimension)
for i := range vec {
vec[i] = val
}

id := fmt.Sprintf("vec-%d", i+1)
ts.vectorIds[i] = id

vectors[i] = &Vector{
Id: id,
Values: vec,
}
}

ctx := context.Background()
_, err := ts.idxConnSourceTag.UpsertVectors(&ctx, vectors)
assert.NoError(ts.T(), err)
}

func (ts *IndexConnectionTests) truncateData() {
ctx := context.Background()
err := ts.idxConn.DeleteAllVectorsInNamespace(&ctx)
Expand Down

0 comments on commit a8ce2fd

Please sign in to comment.