Skip to content

Commit

Permalink
Implement SearchByPks
Browse files Browse the repository at this point in the history
Signed-off-by: unfode <[email protected]>
  • Loading branch information
unfode committed Oct 18, 2023
1 parent 647cbab commit 13fce3c
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 64 deletions.
13 changes: 11 additions & 2 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,17 @@ type Client interface {
// Upsert column-based data of collection, returns id column values
Upsert(ctx context.Context, collName string, partitionName string, columns ...entity.Column) (entity.Column, error)
// Search with bool expression
Search(ctx context.Context, collName string, partitions []string,
expr string, outputFields []string, vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc) ([]SearchResult, error)
Search(
ctx context.Context, collName string, partitions []string, expr string, outputFields []string,
vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int,
sp entity.SearchParam, opts ...SearchQueryOptionFunc,
) ([]SearchResult, error)
// SearchByPks searches using the vectors corresponding to the provided primary keys
SearchByPks(
ctx context.Context, collName string, partitions []string, expr string, outputFields []string,
primaryKeys entity.Column, vectorField string, metricType entity.MetricType, topK int,
sp entity.SearchParam, opts ...SearchQueryOptionFunc,
) ([]SearchResult, error)
// QueryByPks query record by specified primary key(s).
QueryByPks(ctx context.Context, collectionName string, partitionNames []string, ids entity.Column, outputFields []string, opts ...SearchQueryOptionFunc) (ResultSet, error)
// Query performs query records with boolean expression.
Expand Down
2 changes: 1 addition & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func TestGrpcClientNil(t *testing.T) {
mt := m.Type // type of function
if m.Name == "Close" || m.Name == "Connect" || // skip connect & close
m.Name == "UsingDatabase" || // skip use database
m.Name == "Search" || // type alias MetricType treated as string
m.Name == "Search" || m.Name == "SearchByPks" || // type alias MetricType treated as string
m.Name == "CalcDistance" ||
m.Name == "ManualCompaction" || // time.Duration hard to detect in reflect
m.Name == "Insert" || m.Name == "Upsert" { // complex methods with ...
Expand Down
200 changes: 139 additions & 61 deletions client/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"strings"

"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
Expand All @@ -35,45 +36,149 @@ const (
)

// Search with bool expression
func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []string,
expr string, outputFields []string, vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc) ([]SearchResult, error) {
func (c *GrpcClient) Search(
ctx context.Context, collName string, partitions []string, expr string, outputFields []string, vectors []entity.Vector,
vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc,
) ([]SearchResult, error) {
if c.Service == nil {
return []SearchResult{}, ErrClientNotReady
}
var schema *entity.Schema
collInfo, ok := MetaCache.getCollectionInfo(collName)

_, ok := MetaCache.getCollectionInfo(collName)
if !ok {
coll, err := c.DescribeCollection(ctx, collName)
_, err := c.DescribeCollection(ctx, collName)
if err != nil {
return nil, err
}
schema = coll.Schema
} else {
schema = collInfo.Schema
}

option, err := makeSearchQueryOption(collName, opts...)
if err != nil {
return nil, err
}
// 2. Request milvus Service
req, err := prepareSearchRequest(collName, partitions, expr, outputFields, vectors, vectorField, metricType, topK, sp, option)

params := sp.Params()
bs, err := json.Marshal(params)
if err != nil {
return nil, err
}

sr := make([]SearchResult, 0, len(vectors))
searchParams := prepareSearchParamsForSearchRequest(
vectorField, metricType, topK, bs, option,
)

req := &milvuspb.SearchRequest{
DbName: "",
CollectionName: collName,
PartitionNames: partitions,
Dsl: expr,
PlaceholderGroup: vector2PlaceholderGroupBytes(vectors),
DslType: commonpb.DslType_BoolExprV1,
OutputFields: outputFields,
SearchParams: searchParams,
GuaranteeTimestamp: option.GuaranteeTimestamp,
Nq: int64(len(vectors)),
SearchByPrimaryKeys: false,
}

resp, err := c.Service.Search(ctx, req)
if err != nil {
return nil, err
}
if err := handleRespStatus(resp.GetStatus()); err != nil {
return nil, err
}

return processSearchResponse(resp, outputFields), nil
}

func (c *GrpcClient) SearchByPks(
ctx context.Context, collName string, partitions []string, expr string, outputFields []string,
primaryKeys entity.Column, vectorField string, metricType entity.MetricType, topK int,
sp entity.SearchParam, opts ...SearchQueryOptionFunc,
) ([]SearchResult, error) {
if c.Service == nil {
return []SearchResult{}, ErrClientNotReady
}

if primaryKeys.Len() == 0 {
return nil, errors.New("expected at least one primary key, but got zero")
}
if primaryKeys.Type() != entity.FieldTypeInt64 && primaryKeys.Type() != entity.FieldTypeVarChar {
return nil, errors.New("only int64 and varchar column can be primary key for now")
}

_, ok := MetaCache.getCollectionInfo(collName)
if !ok {
_, err := c.DescribeCollection(ctx, collName)
if err != nil {
return nil, err
}
}

option, err := makeSearchQueryOption(collName, opts...)
if err != nil {
return nil, err
}

params := sp.Params()
bs, err := json.Marshal(params)
if err != nil {
return nil, err
}

searchParams := prepareSearchParamsForSearchRequest(
vectorField, metricType, topK, bs, option,
)

req := &milvuspb.SearchRequest{
DbName: "",
CollectionName: collName,
PartitionNames: partitions,
Dsl: expr,
PlaceholderGroup: primaryKeysToPlaceholderGroupBytes(primaryKeys),
DslType: commonpb.DslType_BoolExprV1,
OutputFields: outputFields,
SearchParams: searchParams,
GuaranteeTimestamp: option.GuaranteeTimestamp,
Nq: int64(primaryKeys.Len()),
SearchByPrimaryKeys: true,
}

resp, err := c.Service.Search(ctx, req)
if err != nil {
return nil, err
}
if err := handleRespStatus(resp.GetStatus()); err != nil {
return nil, err
}
// 3. parse result into result
results := resp.GetResults()

return processSearchResponse(resp, outputFields), nil
}

func prepareSearchParamsForSearchRequest(
vectorField string, metricType entity.MetricType, topK int, bs []byte, opt *SearchQueryOption,
) []*commonpb.KeyValuePair {
searchParams := entity.MapKvPairs(map[string]string{
"anns_field": vectorField,
"topk": fmt.Sprintf("%d", topK),
"params": string(bs),
"metric_type": string(metricType),
"round_decimal": "-1",
ignoreGrowingKey: strconv.FormatBool(opt.IgnoreGrowing),
offsetKey: fmt.Sprintf("%d", opt.Offset),
})

return searchParams
}

func processSearchResponse(response *milvuspb.SearchResults, outputFields []string) []SearchResult {
results := response.GetResults()

sr := make([]SearchResult, 0, results.GetNumQueries())
offset := 0
fieldDataList := results.GetFieldsData()

for i := 0; i < int(results.GetNumQueries()); i++ {
rc := int(results.GetTopks()[i]) // result entry count for current query
entry := SearchResult{
Expand All @@ -85,14 +190,15 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s
offset += rc
continue
}
entry.Fields, entry.Err = c.parseSearchResult(schema, outputFields, fieldDataList, i, offset, offset+rc)
entry.Fields, entry.Err = parseSearchResult(outputFields, fieldDataList, offset, offset+rc)
sr = append(sr, entry)
offset += rc
}
return sr, nil

return sr
}

func (c *GrpcClient) parseSearchResult(_ *entity.Schema, outputFields []string, fieldDataList []*schemapb.FieldData, _, from, to int) ([]entity.Column, error) {
func parseSearchResult(outputFields []string, fieldDataList []*schemapb.FieldData, from, to int) ([]entity.Column, error) {
// duplicated name will have only one column now
outputSet := make(map[string]struct{})
for _, output := range outputFields {
Expand Down Expand Up @@ -208,16 +314,12 @@ func (c *GrpcClient) Query(ctx context.Context, collectionName string, partition
return nil, ErrClientNotReady
}

var sch *entity.Schema
collInfo, ok := MetaCache.getCollectionInfo(collectionName)
_, ok := MetaCache.getCollectionInfo(collectionName)
if !ok {
coll, err := c.DescribeCollection(ctx, collectionName)
_, err := c.DescribeCollection(ctx, collectionName)
if err != nil {
return nil, err
}
sch = coll.Schema
} else {
sch = collInfo.Schema
}

option, err := makeSearchQueryOption(collectionName, opts...)
Expand Down Expand Up @@ -254,7 +356,7 @@ func (c *GrpcClient) Query(ctx context.Context, collectionName string, partition

fieldsData := resp.GetFieldsData()

columns, err := c.parseSearchResult(sch, outputFields, fieldsData, 0, 0, -1) //entity.FieldDataColumn(fieldData, 0, -1)
columns, err := parseSearchResult(outputFields, fieldsData, 0, -1) //entity.FieldDataColumn(fieldData, 0, -1)
if err != nil {
return nil, err
}
Expand All @@ -271,47 +373,23 @@ func getPKField(schema *entity.Schema) *entity.Field {
return nil
}

func getVectorField(schema *entity.Schema) *entity.Field {
for _, f := range schema.Fields {
if f.DataType == entity.FieldTypeFloatVector || f.DataType == entity.FieldTypeBinaryVector {
return f
}
}
return nil
}
func primaryKeysToPlaceholderGroupBytes(primaryKeys entity.Column) []byte {

func prepareSearchRequest(collName string, partitions []string,
expr string, outputFields []string, vectors []entity.Vector, vectorField string,
metricType entity.MetricType, topK int, sp entity.SearchParam, opt *SearchQueryOption) (*milvuspb.SearchRequest, error) {
params := sp.Params()
params[forTuningKey] = opt.ForTuning
bs, err := json.Marshal(params)
if err != nil {
return nil, err
}
queryExpr := PKs2Expr("", primaryKeys)
queryExprBytes := []byte(queryExpr)

searchParams := entity.MapKvPairs(map[string]string{
"anns_field": vectorField,
"topk": fmt.Sprintf("%d", topK),
"params": string(bs),
"metric_type": string(metricType),
"round_decimal": "-1",
ignoreGrowingKey: strconv.FormatBool(opt.IgnoreGrowing),
offsetKey: fmt.Sprintf("%d", opt.Offset),
})
req := &milvuspb.SearchRequest{
DbName: "",
CollectionName: collName,
PartitionNames: partitions,
Dsl: expr,
PlaceholderGroup: vector2PlaceholderGroupBytes(vectors),
DslType: commonpb.DslType_BoolExprV1,
OutputFields: outputFields,
SearchParams: searchParams,
GuaranteeTimestamp: opt.GuaranteeTimestamp,
Nq: int64(len(vectors)),
placeholderGroup := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
{
Tag: "$0",
Type: commonpb.PlaceholderType_None,
Values: [][]byte{queryExprBytes},
},
},
}
return req, nil

bs, _ := proto.Marshal(placeholderGroup)
return bs
}

// GetPersistentSegmentInfo get persistent segment info
Expand Down

0 comments on commit 13fce3c

Please sign in to comment.