Skip to content
This repository has been archived by the owner on May 10, 2024. It is now read-only.

Commit

Permalink
add gorillaLM example
Browse files Browse the repository at this point in the history
  • Loading branch information
asaiacai committed Nov 20, 2023
1 parent e5216a9 commit 42d9a66
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
28 changes: 28 additions & 0 deletions examples/gorilla_open_functions/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
## How to run the example

Install `llm-atc` and `openai` Python API

```bash
pip install llm-atc
pip install openai==0.28.1
```

Launch your self-hosted Gorilla Open Functions LLM.

```bash
# launch the LLM server
llm-atc serve --name gorilla-llm/gorilla-openfunctions-v1 -c testvllm --accelerator V100:1

# get the ip of the server
sky status --ip testvllm
```

Run the example which performs a function call for querying the weather. Edit the script to use the ip address of the server.

```bash
python test_openai_gorilla.py
```

## References

[Gorilla Open Function](https://gorilla.cs.berkeley.edu/blogs/4_open_functions.html)
75 changes: 75 additions & 0 deletions examples/gorilla_open_functions/test_openai_gorilla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import json
import openai
import re

GPT_MODEL = "gorilla-openfunctions-v1"
openai.api_key = "EMPTY"
openai.api_base = "http://18.221.156.100:8000/v1"

def extract_function_arguments(response):
"""
Extracts the keyword arguments of a function call from a program given as a string.
:param response: The response from gorilla open functions LLM
:return: A list of dictionaries, each containing key-value pairs of arguments.
"""

# regex to match the name of the function call
pattern = r"([^\s(]+)\("
function_name = re.findall(pattern, response)[0]

# Define a regex pattern to match the function call
# This pattern matches the function name followed by an opening parenthesis,
# then captures everything until the matching closing parenthesis
pattern = rf"{re.escape(function_name)}\((.*?)\)"

# Find all matches in the program string
match = re.findall(pattern, response)[0]

arg_pattern = r"(\w+)\s*=\s*('[^']*'|\"[^\"]*\"|\w+)"
args = re.findall(arg_pattern, match)

# Convert the matches to a dictionary
arg_dict = {key.strip(): value.strip() for key, value in args}
return {
"name" : function_name,
"arguments" : arg_dict,
}

def get_gorilla_response(prompt="Call me an Uber ride type \"Plus\" in Berkeley at zipcode 94704 in 10 minutes", model="gorilla-openfunctions-v1", functions=[]):
def get_prompt(user_query, functions=[]):
if len(functions) == 0:
return f"USER: <<question>> {user_query}\nASSISTANT: "
functions_string = json.dumps(functions)
return f"USER: <<question>> {user_query} <<function>> {functions_string}\nASSISTANT: "
prompt = get_prompt(prompt, functions=functions)
try:
completion = openai.ChatCompletion.create(
model=model,
temperature=0.0,
messages=[{"role": "user", "content": prompt}],
)
return completion.choices[0].message.content
except Exception as e:
print(e, model, prompt)

functions = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]

response = get_gorilla_response("What's the weather in Los Angeles in degrees Fahrenheit?", functions=functions)
print(extract_function_arguments(response))

0 comments on commit 42d9a66

Please sign in to comment.