From 33249a0cb3ba8a49ca80e102ec6e1ab3c8726d65 Mon Sep 17 00:00:00 2001 From: Kazuki Motohashi Date: Thu, 14 Nov 2024 04:09:53 +0000 Subject: [PATCH 1/6] Add Amazon Bedrock support --- nano_graphrag/_llm.py | 124 ++++++++++++++++++++++++++++++++++++++ nano_graphrag/_op.py | 7 ++- nano_graphrag/_utils.py | 16 +++-- nano_graphrag/graphrag.py | 14 +++++ requirements.txt | 3 +- 5 files changed, 156 insertions(+), 8 deletions(-) diff --git a/nano_graphrag/_llm.py b/nano_graphrag/_llm.py index f658234..bd4c33b 100644 --- a/nano_graphrag/_llm.py +++ b/nano_graphrag/_llm.py @@ -1,5 +1,7 @@ +import json import numpy as np +import aioboto3 from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError from tenacity import ( @@ -15,6 +17,7 @@ global_openai_async_client = None global_azure_openai_async_client = None +global_amazon_bedrock_async_client = None def get_openai_async_client_instance(): @@ -31,6 +34,16 @@ def get_azure_openai_async_client_instance(): return global_azure_openai_async_client +def get_amazon_bedrock_async_client_instance(): + global global_amazon_bedrock_async_client + if global_amazon_bedrock_async_client is None: + #global_amazon_bedrock_async_client = aioboto3.Session().client( + # service_name="bedrock-runtime", region_name=os.getenv("AWS_REGION", "us-east-1") + #) + global_amazon_bedrock_async_client = aioboto3.Session() + return global_amazon_bedrock_async_client + + @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -64,6 +77,88 @@ async def openai_complete_if_cache( return response.choices[0].message.content +@retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError)), +) +async def amazon_bedrock_complete_if_cache( + model, prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance() + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages = [] + messages.extend(history_messages) + messages.append({"role": "user", "content": [{"text": prompt}]}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + inference_config = { + "temperature": 0, + "maxTokens": 4096 if "max_tokens" not in kwargs else kwargs["max_tokens"], + } + + async with amazon_bedrock_async_client.client( + "bedrock-runtime", + region_name=os.getenv("AWS_REGION", "us-east-1") + ) as bedrock_runtime: + if system_prompt: + response = await bedrock_runtime.converse( + modelId=model, messages=messages, inferenceConfig=inference_config, + system=[{"text": system_prompt}] + ) + else: + response = await bedrock_runtime.converse( + modelId=model, messages=messages, inferenceConfig=inference_config, + ) + + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": response["output"]["message"]["content"][0]["text"], "model": model}} + ) + await hashing_kv.index_done_callback() + return response["output"]["message"]["content"][0]["text"] + + +async def claude_3_5_haiku_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await amazon_bedrock_complete_if_cache( + "us.anthropic.claude-3-5-haiku-20241022-v1:0", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + +async def claude_3_haiku_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await amazon_bedrock_complete_if_cache( + "us.anthropic.claude-3-haiku-20240307-v1:0", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + +async def claude_3_sonnet_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await amazon_bedrock_complete_if_cache( + "us.anthropic.claude-3-sonnet-20240229-v1:0", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + async def gpt_4o_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -88,6 +183,35 @@ async def gpt_4o_mini_complete( ) +@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) +@retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError)), +) +async def amazon_bedrock_embedding(texts: list[str]) -> np.ndarray: + amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance() + + async with amazon_bedrock_async_client.client( + "bedrock-runtime", + region_name=os.getenv("AWS_REGION", "us-east-1") + ) as bedrock_runtime: + embeddings = [] + for text in texts: + body = json.dumps( + { + "inputText": text, + "dimensions": 1024, + } + ) + response = await bedrock_runtime.invoke_model( + modelId="amazon.titan-embed-text-v2:0", body=body, + ) + response_body = await response.get("body").read() + embeddings.append(json.loads(response_body)) + return np.array([dp["embedding"] for dp in embeddings]) + + @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(5), diff --git a/nano_graphrag/_op.py b/nano_graphrag/_op.py index 4ee9ed5..5d738fa 100644 --- a/nano_graphrag/_op.py +++ b/nano_graphrag/_op.py @@ -293,6 +293,7 @@ async def extract_entities( knwoledge_graph_inst: BaseGraphStorage, entity_vdb: BaseVectorStorage, global_config: dict, + using_amazon_bedrock: bool=False, ) -> Union[BaseGraphStorage, None]: use_llm_func: callable = global_config["best_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] @@ -320,12 +321,14 @@ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): content = chunk_dp["content"] hint_prompt = entity_extract_prompt.format(**context_base, input_text=content) final_result = await use_llm_func(hint_prompt) + if isinstance(final_result, list): + final_result = final_result[0]["text"] - history = pack_user_ass_to_openai_messages(hint_prompt, final_result) + history = pack_user_ass_to_openai_messages(hint_prompt, final_result, using_amazon_bedrock) for now_glean_index in range(entity_extract_max_gleaning): glean_result = await use_llm_func(continue_prompt, history_messages=history) - history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) + history += pack_user_ass_to_openai_messages(continue_prompt, glean_result, using_amazon_bedrock) final_result += glean_result if now_glean_index == entity_extract_max_gleaning - 1: break diff --git a/nano_graphrag/_utils.py b/nano_graphrag/_utils.py index ae772eb..8f76227 100644 --- a/nano_graphrag/_utils.py +++ b/nano_graphrag/_utils.py @@ -162,11 +162,17 @@ def load_json(file_name): # it's dirty to type, so it's a good way to have fun -def pack_user_ass_to_openai_messages(*args: str): - roles = ["user", "assistant"] - return [ - {"role": roles[i % 2], "content": content} for i, content in enumerate(args) - ] +def pack_user_ass_to_openai_messages(prompt: str, generated_content: str, using_amazon_bedrock: bool): + if using_amazon_bedrock: + return [ + {"role": "user", "content": [{"text": prompt}]}, + {"role": "assistant", "content": [{"text": generated_content}]}, + ] + else: + return [ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": generated_content}, + ] def is_float_regex(value): diff --git a/nano_graphrag/graphrag.py b/nano_graphrag/graphrag.py index 2c9e1be..b8a5c65 100644 --- a/nano_graphrag/graphrag.py +++ b/nano_graphrag/graphrag.py @@ -9,6 +9,10 @@ from ._llm import ( + amazon_bedrock_embedding, + claude_3_5_haiku_complete, + claude_3_sonnet_complete, + claude_3_haiku_complete, gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, @@ -107,6 +111,7 @@ class GraphRAG: # LLM using_azure_openai: bool = False + using_amazon_bedrock: bool = False best_model_func: callable = gpt_4o_complete best_model_max_token_size: int = 32768 best_model_max_async: int = 16 @@ -145,6 +150,14 @@ def __post_init__(self): "Switched the default openai funcs to Azure OpenAI if you didn't set any of it" ) + if self.using_amazon_bedrock: + self.best_model_func = claude_3_sonnet_complete + self.cheap_model_func = claude_3_haiku_complete + self.embedding_func = amazon_bedrock_embedding + logger.info( + "Switched the default openai funcs to Amazon Bedrock" + ) + if not os.path.exists(self.working_dir) and self.always_create_working_dir: logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) @@ -298,6 +311,7 @@ async def ainsert(self, string_or_strings): knwoledge_graph_inst=self.chunk_entity_relation_graph, entity_vdb=self.entities_vdb, global_config=asdict(self), + using_amazon_bedrock=self.using_amazon_bedrock, ) if maybe_new_kg is None: logger.warning("No new entities found") diff --git a/requirements.txt b/requirements.txt index be0e993..7d26a49 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ hnswlib xxhash tenacity dspy-ai -neo4j \ No newline at end of file +neo4j +aioboto3 \ No newline at end of file From 85ba67615989a8ce6376da364951a885a5338e55 Mon Sep 17 00:00:00 2001 From: Kazuki Motohashi Date: Thu, 14 Nov 2024 07:16:03 +0000 Subject: [PATCH 2/6] add sample script to test amazon bedrock integration --- examples/using_amazon_bedrock.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 examples/using_amazon_bedrock.py diff --git a/examples/using_amazon_bedrock.py b/examples/using_amazon_bedrock.py new file mode 100644 index 0000000..f63050c --- /dev/null +++ b/examples/using_amazon_bedrock.py @@ -0,0 +1,14 @@ +from nano_graphrag import GraphRAG, QueryParam + +graph_func = GraphRAG(working_dir="./bedrock_example", using_amazon_bedrock=True) + +with open("../tests/mock_data.txt") as f: + graph_func.insert(f.read()) + +prompt = "What are the top themes in this story?" + +# Perform global graphrag search +print(graph_func.query(prompt, param=QueryParam(mode="global"))) + +# Perform local graphrag search (I think is better and more scalable one) +print(graph_func.query(prompt, param=QueryParam(mode="local"))) From d04a34d7d6bb0d7e2aa2af54d3ea31c733a561db Mon Sep 17 00:00:00 2001 From: Kazuki Motohashi Date: Wed, 20 Nov 2024 05:00:06 +0000 Subject: [PATCH 3/6] add the latest Claude 3.5 Sonnet v1&v2 model --- nano_graphrag/_llm.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/nano_graphrag/_llm.py b/nano_graphrag/_llm.py index bd4c33b..61458cc 100644 --- a/nano_graphrag/_llm.py +++ b/nano_graphrag/_llm.py @@ -135,6 +135,30 @@ async def claude_3_5_haiku_complete( ) +async def claude_3_5_sonnet_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await amazon_bedrock_complete_if_cache( + "us.anthropic.claude-3-5-sonnet-20240620-v1:0", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + +async def claude_3_5_sonnet_v2_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await amazon_bedrock_complete_if_cache( + "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + async def claude_3_haiku_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: From 9d71f941e37557b53e0bf02f2beef796017ad83b Mon Sep 17 00:00:00 2001 From: Kazuki Motohashi Date: Wed, 20 Nov 2024 05:28:11 +0000 Subject: [PATCH 4/6] Add a factory function for bedrock completion instead of creating one for each model --- examples/using_amazon_bedrock.py | 7 ++- nano_graphrag/_llm.py | 87 +++++++++++--------------------- nano_graphrag/graphrag.py | 10 ++-- 3 files changed, 40 insertions(+), 64 deletions(-) diff --git a/examples/using_amazon_bedrock.py b/examples/using_amazon_bedrock.py index f63050c..c8aeac4 100644 --- a/examples/using_amazon_bedrock.py +++ b/examples/using_amazon_bedrock.py @@ -1,6 +1,11 @@ from nano_graphrag import GraphRAG, QueryParam -graph_func = GraphRAG(working_dir="./bedrock_example", using_amazon_bedrock=True) +graph_func = GraphRAG( + working_dir="../bedrock_example", + using_amazon_bedrock=True, + best_model_id="us.anthropic.claude-3-sonnet-20240229-v1:0", + cheap_model_id="us.anthropic.claude-3-haiku-20240307-v1:0", +) with open("../tests/mock_data.txt") as f: graph_func.insert(f.read()) diff --git a/nano_graphrag/_llm.py b/nano_graphrag/_llm.py index 61458cc..eb35c66 100644 --- a/nano_graphrag/_llm.py +++ b/nano_graphrag/_llm.py @@ -1,5 +1,6 @@ import json import numpy as np +from typing import Optional, List, Any, Callable import aioboto3 from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError @@ -123,64 +124,34 @@ async def amazon_bedrock_complete_if_cache( return response["output"]["message"]["content"][0]["text"] -async def claude_3_5_haiku_complete( - prompt, system_prompt=None, history_messages=[], **kwargs -) -> str: - return await amazon_bedrock_complete_if_cache( - "us.anthropic.claude-3-5-haiku-20241022-v1:0", - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - ) - - -async def claude_3_5_sonnet_complete( - prompt, system_prompt=None, history_messages=[], **kwargs -) -> str: - return await amazon_bedrock_complete_if_cache( - "us.anthropic.claude-3-5-sonnet-20240620-v1:0", - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - ) - - -async def claude_3_5_sonnet_v2_complete( - prompt, system_prompt=None, history_messages=[], **kwargs -) -> str: - return await amazon_bedrock_complete_if_cache( - "us.anthropic.claude-3-5-sonnet-20241022-v2:0", - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - ) - - -async def claude_3_haiku_complete( - prompt, system_prompt=None, history_messages=[], **kwargs -) -> str: - return await amazon_bedrock_complete_if_cache( - "us.anthropic.claude-3-haiku-20240307-v1:0", - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - ) - - -async def claude_3_sonnet_complete( - prompt, system_prompt=None, history_messages=[], **kwargs -) -> str: - return await amazon_bedrock_complete_if_cache( - "us.anthropic.claude-3-sonnet-20240229-v1:0", - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - ) +def create_amazon_bedrock_complete_function(model_id: str) -> Callable: + """ + Factory function to dynamically create completion functions for Amazon Bedrock + + Args: + model_id (str): Amazon Bedrock model identifier (e.g., "us.anthropic.claude-3-sonnet-20240229-v1:0") + + Returns: + Callable: Generated completion function + """ + async def bedrock_complete( + prompt: str, + system_prompt: Optional[str] = None, + history_messages: List[Any] = [], + **kwargs + ) -> str: + return await amazon_bedrock_complete_if_cache( + model_id, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs + ) + + # Set function name for easier debugging + bedrock_complete.__name__ = f"{model_id}_complete" + + return bedrock_complete async def gpt_4o_complete( diff --git a/nano_graphrag/graphrag.py b/nano_graphrag/graphrag.py index b8a5c65..e60fcb0 100644 --- a/nano_graphrag/graphrag.py +++ b/nano_graphrag/graphrag.py @@ -10,9 +10,7 @@ from ._llm import ( amazon_bedrock_embedding, - claude_3_5_haiku_complete, - claude_3_sonnet_complete, - claude_3_haiku_complete, + create_amazon_bedrock_complete_function, gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, @@ -112,6 +110,8 @@ class GraphRAG: # LLM using_azure_openai: bool = False using_amazon_bedrock: bool = False + best_model_id: str = "us.anthropic.claude-3-sonnet-20240229-v1:0" + cheap_model_id: str = "us.anthropic.claude-3-haiku-20240307-v1:0" best_model_func: callable = gpt_4o_complete best_model_max_token_size: int = 32768 best_model_max_async: int = 16 @@ -151,8 +151,8 @@ def __post_init__(self): ) if self.using_amazon_bedrock: - self.best_model_func = claude_3_sonnet_complete - self.cheap_model_func = claude_3_haiku_complete + self.best_model_func = create_amazon_bedrock_complete_function(self.best_model_id) + self.cheap_model_func = create_amazon_bedrock_complete_function(self.cheap_model_id) self.embedding_func = amazon_bedrock_embedding logger.info( "Switched the default openai funcs to Amazon Bedrock" From 2b6cdaa1640db6edddda085b3da415976ca2ebc2 Mon Sep 17 00:00:00 2001 From: Kazuki Motohashi Date: Wed, 20 Nov 2024 05:48:44 +0000 Subject: [PATCH 5/6] update README.md to explain the Bedrock option. --- readme.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/readme.md b/readme.md index 112c4a0..ba6a673 100644 --- a/readme.md +++ b/readme.md @@ -73,6 +73,9 @@ pip install nano-graphrag > [!TIP] > If you're using Azure OpenAI API, refer to the [.env.example](./.env.example.azure) to set your azure openai. Then pass `GraphRAG(...,using_azure_openai=True,...)` to enable. +> [!TIP] +> If you're using Amazon Bedrock API, please ensure your credentials are properly set through commands like `aws configure``. Then enable it by configuring like this: `GraphRAG(...,using_amazon_bedrock=True, best_model_id="us.anthropic.claude-3-sonnet-20240229-v1:0", cheap_model_id="us.anthropic.claude-3-haiku-20240307-v1:0",...)`. Refer to an [example script](./examples/using_amazon_bedrock.py). + > [!TIP] > > If you don't have any key, check out this [example](./examples/no_openai_key_at_all.py) that using `transformers` and `ollama` . If you like to use another LLM or Embedding Model, check [Advances](#Advances). @@ -167,9 +170,11 @@ Below are the components you can use: | Type | What | Where | | :-------------- | :----------------------------------------------------------: | :-----------------------------------------------: | | LLM | OpenAI | Built-in | +| | Amazon Bedrock | Built-in | | | DeepSeek | [examples](./examples) | | | `ollama` | [examples](./examples) | | Embedding | OpenAI | Built-in | +| | Amazon Bedrock | Built-in | | | Sentence-transformers | [examples](./examples) | | Vector DataBase | [`nano-vectordb`](https://github.com/gusye1234/nano-vectordb) | Built-in | | | [`hnswlib`](https://github.com/nmslib/hnswlib) | Built-in, [examples](./examples) | From e278f5dcd98733db9e993f4c9ff56fd8e4665b52 Mon Sep 17 00:00:00 2001 From: Kazuki Motohashi Date: Wed, 20 Nov 2024 05:56:41 +0000 Subject: [PATCH 6/6] clean up --- nano_graphrag/_llm.py | 3 --- readme.md | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/nano_graphrag/_llm.py b/nano_graphrag/_llm.py index eb35c66..974c339 100644 --- a/nano_graphrag/_llm.py +++ b/nano_graphrag/_llm.py @@ -38,9 +38,6 @@ def get_azure_openai_async_client_instance(): def get_amazon_bedrock_async_client_instance(): global global_amazon_bedrock_async_client if global_amazon_bedrock_async_client is None: - #global_amazon_bedrock_async_client = aioboto3.Session().client( - # service_name="bedrock-runtime", region_name=os.getenv("AWS_REGION", "us-east-1") - #) global_amazon_bedrock_async_client = aioboto3.Session() return global_amazon_bedrock_async_client diff --git a/readme.md b/readme.md index ba6a673..2fc470c 100644 --- a/readme.md +++ b/readme.md @@ -74,7 +74,7 @@ pip install nano-graphrag > If you're using Azure OpenAI API, refer to the [.env.example](./.env.example.azure) to set your azure openai. Then pass `GraphRAG(...,using_azure_openai=True,...)` to enable. > [!TIP] -> If you're using Amazon Bedrock API, please ensure your credentials are properly set through commands like `aws configure``. Then enable it by configuring like this: `GraphRAG(...,using_amazon_bedrock=True, best_model_id="us.anthropic.claude-3-sonnet-20240229-v1:0", cheap_model_id="us.anthropic.claude-3-haiku-20240307-v1:0",...)`. Refer to an [example script](./examples/using_amazon_bedrock.py). +> If you're using Amazon Bedrock API, please ensure your credentials are properly set through commands like `aws configure`. Then enable it by configuring like this: `GraphRAG(...,using_amazon_bedrock=True, best_model_id="us.anthropic.claude-3-sonnet-20240229-v1:0", cheap_model_id="us.anthropic.claude-3-haiku-20240307-v1:0",...)`. Refer to an [example script](./examples/using_amazon_bedrock.py). > [!TIP] >