Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

知识库检索 API更新 #720

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions go/appbuilder/knowledge_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,12 @@ func (t *KnowledgeBase) DescribeChunks(req DescribeChunksRequest) (DescribeChunk
}

func (t *KnowledgeBase) QueryKnowledgeBase(req QueryKnowledgeBaseRequest) (QueryKnowledgeBaseResponse, error) {
// 检查 RankScoreThreshold 是否为 nil,如果是,则设置默认值
if req.RankScoreThreshold == nil {
defaultThreshold := 0.4
req.RankScoreThreshold = &defaultThreshold
}

request := http.Request{}
header := t.sdkConfig.AuthHeaderV2()
serviceURL, err := t.sdkConfig.ServiceURLV2("/knowledgebases/query")
Expand Down
64 changes: 44 additions & 20 deletions go/appbuilder/knowledge_base_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@

package appbuilder

type QueryType string

const (
Fulltext QueryType = "fulltext"
Semantic QueryType = "semantic"
Hybrid QueryType = "hybrid"
)

const (
ContentTypeRawText = "raw_text"
ContentTypeQA = "qa"
Expand Down Expand Up @@ -277,6 +285,18 @@ type ElasticSearchRetrieveConfig struct {
Top int `json:"top"`
}

type VectorDBRetrieveConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Threshold float64 `json:"threshold"`
Top int `json:"top"`
}

type SmallToBigConfig struct {
Name string `json:"name"`
Type string `json:"type"`
}

type RankingConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Expand All @@ -291,13 +311,14 @@ type QueryPipelineConfig struct {
}

type QueryKnowledgeBaseRequest struct {
Query string `json:"query"`
KnowledgebaseIDs []string `json:"knowledgebase_ids"`
Type *string `json:"type,omitempty"`
Top int `json:"top,omitempty"`
Skip int `json:"skip,omitempty"`
MetadataFileters MetadataFilters `json:"metadata_fileters,omitempty"`
PipelineConfig QueryPipelineConfig `json:"pipeline_config,omitempty"`
Query string `json:"query"`
KnowledgebaseIDs []string `json:"knowledgebase_ids"`
Type *QueryType `json:"type,omitempty"`
Top int `json:"top,omitempty"`
Skip int `json:"skip,omitempty"`
RankScoreThreshold *float64 `json:"rank_score_threshold,omitempty"`
MetadataFileters MetadataFilters `json:"metadata_fileters,omitempty"`
PipelineConfig QueryPipelineConfig `json:"pipeline_config,omitempty"`
}

type RowLine struct {
Expand All @@ -314,19 +335,22 @@ type ChunkLocation struct {
}

type Chunk struct {
ChunkID string `json:"chunk_id"`
KnowledgebaseID string `json:"knowledgebase_id"`
DocumentID string `json:"document_id"`
DocumentName string `json:"document_name"`
Meta map[string]any `json:"meta"`
Type string `json:"type"`
Content string `json:"content"`
CreateTime string `json:"create_time"`
UpdateTime string `json:"update_time"`
RetrievalScore float64 `json:"retrieval_score"`
RankScore float64 `json:"rank_score"`
Locations ChunkLocation `json:"locations"`
Children []Chunk `json:"children"`
ChunkID string `json:"chunk_id"`
KnowledgebaseID string `json:"knowledgebase_id"`
DocumentID string `json:"document_id"`
DocumentName string `json:"document_name"`
Meta map[string]any `json:"meta"`
Type string `json:"type"`
Content string `json:"content"`
CreateTime string `json:"create_time"`
UpdateTime string `json:"update_time"`
RetrievalScore float64 `json:"retrieval_score"`
RankScore float64 `json:"rank_score"`
Locations ChunkLocation `json:"locations"`
Children []Chunk `json:"children"`
NeighbourChunks []Chunk `json:"neighbour_chunks"`
OriginalChunkId string `json:"original_chunk_id"`
OriginalChunkOffset int `json:"original_chunk_offset"`
}

type QueryKnowledgeBaseResponse struct {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,9 @@ public ChunksDescribeResponse describeChunks(String documentId, String marker, I

public QueryKnowledgeBaseResponse queryKnowledgeBase(QueryKnowledgeBaseRequest request)
throws IOException, AppBuilderServerException {
if (request.getRank_score_threshold() == null) {
request.setRank_score_threshold(0.4f);
}
String url = AppBuilderConfig.QUERY_KNOWLEDGEBASE_URL;

String jsonBody = JsonUtils.serialize(request);
Expand All @@ -734,12 +737,16 @@ public QueryKnowledgeBaseResponse queryKnowledgeBase(QueryKnowledgeBaseRequest r
return respBody;
}

public QueryKnowledgeBaseResponse queryKnowledgeBase(String query, String type, Integer top, Integer skip,
public QueryKnowledgeBaseResponse queryKnowledgeBase(String query, String type, Float rank_score_threshold, Integer top, Integer skip,
String[] knowledgebaseIDs, QueryKnowledgeBaseRequest.MetadataFilters filters,
QueryKnowledgeBaseRequest.QueryPipelineConfig pipelineConfig)
throws IOException, AppBuilderServerException {
if (rank_score_threshold == null) {
rank_score_threshold = 0.4f;
}

String url = AppBuilderConfig.QUERY_KNOWLEDGEBASE_URL;
QueryKnowledgeBaseRequest request = new QueryKnowledgeBaseRequest(query, type, top, skip, knowledgebaseIDs, filters, pipelineConfig);
QueryKnowledgeBaseRequest request = new QueryKnowledgeBaseRequest(query, type, rank_score_threshold,top, skip, knowledgebaseIDs, filters, pipelineConfig);
String jsonBody = JsonUtils.serialize(request);
ClassicHttpRequest postRequest = httpClient.createPostRequestV2(url,
new StringEntity(jsonBody, StandardCharsets.UTF_8));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@
public class QueryKnowledgeBaseRequest {
private String query;
private String type;
private Float rank_score_threshold;
private Integer top;
private Integer skip;
private String[] knowledgebase_ids;
private MetadataFilters metadata_filters;
private QueryPipelineConfig pipeline_config;

public QueryKnowledgeBaseRequest(String query, String type, Integer top, Integer skip,
public QueryKnowledgeBaseRequest(String query, String type, Float rank_score_threshold, Integer top, Integer skip,
String[] knowledgebase_ids, MetadataFilters metadata_filters,
QueryPipelineConfig pipeline_config) {
this.query = query;
this.type = type;
this.rank_score_threshold = rank_score_threshold;
this.top = top;
this.skip = skip;
this.knowledgebase_ids = knowledgebase_ids;
Expand All @@ -39,6 +41,14 @@ public void setType(String type) {
this.type = type;
}

public Float getRank_score_threshold() {
return rank_score_threshold;
}

public void setRank_score_threshold(Float rank_score_threshold) {
this.rank_score_threshold = rank_score_threshold;
}

public Integer getTop() {
return top;
}
Expand Down Expand Up @@ -217,6 +227,46 @@ public void setTop(Integer top) {
}
}

public static class VectorDBRetrieveConfig {
private String name;
private String type;
private Double threshold;
private Integer top;

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public String getType() {
return type;
}

public void setType(String type) {
this.type = type;
}

public Double getThreshold() {
return threshold;
}

public void setThreshold(Double threshold) {
this.threshold = threshold;
}

public Integer getTop() {
return top;
}

public void setTop(Integer top) {
this.top = top;
}
}


public static class RankingConfig {
private String name;
private String type;
Expand Down Expand Up @@ -265,6 +315,28 @@ public void setTop(Integer top) {
}
}

public static class SmallToBigConfig {
private String name;
private String type;

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public String getType() {
return type;
}

public void setType(String type) {
this.type = type;
}

}

public static class QueryPipelineConfig {
private String id;
private List<Object> pipeline;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ public static class Chunk {
private float rank_score;
private ChunkLocation locations;
private List<Chunk> children;
private List<Chunk> neighbour_chunks;
private String original_chunk_id;
private Integer original_chunk_offset;

public String getChunk_id() { return chunk_id; }

Expand Down Expand Up @@ -96,6 +99,18 @@ public static class Chunk {
public List<Chunk> getChildren() { return children; }

public void setChildren(List<Chunk> children) { this.children = children; }

public List<Chunk> getNeighbour_chunks() { return neighbour_chunks; }

public void setNeighbour_chunks(List<Chunk> neighbour_chunks) { this.neighbour_chunks = neighbour_chunks; }

public String getOriginal_chunk_id() { return original_chunk_id; }

public void setOriginal_chunk_id(String original_chunk_id) { this.original_chunk_id = original_chunk_id; }

public Integer getOriginal_chunk_offset() { return original_chunk_offset; }

public void setOriginal_chunk_offset(Integer original_chunk_offset) { this.original_chunk_offset = original_chunk_offset; }
}

public static class ChunkLocation {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ public void testQueryKnowledgeBaseV2() throws IOException, AppBuilderServerExcep
Files.readAllBytes(Paths.get("src/test/java/com/baidubce/appbuilder/files/query_knowledgebase.json")));
QueryKnowledgeBaseRequest request = gson.fromJson(requestJson, QueryKnowledgeBaseRequest.class);
QueryKnowledgeBaseResponse response = knowledgebase.queryKnowledgeBase(request.getQuery(),
request.getType(), request.getTop(), request.getSkip(),
request.getType(), request.getRank_score_threshold(), request.getTop(), request.getSkip(),
request.getKnowledgebase_ids(), request.getMetadata_filters(), request.getPipeline_config());
assertNotNull(response.getChunks().get(0).getChunk_id());
}
Expand Down
29 changes: 25 additions & 4 deletions python/core/console/knowledge_base/data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import annotations
from datetime import datetime
from pydantic import BaseModel, Field
from enum import Enum
from typing import Union, Optional, List


Expand Down Expand Up @@ -323,13 +324,28 @@ class PreRankingConfig(BaseModel):
None, description="得分归一化参数,不建议修改,默认50"
)

class QueryType(str, Enum):
FULLTEXT = "fulltext" # 全文检索
SEMANTIC = "semantic" # 语义检索
HYBRID = "hybrid" # 混合检索

class ElasticSearchRetrieveConfig(BaseModel):
class ElasticSearchRetrieveConfig(BaseModel): # 托管资源为共享资源 或 BES资源时使用该配置
name: str = Field(..., description="配置名称")
type: str = Field(None, description="elastic_search标志,该节点为es全文检索")
threshold: float = Field(None, description="得分阈值,默认0.1")
top: int = Field(None, description="召回数量,默认400")

class VectorDBRetrieveConfig(BaseModel):
name: str = Field(..., description="该节点的自定义名称。")
type: str = Field("vector_db", description="该节点的类型,默认为vector_db。")
threshold: Optional[float] = Field(0.1, description="得分阈值。取值范围:[0, 1]", ge=0.0, le=1.0)
top: Optional[int] = Field(400, description="召回数量。取值范围:[0, 800]", ge=0, le=800)
pre_ranking: Optional[PreRankingConfig] = Field(None, description="粗排配置")

class SmallToBigConfig(BaseModel):
name: str = Field(..., description="配置名称")
type: str = Field("small_to_big", description="small_to_big标志,该节点为small_to_big节点")


class RankingConfig(BaseModel):
name: str = Field(..., description="配置名称")
Expand All @@ -341,24 +357,29 @@ class RankingConfig(BaseModel):
model_name: str = Field(None, description="ranking模型名(当前仅一种,暂不生效)")
top: int = Field(None, description="取切片top进行排序,默认20,最大400")


class QueryPipelineConfig(BaseModel):
id: str = Field(
None, description="配置唯一标识,如果用这个id,则引用已经配置好的QueryPipeline"
)
pipeline: list[Union[ElasticSearchRetrieveConfig, RankingConfig]] = Field(
pipeline: list[Union[ElasticSearchRetrieveConfig, RankingConfig, VectorDBRetrieveConfig, SmallToBigConfig]] = Field(
None, description="配置的Pipeline,如果没有用id,可以用这个对象指定一个新的配置"
)


class QueryKnowledgeBaseRequest(BaseModel):
query: str = Field(..., description="检索query")
type: str = Field(None, description="检索策略的枚举, fulltext:全文检索")
type: Optional[QueryType] = Field(None, description="检索策略的枚举, fulltext:全文检索, semantic:语义检索, hybrid:混合检索")
top: int = Field(None, description="返回结果数量")
skip: int = Field(
None,
description="跳过多少条记录, 通过top和skip可以实现类似分页的效果,比如top 10 skip 0,取第一页的10个,top 10 skip 10,取第二页的10个",
)
rank_score_threshold: float = Field(
0.4,
description="重排序匹配分阈值,只有rank_score大于等于该分值的切片重排序时才会被筛选出来。当且仅当,pipeline_config中配置了ranking节点时,该过滤条件生效。取值范围: [0, 1]。",
ge=0.0,
le=1.0,
)
knowledgebase_ids: list[str] = Field(..., description="知识库ID列表")
metadata_filters: MetadataFilters = Field(None, description="元数据过滤条件")
pipeline_config: QueryPipelineConfig = Field(None, description="检索配置")
Expand Down
6 changes: 4 additions & 2 deletions python/core/console/knowledge_base/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,10 +909,11 @@ def query_knowledge_base(
self,
query: str,
knowledgebase_ids: list[str],
type: str = None,
type: Optional[data_class.QueryType] = None,
metadata_filters: data_class.MetadataFilter = None,
pipeline_config: data_class.QueryPipelineConfig = None,
top: int = None,
rank_score_threshold: Optional[float] = 0.4,
top: int = 6,
skip: int = None,
) -> data_class.QueryKnowledgeBaseResponse:
"""
Expand All @@ -934,6 +935,7 @@ def query_knowledge_base(
type=type,
metadata_filters=metadata_filters,
pipeline_config=pipeline_config,
rank_score_threshold=rank_score_threshold,
top=top,
skip=skip,
)
Expand Down
Loading