Skip to content

Commit

Permalink
Feature/add web search (#1662)
Browse files Browse the repository at this point in the history
* add web search to rag agent

* add web search to rag agent
  • Loading branch information
emrgnt-cmplxty authored Dec 5, 2024
1 parent 01659ea commit fb2a221
Show file tree
Hide file tree
Showing 12 changed files with 253 additions and 28 deletions.
1 change: 1 addition & 0 deletions py/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"Prompt",
# Search abstractions
"AggregateSearchResult",
"WebSearchResponse",
"GraphSearchResult",
"ChunkSearchSettings",
"GraphSearchSettings",
Expand Down
58 changes: 50 additions & 8 deletions py/core/agent/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AggregateSearchResult,
GraphSearchSettings,
SearchSettings,
WebSearchResponse,
)
from core.base.agent import AgentConfig, Tool
from core.base.providers import CompletionProvider
Expand All @@ -30,31 +31,72 @@ def _register_tools(self):
if not self.config.tool_names:
return
for tool_name in self.config.tool_names:
if tool_name == "search":
self._tools.append(self.search_tool())
if tool_name == "local_search":
self._tools.append(self.local_search())
elif tool_name == "web_search":
self._tools.append(self.web_search())
else:
raise ValueError(f"Unsupported tool name: {tool_name}")

def search_tool(self) -> Tool:
def web_search(self) -> Tool:
return Tool(
name="search",
description="Search for information using the R2R framework",
results_function=self.search,
name="web_search",
description="Search for information on the web.",
results_function=self._web_search,
llm_format_function=RAGAgentMixin.format_search_results_for_llm,
stream_function=RAGAgentMixin.format_search_results_for_stream,
parameters={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search the local vector database with.",
"description": "The query to search Google with.",
},
},
"required": ["query"],
},
)

async def search(
async def _web_search(
self,
query: str,
search_settings: SearchSettings,
*args,
**kwargs,
) -> list[AggregateSearchResult]:
from .serper import SerperClient

serper_client = SerperClient()
# TODO - make async!
# TODO - Move to search pipeline, make configurable.
raw_results = serper_client.get_raw(query)
web_response = WebSearchResponse.from_serper_results(raw_results)
return AggregateSearchResult(
chunk_search_results=None,
graph_search_results=None,
web_search_results=web_response.organic_results, # TODO - How do we feel about throwing away so much info?
)

def local_search(self) -> Tool:
return Tool(
name="local_search",
description="Search your local knowledgebase using the R2R AI system",
results_function=self._local_search,
llm_format_function=RAGAgentMixin.format_search_results_for_llm,
stream_function=RAGAgentMixin.format_search_results_for_stream,
parameters={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search the local knowledgebase with.",
},
},
"required": ["query"],
},
)

async def _local_search(
self,
query: str,
search_settings: SearchSettings,
Expand Down
104 changes: 104 additions & 0 deletions py/core/agent/serper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# TODO - relocate to a dedicated module
import http.client
import json
import os


# TODO - Move process json to dedicated data processing module
def process_json(json_object, indent=0):
"""
Recursively traverses the JSON object (dicts and lists) to create an unstructured text blob.
"""
text_blob = ""
if isinstance(json_object, dict):
for key, value in json_object.items():
padding = " " * indent
if isinstance(value, (dict, list)):
text_blob += (
f"{padding}{key}:\n{process_json(value, indent + 1)}"
)
else:
text_blob += f"{padding}{key}: {value}\n"
elif isinstance(json_object, list):
for index, item in enumerate(json_object):
padding = " " * indent
if isinstance(item, (dict, list)):
text_blob += f"{padding}Item {index + 1}:\n{process_json(item, indent + 1)}"
else:
text_blob += f"{padding}Item {index + 1}: {item}\n"
return text_blob


# TODO - Introduce abstract "Integration" ABC.
class SerperClient:
def __init__(self, api_base: str = "google.serper.dev") -> None:
api_key = os.getenv("SERPER_API_KEY")
if not api_key:
raise ValueError(
"Please set the `SERPER_API_KEY` environment variable to use `SerperClient`."
)

self.api_base = api_base
self.headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json",
}

@staticmethod
def _extract_results(result_data: dict) -> list:
formatted_results = []

for key, value in result_data.items():
# Skip searchParameters as it's not a result entry
if key == "searchParameters":
continue

# Handle 'answerBox' as a single item
if key == "answerBox":
value["type"] = key # Add the type key to the dictionary
formatted_results.append(value)
# Handle lists of results
elif isinstance(value, list):
for item in value:
item["type"] = key # Add the type key to the dictionary
formatted_results.append(item)
# Handle 'peopleAlsoAsk' and potentially other single item formats
elif isinstance(value, dict):
value["type"] = key # Add the type key to the dictionary
formatted_results.append(value)

return formatted_results

# TODO - Add explicit typing for the return value
def get_raw(self, query: str, limit: int = 10) -> list:
connection = http.client.HTTPSConnection(self.api_base)
payload = json.dumps({"q": query, "num_outputs": limit})
connection.request("POST", "/search", payload, self.headers)
response = connection.getresponse()
data = response.read()
json_data = json.loads(data.decode("utf-8"))
return SerperClient._extract_results(json_data)

@staticmethod
def construct_context(results: list) -> str:
# Organize results by type
organized_results = {}
for result in results:
result_type = result.metadata.pop(
"type", "Unknown"
) # Pop the type and use as key
if result_type not in organized_results:
organized_results[result_type] = [result.metadata]
else:
organized_results[result_type].append(result.metadata)

context = ""
# Iterate over each result type
for result_type, items in organized_results.items():
context += f"# {result_type} Results:\n"
for index, item in enumerate(items, start=1):
# Process each item under the current type
context += f"Item {index}:\n"
context += process_json(item) + "\n"

return context
1 change: 1 addition & 0 deletions py/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"Prompt",
# Search abstractions
"AggregateSearchResult",
"WebSearchResponse",
"GraphSearchResult",
"GraphSearchSettings",
"ChunkSearchSettings",
Expand Down
2 changes: 2 additions & 0 deletions py/core/base/abstractions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
KGRelationshipResult,
KGSearchResultType,
SearchSettings,
WebSearchResponse,
)
from shared.abstractions.user import Token, TokenData, User
from shared.abstractions.vector import (
Expand Down Expand Up @@ -120,6 +121,7 @@
# Prompt abstractions
"Prompt",
# Search abstractions
"WebSearchResponse",
"AggregateSearchResult",
"GraphSearchResult",
"KGSearchResultType",
Expand Down
2 changes: 1 addition & 1 deletion py/core/configs/full_local_llm.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[agent]
system_instruction_name = "rag_agent"
tool_names = ["search"]
tool_names = ["local_search"]

[agent.generation_config]
model = "ollama/llama3.1"
Expand Down
2 changes: 1 addition & 1 deletion py/core/configs/local_llm.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[agent]
system_instruction_name = "rag_agent"
tool_names = ["search"]
tool_names = ["local_search"]

[agent.generation_config]
model = "ollama/llama3.1"
Expand Down
7 changes: 1 addition & 6 deletions py/core/main/api/v3/retrieval_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@
from fastapi import Body, Depends
from fastapi.responses import StreamingResponse

from core.base import (
GenerationConfig,
Message,
R2RException,
SearchSettings,
)
from core.base import GenerationConfig, Message, R2RException, SearchSettings
from core.base.api.models import (
WrappedAgentResponse,
WrappedCompletionResponse,
Expand Down
10 changes: 6 additions & 4 deletions py/core/pipes/retrieval/search_rag_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
AsyncState,
CompletionProvider,
DatabaseProvider,
KGSearchResultType
KGSearchResultType,
)
from core.base.abstractions import GenerationConfig, RAGCompletion

Expand Down Expand Up @@ -111,11 +111,13 @@ async def _collect_context(
# context += f"Results:\n"
if search_result.result_type == KGSearchResultType.ENTITY:
context += f"[{it}]: Entity Name - {search_result.content.name}\n\nDescription - {search_result.content.description}\n\n"
elif search_result.result_type == KGSearchResultType.RELATIONSHIP:
elif (
search_result.result_type
== KGSearchResultType.RELATIONSHIP
):
context += f"[{it}]: Relationship - {search_result.content.subject} - {search_result.content.predicate} - {search_result.content.object}\n\n"
else:
context += f"[{it}]: Community Name - {search_result.content.name}\n\nDescription - {search_result.content.summary}\n\n"

context += f"[{it}]: Community Name - {search_result.content.name}\n\nDescription - {search_result.content.summary}\n\n"

it += 1
total_results = (
Expand Down
3 changes: 2 additions & 1 deletion py/r2r.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

[agent]
system_instruction_name = "rag_agent"
tool_names = ["search"]
# tool_names = ["local_search", "web_search"] # uncomment to enable web search
tool_names = ["local_search"]

[agent.generation_config]
model = "openai/gpt-4o"
Expand Down
61 changes: 58 additions & 3 deletions py/shared/abstractions/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,67 @@ class Config:
}


class WebSearchResult(R2RSerializable):
title: str
link: str
snippet: str
position: int
type: str = "organic"
date: Optional[str] = None
sitelinks: Optional[list[dict]] = None


class RelatedSearchResult(R2RSerializable):
query: str
type: str = "related"


class PeopleAlsoAskResult(R2RSerializable):
question: str
snippet: str
link: str
title: str
type: str = "peopleAlsoAsk"


class WebSearchResponse(R2RSerializable):
organic_results: list[WebSearchResult] = []
related_searches: list[RelatedSearchResult] = []
people_also_ask: list[PeopleAlsoAskResult] = []

@classmethod
def from_serper_results(cls, results: list[dict]) -> "WebSearchResponse":
organic = []
related = []
paa = []

for result in results:
if result["type"] == "organic":
organic.append(WebSearchResult(**result))
elif result["type"] == "relatedSearches":
related.append(RelatedSearchResult(**result))
elif result["type"] == "peopleAlsoAsk":
paa.append(PeopleAlsoAskResult(**result))

return cls(
organic_results=organic,
related_searches=related,
people_also_ask=paa,
)


class AggregateSearchResult(R2RSerializable):
"""Result of an aggregate search operation."""

chunk_search_results: Optional[list[ChunkSearchResult]]
graph_search_results: Optional[list[GraphSearchResult]] = None
web_search_results: Optional[list[WebSearchResult]] = None

def __str__(self) -> str:
return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results})"
return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results})"

def __repr__(self) -> str:
return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results})"
return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results})"

def as_dict(self) -> dict:
return {
Expand All @@ -165,7 +215,12 @@ def as_dict(self) -> dict:
if self.chunk_search_results
else []
),
"graph_search_results": self.graph_search_results or None,
"graph_search_results": [
result.to_dict() for result in self.graph_search_results
],
"web_search_results": [
result.to_dict() for result in self.web_search_results
],
}


Expand Down
Loading

0 comments on commit fb2a221

Please sign in to comment.