Skip to content

Commit

Permalink
aws#4725: Change model deployment to JumpStart
Browse files Browse the repository at this point in the history
  • Loading branch information
HubGab-Git committed Sep 29, 2024
1 parent faf8648 commit 4ba750e
Showing 1 changed file with 72 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"\n",
"Many use cases such as building a chatbot require text (text2text) generation models like **[BloomZ 7B1](https://huggingface.co/bigscience/bloomz-7b1)**, **[Flan T5 XXL](https://huggingface.co/google/flan-t5-xxl)**, and **[Flan T5 UL2](https://huggingface.co/google/flan-ul2)** to respond to user questions with insightful answers. The **BloomZ 7B1**, **Flan T5 XXL**, and **Flan T5 UL2** models have picked up a lot of general knowledge in training, but we often need to ingest and use a large library of more specific information.\n",
"\n",
"In this notebook we will demonstrate how to use **BloomZ 7B1**, **Flan T5 XXL**, and **Flan T5 UL2** to answer questions using a library of documents as a reference, by using document embeddings and retrieval. The embeddings are generated from **GPT-J-6B** embedding model. \n",
"In this notebook we will demonstrate how to use **BloomZ 7B1**, **Flan T5 XXL**, and **Flan T5 UL2** to answer questions using a library of documents as a reference, by using document embeddings and retrieval. The embeddings are generated from **MiniLM-L6-v2** embedding model. \n",
"\n",
"**This notebook serves a template such that you can easily replace the example dataset by your own to build a custom question and asnwering application.**"
]
Expand All @@ -45,7 +45,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
Expand All @@ -57,9 +56,8 @@
"outputs": [],
"source": [
"!pip install --upgrade sagemaker --quiet\n",
"!pip install ipywidgets==7.0.0 --quiet\n",
"!pip install langchain==0.0.148 --quiet\n",
"!pip install faiss-cpu --quiet"
"!pip install faiss-cpu --quiet\n",
"!pip install langchain --quiet"
]
},
{
Expand All @@ -70,59 +68,18 @@
},
"outputs": [],
"source": [
"import time\n",
"import sagemaker, boto3, json\n",
"from sagemaker.session import Session\n",
"from sagemaker.model import Model\n",
"from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n",
"from sagemaker.predictor import Predictor\n",
"from sagemaker import Session\n",
"from sagemaker.utils import name_from_base\n",
"from typing import Any, Dict, List, Optional\n",
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
"from sagemaker.jumpstart.model import JumpStartModel\n",
"\n",
"sagemaker_session = Session()\n",
"aws_role = sagemaker_session.get_caller_identity_arn()\n",
"aws_region = boto3.Session().region_name\n",
"sess = sagemaker.Session()\n",
"model_version = \"1.*\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def query_endpoint_with_json_payload(encoded_json, endpoint_name, content_type=\"application/json\"):\n",
" client = boto3.client(\"runtime.sagemaker\")\n",
" response = client.invoke_endpoint(\n",
" EndpointName=endpoint_name, ContentType=content_type, Body=encoded_json\n",
" )\n",
" return response\n",
"\n",
"\n",
"def parse_response_model_flan_t5(query_response):\n",
" model_predictions = json.loads(query_response[\"Body\"].read())\n",
" generated_text = model_predictions[\"generated_texts\"]\n",
" return generated_text\n",
"\n",
"\n",
"def parse_response_multiple_texts_bloomz(query_response):\n",
" generated_text = []\n",
" model_predictions = json.loads(query_response[\"Body\"].read())\n",
" for x in model_predictions[0]:\n",
" generated_text.append(x[\"generated_text\"])\n",
" return generated_text"
"sagemaker_session = Session()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Deploy SageMaker endpoint(s) for large language models and GPT-J 6B embedding model. Please uncomment the entries as below if you want to deploy multiple LLM models to compare their performance."
"Deploy SageMaker endpoint(s) for large language models and MiniLM-L6-v2 embedding model. Please uncomment the entries as below if you want to deploy multiple LLM models to compare their performance."
]
},
{
Expand All @@ -135,30 +92,21 @@
"source": [
"_MODEL_CONFIG_ = {\n",
" \"huggingface-text2text-flan-t5-xxl\": {\n",
" \"model_version\": \"2.*\",\n",
" \"instance type\": \"ml.g5.12xlarge\",\n",
" \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n",
" \"parse_function\": parse_response_model_flan_t5,\n",
" \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n",
" },\n",
" \"huggingface-textembedding-gpt-j-6b\": {\n",
" \"huggingface-textembedding-all-MiniLM-L6-v2\": {\n",
" \"model_version\": \"1.*\",\n",
" \"instance type\": \"ml.g5.24xlarge\",\n",
" \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n",
" },\n",
" # \"huggingface-textgeneration1-bloomz-7b1-fp16\": {\n",
" # \"instance type\": \"ml.g5.12xlarge\",\n",
" # \"env\": {},\n",
" # \"parse_function\": parse_response_multiple_texts_bloomz,\n",
" # \"prompt\": \"\"\"question: \\\"{question}\"\\\\n\\nContext: \\\"{context}\"\\\\n\\nAnswer:\"\"\",\n",
" # \"huggingface-textembedding-all-MiniLM-L6-v2\": {\n",
" # \"model_version\": \"3.*\",\n",
" # \"instance type\": \"ml.g5.12xlarge\"\n",
" # },\n",
" # \"huggingface-text2text-flan-ul2-bf16\": {\n",
" # \"instance type\": \"ml.g5.24xlarge\",\n",
" # \"env\": {\n",
" # \"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\",\n",
" # \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"\n",
" # },\n",
" # \"parse_function\": parse_response_model_flan_t5,\n",
" # \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n",
" # },\n",
" # \"model_version\": \"2.*\",\n",
" # \"instance type\": \"ml.g5.24xlarge\"\n",
" # }\n",
"}"
]
},
Expand All @@ -168,41 +116,27 @@
"metadata": {},
"outputs": [],
"source": [
"newline, bold, unbold = \"\\n\", \"\\033[1m\", \"\\033[0m\"\n",
"\n",
"for model_id in _MODEL_CONFIG_:\n",
" endpoint_name = name_from_base(f\"jumpstart-example-raglc-{model_id}\")\n",
" inference_instance_type = _MODEL_CONFIG_[model_id][\"instance type\"]\n",
" model_version = _MODEL_CONFIG_[model_id][\"model_version\"]\n",
"\n",
" # Retrieve the inference container uri. This is the base HuggingFace container image for the default model above.\n",
" deploy_image_uri = image_uris.retrieve(\n",
" region=None,\n",
" framework=None, # automatically inferred from model_id\n",
" image_scope=\"inference\",\n",
" model_id=model_id,\n",
" model_version=model_version,\n",
" instance_type=inference_instance_type,\n",
" )\n",
" # Retrieve the model uri.\n",
" model_uri = model_uris.retrieve(\n",
" model_id=model_id, model_version=model_version, model_scope=\"inference\"\n",
" )\n",
" model_inference = Model(\n",
" image_uri=deploy_image_uri,\n",
" model_data=model_uri,\n",
" role=aws_role,\n",
" predictor_cls=Predictor,\n",
" name=endpoint_name,\n",
" env=_MODEL_CONFIG_[model_id][\"env\"],\n",
" )\n",
" model_predictor_inference = model_inference.deploy(\n",
" initial_instance_count=1,\n",
" instance_type=inference_instance_type,\n",
" predictor_cls=Predictor,\n",
" endpoint_name=endpoint_name,\n",
" )\n",
" print(f\"{bold}Model {model_id} has been deployed successfully.{unbold}{newline}\")\n",
" _MODEL_CONFIG_[model_id][\"endpoint_name\"] = endpoint_name"
" print(f\"Deploying {model_id}...\")\n",
"\n",
" model = JumpStartModel(model_id=model_id, model_version=model_version)\n",
"\n",
" try:\n",
" predictor = model.deploy(\n",
" initial_instance_count=1,\n",
" instance_type=inference_instance_type,\n",
" endpoint_name=name_from_base(f\"jumpstart-example-raglc-{model_id}\"),\n",
" )\n",
" print(f\"Deployed endpoint: {predictor.endpoint_name}\")\n",
" _MODEL_CONFIG_[model_id][\"predictor\"] = predictor\n",
" except Exception as e:\n",
" print(f\"Error deploying {model_id}: {str(e)}\")\n",
"\n",
"print(\"Deployment process completed.\")"
]
},
{
Expand All @@ -229,26 +163,14 @@
"metadata": {},
"outputs": [],
"source": [
"payload = {\n",
" \"text_inputs\": question,\n",
" \"max_length\": 100,\n",
" \"num_return_sequences\": 1,\n",
" \"top_k\": 50,\n",
" \"top_p\": 0.95,\n",
" \"do_sample\": True,\n",
"}\n",
"\n",
"list_of_LLMs = list(_MODEL_CONFIG_.keys())\n",
"list_of_LLMs.remove(\"huggingface-textembedding-gpt-j-6b\") # remove the embedding model\n",
"\n",
"list_of_LLMs = [model for model in list_of_LLMs if \"textembedding\" not in model]\n",
"\n",
"for model_id in list_of_LLMs:\n",
" endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n",
" query_response = query_endpoint_with_json_payload(\n",
" json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n",
" )\n",
" generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n",
" print(f\"For model: {model_id}, the generated output is: {generated_texts[0]}\\n\")"
" predictor = _MODEL_CONFIG_[model_id][\"predictor\"]\n",
" response = predictor.predict({\"inputs\": question})\n",
" print(f\"For model: {model_id}, the generated output is:\\n\")\n",
" print(f\"{response[0]['generated_text']}\\n\")"
]
},
{
Expand Down Expand Up @@ -283,31 +205,13 @@
"metadata": {},
"outputs": [],
"source": [
"parameters = {\n",
" \"max_length\": 200,\n",
" \"num_return_sequences\": 1,\n",
" \"top_k\": 250,\n",
" \"top_p\": 0.95,\n",
" \"do_sample\": False,\n",
" \"temperature\": 1,\n",
"}\n",
"prompt = f\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\n",
"\n",
"for model_id in list_of_LLMs:\n",
" endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n",
"\n",
" prompt = _MODEL_CONFIG_[model_id][\"prompt\"]\n",
"\n",
" text_input = prompt.replace(\"{context}\", context)\n",
" text_input = text_input.replace(\"{question}\", question)\n",
" payload = {\"text_inputs\": text_input, **parameters}\n",
"\n",
" query_response = query_endpoint_with_json_payload(\n",
" json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n",
" )\n",
" generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n",
" print(\n",
" f\"{bold}For model: {model_id}, the generated output is: {generated_texts[0]}{unbold}{newline}\"\n",
" )"
" predictor = _MODEL_CONFIG_[model_id][\"predictor\"]\n",
" response = predictor.predict({\"inputs\": prompt})\n",
" print(f\"For model: {model_id}, the generated output is:\\n\")\n",
" print(f\"{response[0]['generated_text']}\\n\")"
]
},
{
Expand All @@ -330,7 +234,7 @@
"\n",
"To achieve that, we will do following.\n",
"\n",
"1. **Generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B embedding model.**\n",
"1. **Generate embedings for each of document in the knowledge library with SageMaker MiniLM-L6-v2 embedding model.**\n",
"2. **Identify top K most relevant documents based on user query.**\n",
" - 2.1 **For a query of your interest, generate the embedding of the query using the same embedding model.**\n",
" - 2.2 **Search the indexes of top K most relevant documents in the embedding space using in-memory Faiss search.**\n",
Expand Down Expand Up @@ -365,6 +269,11 @@
"outputs": [],
"source": [
"from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler\n",
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
"from typing import List\n",
"import boto3\n",
"\n",
"aws_region = boto3.Session().region_name\n",
"\n",
"\n",
"class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):\n",
Expand Down Expand Up @@ -405,9 +314,12 @@
"\n",
"\n",
"content_handler = ContentHandler()\n",
"endpoint_name = _MODEL_CONFIG_[\"huggingface-textembedding-all-MiniLM-L6-v2\"][\n",
" \"predictor\"\n",
"].endpoint_name\n",
"\n",
"embeddings = SagemakerEndpointEmbeddingsJumpStart(\n",
" endpoint_name=_MODEL_CONFIG_[\"huggingface-textembedding-gpt-j-6b\"][\"endpoint_name\"],\n",
" endpoint_name=endpoint_name,\n",
" region_name=aws_region,\n",
" content_handler=content_handler,\n",
")"
Expand All @@ -428,33 +340,34 @@
"source": [
"from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint\n",
"\n",
"parameters = {\n",
" \"max_length\": 200,\n",
" \"num_return_sequences\": 1,\n",
" \"top_k\": 250,\n",
" \"top_p\": 0.95,\n",
" \"do_sample\": False,\n",
" \"temperature\": 1,\n",
"}\n",
"\n",
"\n",
"class ContentHandler(LLMContentHandler):\n",
" content_type = \"application/json\"\n",
" accepts = \"application/json\"\n",
"\n",
" def transform_input(self, prompt: str, model_kwargs={}) -> bytes:\n",
" input_str = json.dumps({\"text_inputs\": prompt, **model_kwargs})\n",
" input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
" return input_str.encode(\"utf-8\")\n",
"\n",
" def transform_output(self, output: bytes) -> str:\n",
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
" return response_json[\"generated_texts\"][0]\n",
" return response_json[0][\"generated_text\"]\n",
"\n",
"\n",
"content_handler = ContentHandler()\n",
"endpoint_name = _MODEL_CONFIG_[\"huggingface-text2text-flan-t5-xxl\"][\"predictor\"].endpoint_name\n",
"\n",
"parameters = {\n",
" \"max_length\": 200,\n",
" \"num_return_sequences\": 1,\n",
" \"top_k\": 250,\n",
" \"top_p\": 0.95,\n",
" \"do_sample\": False,\n",
" \"temperature\": 1,\n",
"}\n",
"\n",
"sm_llm = SagemakerEndpoint(\n",
" endpoint_name=_MODEL_CONFIG_[\"huggingface-text2text-flan-t5-xxl\"][\"endpoint_name\"],\n",
" endpoint_name=endpoint_name,\n",
" region_name=aws_region,\n",
" model_kwargs=parameters,\n",
" content_handler=content_handler,\n",
Expand Down Expand Up @@ -568,7 +481,8 @@
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain import PromptTemplate\n",
"from langchain.chains.question_answering import load_qa_chain\n",
"from langchain.document_loaders.csv_loader import CSVLoader"
"from langchain.document_loaders.csv_loader import CSVLoader\n",
"import json"
]
},
{
Expand Down Expand Up @@ -670,7 +584,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Firstly, we **generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B embedding model.**"
"Firstly, we **generate embedings for each of document in the knowledge library with SageMaker MiniLM-L6-v2 embedding model.**"
]
},
{
Expand Down Expand Up @@ -1384,9 +1298,9 @@
],
"instance_type": "ml.t3.medium",
"kernelspec": {
"display_name": "Python 3 (Data Science 2.0)",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-38"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -1398,7 +1312,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 4ba750e

Please sign in to comment.