diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index f6c052ce..7e777997 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -22,8 +22,10 @@ num_chars_from_tokens, ) from jinja2 import Template +from openai import OpenAI from rouge_score import rouge_scorer import click +import httpx import instructlab.utils import tqdm @@ -266,18 +268,13 @@ def get_instructions_from_model( request_idx, instruction_data_pool, prompt_template, - api_base, - api_key, + client, model_name, num_prompt_instructions, request_batch_size, temperature, top_p, output_file_discarded, - tls_insecure, - tls_client_cert, - tls_client_key, - tls_client_passwd, ): batch_inputs = [] for _ in range(request_batch_size): @@ -306,14 +303,9 @@ def get_instructions_from_model( request_start = time.time() try: results = utils.openai_completion( - api_base=api_base, - api_key=api_key, + client, prompts=batch_inputs, model_name=model_name, - tls_insecure=tls_insecure, - tls_client_cert=tls_client_cert, - tls_client_key=tls_client_key, - tls_client_passwd=tls_client_passwd, batch_size=request_batch_size, decoding_args=decoding_args, ) @@ -473,6 +465,15 @@ def generate_data( tls_client_key: Optional[str] = None, tls_client_passwd: Optional[str] = None, ): + cert = tuple( + item for item in (tls_client_cert, tls_client_key, tls_client_passwd) if item + ) + client = OpenAI( + base_url=api_base, + api_key=api_key, + http_client=httpx.Client(cert=cert, verify=not tls_insecure), + ) + seed_instruction_data = [] machine_seed_instruction_data = [] generate_start = time.time() @@ -566,18 +567,13 @@ def generate_data( request_idx, instruction_data_pool, prompt_template, - api_base, - api_key, + client, model_name, num_prompt_instructions, request_batch_size, temperature, top_p, output_file_discarded, - tls_insecure, - tls_client_cert, - tls_client_key, - tls_client_passwd, ) total_discarded += discarded total = len(instruction_data) diff --git a/src/instructlab/sdg/utils.py b/src/instructlab/sdg/utils.py index 83423fe8..70968a77 100644 --- a/src/instructlab/sdg/utils.py +++ b/src/instructlab/sdg/utils.py @@ -13,10 +13,9 @@ # Third Party # instructlab - TODO these need to go away, issue #6 -from instructlab.configuration import DEFAULT_API_KEY, DEFAULT_MODEL_OLD +from instructlab.configuration import DEFAULT_MODEL_OLD from instructlab.utils import get_sysprompt -from openai import OpenAI, OpenAIError -import httpx +from openai import OpenAIError StrOrOpenAIObject = Union[str, object] @@ -40,11 +39,7 @@ class OpenAIDecodingArguments: def openai_completion( - api_base, - tls_insecure, - tls_client_cert, - tls_client_key, - tls_client_passwd, + client, prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]], decoding_args: OpenAIDecodingArguments, model_name="ggml-merlinite-7b-lab-Q4_K_M", @@ -52,7 +47,6 @@ def openai_completion( max_instances=sys.maxsize, max_batches=sys.maxsize, return_text=False, - api_key=DEFAULT_API_KEY, **decoding_kwargs, ) -> Union[ Union[StrOrOpenAIObject], @@ -62,11 +56,6 @@ def openai_completion( """Decode with OpenAI API. Args: - api_base: Endpoint URL where model is hosted - tls_insecure: Disable TLS verification - tls_client_cert: Path to the TLS client certificate to use - tls_client_key: Path to the TLS client key to use - tls_client_passwd: TLS client certificate password prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. @@ -78,7 +67,6 @@ def openai_completion( max_instances: Maximum number of prompts to decode. max_batches: Maximum number of batches to decode. This will be deprecated in the future. return_text: If True, return text instead of full completion object (e.g. includes logprob). - api_key: API key API key for API endpoint where model is hosted decoding_kwargs: Extra decoding arguments. Pass in `best_of` and `logit_bias` if needed. Returns: @@ -116,22 +104,8 @@ def openai_completion( **decoding_kwargs, } - if not api_key: - # we need to explicitly set non-empty api-key, to ensure generate - # connects to our local server - api_key = "no_api_key" - # do not pass a lower timeout to this client since generating a dataset takes some time # pylint: disable=R0801 - orig_cert = (tls_client_cert, tls_client_key, tls_client_passwd) - cert = tuple(item for item in orig_cert if item) - verify = not tls_insecure - client = OpenAI( - base_url=api_base, - api_key=api_key, - http_client=httpx.Client(cert=cert, verify=verify), - ) - # ensure the model specified exists on the server. with backends like vllm, this is crucial. model_list = client.models.list().data model_ids = []