Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Amazon Bedrock support #97

Merged
merged 6 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions examples/using_amazon_bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from nano_graphrag import GraphRAG, QueryParam

graph_func = GraphRAG(
working_dir="../bedrock_example",
using_amazon_bedrock=True,
best_model_id="us.anthropic.claude-3-sonnet-20240229-v1:0",
cheap_model_id="us.anthropic.claude-3-haiku-20240307-v1:0",
)

with open("../tests/mock_data.txt") as f:
graph_func.insert(f.read())

prompt = "What are the top themes in this story?"

# Perform global graphrag search
print(graph_func.query(prompt, param=QueryParam(mode="global")))

# Perform local graphrag search (I think is better and more scalable one)
print(graph_func.query(prompt, param=QueryParam(mode="local")))
116 changes: 116 additions & 0 deletions nano_graphrag/_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import json
import numpy as np
from typing import Optional, List, Any, Callable

import aioboto3
from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError

from tenacity import (
Expand All @@ -15,6 +18,7 @@

global_openai_async_client = None
global_azure_openai_async_client = None
global_amazon_bedrock_async_client = None


def get_openai_async_client_instance():
Expand All @@ -31,6 +35,13 @@ def get_azure_openai_async_client_instance():
return global_azure_openai_async_client


def get_amazon_bedrock_async_client_instance():
global global_amazon_bedrock_async_client
if global_amazon_bedrock_async_client is None:
global_amazon_bedrock_async_client = aioboto3.Session()
return global_amazon_bedrock_async_client


@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
Expand Down Expand Up @@ -64,6 +75,82 @@ async def openai_complete_if_cache(
return response.choices[0].message.content


@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def amazon_bedrock_complete_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance()
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
messages.extend(history_messages)
messages.append({"role": "user", "content": [{"text": prompt}]})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]

inference_config = {
"temperature": 0,
"maxTokens": 4096 if "max_tokens" not in kwargs else kwargs["max_tokens"],
}

async with amazon_bedrock_async_client.client(
"bedrock-runtime",
region_name=os.getenv("AWS_REGION", "us-east-1")
) as bedrock_runtime:
if system_prompt:
response = await bedrock_runtime.converse(
modelId=model, messages=messages, inferenceConfig=inference_config,
system=[{"text": system_prompt}]
)
else:
response = await bedrock_runtime.converse(
modelId=model, messages=messages, inferenceConfig=inference_config,
)

if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response["output"]["message"]["content"][0]["text"], "model": model}}
)
await hashing_kv.index_done_callback()
return response["output"]["message"]["content"][0]["text"]


def create_amazon_bedrock_complete_function(model_id: str) -> Callable:
"""
Factory function to dynamically create completion functions for Amazon Bedrock

Args:
model_id (str): Amazon Bedrock model identifier (e.g., "us.anthropic.claude-3-sonnet-20240229-v1:0")

Returns:
Callable: Generated completion function
"""
async def bedrock_complete(
prompt: str,
system_prompt: Optional[str] = None,
history_messages: List[Any] = [],
**kwargs
) -> str:
return await amazon_bedrock_complete_if_cache(
model_id,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs
)

# Set function name for easier debugging
bedrock_complete.__name__ = f"{model_id}_complete"

return bedrock_complete


async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
Expand All @@ -88,6 +175,35 @@ async def gpt_4o_mini_complete(
)


@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def amazon_bedrock_embedding(texts: list[str]) -> np.ndarray:
amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance()

async with amazon_bedrock_async_client.client(
"bedrock-runtime",
region_name=os.getenv("AWS_REGION", "us-east-1")
) as bedrock_runtime:
embeddings = []
for text in texts:
body = json.dumps(
{
"inputText": text,
"dimensions": 1024,
}
)
response = await bedrock_runtime.invoke_model(
modelId="amazon.titan-embed-text-v2:0", body=body,
)
response_body = await response.get("body").read()
embeddings.append(json.loads(response_body))
return np.array([dp["embedding"] for dp in embeddings])


@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(5),
Expand Down
7 changes: 5 additions & 2 deletions nano_graphrag/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ async def extract_entities(
knwoledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
global_config: dict,
using_amazon_bedrock: bool=False,
) -> Union[BaseGraphStorage, None]:
use_llm_func: callable = global_config["best_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
Expand Down Expand Up @@ -320,12 +321,14 @@ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
content = chunk_dp["content"]
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
final_result = await use_llm_func(hint_prompt)
if isinstance(final_result, list):
final_result = final_result[0]["text"]

history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, using_amazon_bedrock)
for now_glean_index in range(entity_extract_max_gleaning):
glean_result = await use_llm_func(continue_prompt, history_messages=history)

history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result, using_amazon_bedrock)
final_result += glean_result
if now_glean_index == entity_extract_max_gleaning - 1:
break
Expand Down
16 changes: 11 additions & 5 deletions nano_graphrag/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,17 @@ def load_json(file_name):


# it's dirty to type, so it's a good way to have fun
def pack_user_ass_to_openai_messages(*args: str):
roles = ["user", "assistant"]
return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
]
def pack_user_ass_to_openai_messages(prompt: str, generated_content: str, using_amazon_bedrock: bool):
if using_amazon_bedrock:
return [
{"role": "user", "content": [{"text": prompt}]},
{"role": "assistant", "content": [{"text": generated_content}]},
]
else:
return [
{"role": "user", "content": prompt},
{"role": "assistant", "content": generated_content},
]


def is_float_regex(value):
Expand Down
14 changes: 14 additions & 0 deletions nano_graphrag/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


from ._llm import (
amazon_bedrock_embedding,
create_amazon_bedrock_complete_function,
gpt_4o_complete,
gpt_4o_mini_complete,
openai_embedding,
Expand Down Expand Up @@ -107,6 +109,9 @@ class GraphRAG:

# LLM
using_azure_openai: bool = False
using_amazon_bedrock: bool = False
best_model_id: str = "us.anthropic.claude-3-sonnet-20240229-v1:0"
cheap_model_id: str = "us.anthropic.claude-3-haiku-20240307-v1:0"
best_model_func: callable = gpt_4o_complete
best_model_max_token_size: int = 32768
best_model_max_async: int = 16
Expand Down Expand Up @@ -145,6 +150,14 @@ def __post_init__(self):
"Switched the default openai funcs to Azure OpenAI if you didn't set any of it"
)

if self.using_amazon_bedrock:
self.best_model_func = create_amazon_bedrock_complete_function(self.best_model_id)
self.cheap_model_func = create_amazon_bedrock_complete_function(self.cheap_model_id)
self.embedding_func = amazon_bedrock_embedding
logger.info(
"Switched the default openai funcs to Amazon Bedrock"
)

if not os.path.exists(self.working_dir) and self.always_create_working_dir:
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
Expand Down Expand Up @@ -298,6 +311,7 @@ async def ainsert(self, string_or_strings):
knwoledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
global_config=asdict(self),
using_amazon_bedrock=self.using_amazon_bedrock,
)
if maybe_new_kg is None:
logger.warning("No new entities found")
Expand Down
5 changes: 5 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ pip install nano-graphrag
> [!TIP]
> If you're using Azure OpenAI API, refer to the [.env.example](./.env.example.azure) to set your azure openai. Then pass `GraphRAG(...,using_azure_openai=True,...)` to enable.

> [!TIP]
> If you're using Amazon Bedrock API, please ensure your credentials are properly set through commands like `aws configure`. Then enable it by configuring like this: `GraphRAG(...,using_amazon_bedrock=True, best_model_id="us.anthropic.claude-3-sonnet-20240229-v1:0", cheap_model_id="us.anthropic.claude-3-haiku-20240307-v1:0",...)`. Refer to an [example script](./examples/using_amazon_bedrock.py).

> [!TIP]
>
> If you don't have any key, check out this [example](./examples/no_openai_key_at_all.py) that using `transformers` and `ollama` . If you like to use another LLM or Embedding Model, check [Advances](#Advances).
Expand Down Expand Up @@ -167,9 +170,11 @@ Below are the components you can use:
| Type | What | Where |
| :-------------- | :----------------------------------------------------------: | :-----------------------------------------------: |
| LLM | OpenAI | Built-in |
| | Amazon Bedrock | Built-in |
| | DeepSeek | [examples](./examples) |
| | `ollama` | [examples](./examples) |
| Embedding | OpenAI | Built-in |
| | Amazon Bedrock | Built-in |
| | Sentence-transformers | [examples](./examples) |
| Vector DataBase | [`nano-vectordb`](https://github.com/gusye1234/nano-vectordb) | Built-in |
| | [`hnswlib`](https://github.com/nmslib/hnswlib) | Built-in, [examples](./examples) |
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ hnswlib
xxhash
tenacity
dspy-ai
neo4j
neo4j
aioboto3
Loading