Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compatible with oobabooga text-generation-web-ui #1019

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 278 additions & 0 deletions garak/generators/oobabooga.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""REST API generator interface

Generic Module for REST API connections
"""


import json
import logging
from typing import List, Union
import requests

import backoff
import jsonpath_ng
from jsonpath_ng.exceptions import JsonPathParserError

from garak import _config
from garak.exception import APIKeyMissingError, RateLimitHit
from garak.generators.base import Generator


class RestGenerator(Generator):
"""Generic API interface for REST models

See reference docs for details (https://reference.garak.ai/en/latest/garak.generators.rest.html)
"""

DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"headers": {},
"method": "post",
"ratelimit_codes": [429],
"response_json": False,
"response_json_field": None,
"req_template": "$INPUT",
"request_timeout": 50,
}

ENV_VAR = "REST_API_KEY"
generator_family_name = "REST"

_supported_params = (
"api_key",
"name",
"uri",
"key_env_var",
"req_template",
"req_template_json",
"context_len",
"max_tokens",
"method",
"headers",
"response_json",
"response_json_field",
"req_template_json_object",
"request_timeout",
"ratelimit_codes",
"temperature",
"top_k",
)

def __init__(self, uri=None, config_root=_config):
self.uri = uri
self.name = uri
self.seed = _config.run.seed
self.supports_multiple_generations = False # not implemented yet
self.escape_function = self._json_escape
self.retry_5xx = True
self.key_env_var = self.ENV_VAR if hasattr(self, "ENV_VAR") else None

# load configuration since super.__init__ has not been called
self._load_config(config_root)

if (
hasattr(self, "req_template_json_object")
and self.req_template_json_object is not None
):
self.req_template = json.dumps(self.req_template_json_object)

if self.response_json:
if self.response_json_field is None:
raise ValueError(
"RestGenerator response_json is True but response_json_field isn't set"
)
if not isinstance(self.response_json_field, str):
raise ValueError("response_json_field must be a string")
if self.response_json_field == "":
raise ValueError(
"RestGenerator response_json is True but response_json_field is an empty string. If the root object is the target object, use a JSONPath."
)

if self.name is None:
self.name = self.uri

if self.uri is None:
raise ValueError(
"No REST endpoint URI definition found in either constructor param, JSON, or --model_name. Please specify one."
)

self.fullname = f"{self.generator_family_name} {self.name}"

self.method = self.method.lower()
if self.method not in (
"get",
"post",
"put",
"patch",
"options",
"delete",
"head",
):
logging.info(
"RestGenerator HTTP method %s not supported, defaulting to 'post'",
self.method,
)
self.method = "post"
self.http_function = getattr(requests, self.method)

# validate jsonpath
if self.response_json and self.response_json_field:
try:
self.json_expr = jsonpath_ng.parse(self.response_json_field)
except JsonPathParserError as e:
logging.CRITICAL(
"Couldn't parse response_json_field %s", self.response_json_field
)
raise e

super().__init__(self.name, config_root=config_root)

def _get_nested_value(self, obj, path):
"""Helper method to get nested dictionary values using a path string."""
try:
current = obj
parts = path.replace("][", "].").replace("[", ".").replace("]", "").split(".")
for part in parts:
if part.isdigit():
current = current[int(part)]
else:
current = current[part]
return current
except (KeyError, IndexError, TypeError) as e:
logging.error(f"Failed to access path '{path}' in object: {e}")
return None

def _validate_env_var(self):
key_match = "$KEY"
header_requires_key = False
for _k, v in self.headers.items():
if key_match in v:
header_requires_key = True
if "$KEY" in self.req_template or header_requires_key:
return super()._validate_env_var()

def _json_escape(self, text: str) -> str:
"""JSON escape a string"""
# trim first & last "
return json.dumps(text)[1:-1]

def _populate_template(
self, template: str, text: str, json_escape_key: bool = False
) -> str:
"""Replace template placeholders with values

Interesting values are:
* $KEY - the API key set as an object variable
* $INPUT - the prompt text

$KEY is only set if the relevant environment variable is set; the
default variable name is REST_API_KEY but this can be overridden.
"""
output = template
if "$KEY" in template:
if self.api_key is None:
raise APIKeyMissingError(
f"Template requires an API key but {self.key_env_var} env var isn't set"
)
if json_escape_key:
output = output.replace("$KEY", self.escape_function(self.api_key))
else:
output = output.replace("$KEY", self.api_key)
return output.replace("$INPUT", self.escape_function(text))

# we'll overload IOError as the rate limit exception
@backoff.on_exception(backoff.fibo, RateLimitHit, max_value=70)
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
"""Individual call to get a rest from the REST API

:param prompt: the input to be placed into the request template and sent to the endpoint
:type prompt: str
"""

request_data = self._populate_template(self.req_template, prompt)

request_headers = dict(self.headers)
for k, v in self.headers.items():
request_headers[k] = self._populate_template(v, prompt)

# the prompt should not be sent via data when using a GET request. Prompt should be
# serialized as parameters, in general a method could be created to add
# the prompt data to a request via params or data based on the action verb
data_kw = "params" if self.http_function == requests.get else "data"
req_kArgs = {
data_kw: request_data,
"headers": request_headers,
"timeout": self.request_timeout,
}
resp = self.http_function(self.uri, **req_kArgs)
if resp.status_code in self.ratelimit_codes:
raise RateLimitHit(f"Rate limited: {resp.status_code} - {resp.reason}, uri: {self.uri}")

elif str(resp.status_code)[0] == "3":
raise NotImplementedError(
f"REST URI redirection: {resp.status_code} - {resp.reason}, uri: {self.uri}"
)

elif str(resp.status_code)[0] == "4":
raise ConnectionError(
f"REST URI client error: {resp.status_code} - {resp.reason}, uri: {self.uri}"
)

elif str(resp.status_code)[0] == "5":
error_msg = f"REST URI server error: {resp.status_code} - {resp.reason}, uri: {self.uri}"
if self.retry_5xx:
raise IOError(error_msg)
else:
raise ConnectionError(error_msg)

if not self.response_json:
return [str(resp.text)]

response_object = json.loads(resp.content)

# if response_json_field starts with a $, treat is as a JSONPath
assert (
self.response_json
), "response_json must be True at this point; if False, we should have returned already"
assert isinstance(
self.response_json_field, str
), "response_json_field must be a string"
assert (
len(self.response_json_field) > 0
), "response_json_field needs to be complete if response_json is true; ValueError should have been raised in constructor"
if self.response_json_field[0] != "$":
if isinstance(response_object, list):
response = [item[self.response_json_field] for item in response_object]
else:
value = self._get_nested_value(response_object, self.response_json_field)
if value is not None:
response = [value]
else:
logging.error(f"Could not find field '{self.response_json_field}' in response: {response_object}")
return [None]
else:
field_path_expr = jsonpath_ng.parse(self.response_json_field)
responses = field_path_expr.find(response_object)
if len(responses) == 1:
response_value = responses[0].value
if isinstance(response_value, str):
response = [response_value]
elif isinstance(response_value, list):
response = response_value
elif len(responses) > 1:
response = [r.value for r in responses]
else:
logging.error(
"RestGenerator JSONPath in response_json_field yielded nothing. Response content: %s"
% repr(resp.content)
)
return [None]

return response


DEFAULT_CLASS = "RestGenerator"
Loading