diff --git a/ice/agents/openai.py b/ice/agents/openai.py index 38cb2220..7b56d822 100644 --- a/ice/agents/openai.py +++ b/ice/agents/openai.py @@ -153,10 +153,14 @@ def __init__( model: str = "gpt-3.5-turbo", temperature: float = 0.0, top_p: float = 1.0, + logprobs: bool = False, + top_logprobs: int = None, ): self.model = model self.temperature = temperature self.top_p = top_p + self.logprobs = logprobs + self.top_logprobs = top_logprobs async def complete( self, @@ -176,6 +180,18 @@ async def complete( self._print_markdown(completion) return completion + async def predict(self, *, context, default="", verbose=False) -> dict[str, float]: + """Generate a probability distribution over the next token given some context.""" + if verbose: + self._print_markdown(context) + response = await self._complete( + context, top_logprobs=5, logprobs=True, max_tokens=1 + ) + prediction = self._extract_prediction(response) + if verbose: + self._print_markdown(prediction) + return prediction + async def classify( self, *, @@ -184,9 +200,28 @@ async def classify( default: Optional[str] = None, verbose: bool = False, ) -> tuple[dict[str, float], Optional[str]]: - raise NotImplementedError( - "OpenAI ChatCompletion has no option to score a classification." - ) + """Generate a classification from a list of choices given some context and a question.""" + if verbose: + self._print_markdown(prompt) + self._print_markdown(choices) + + choice_prefix = longest_common_prefix(choices).rstrip() + prompt_with_prefix = f"{prompt}{choice_prefix}" + + if prompt_with_prefix.endswith(" "): + prompt_with_prefix = prompt_with_prefix[:-1] + default = " " + else: + default = "" + + prediction = await self.predict(context=prompt_with_prefix, default=default) + + rel_probs = self._compute_relative_probs(choices, choice_prefix, prediction) + + if verbose: + self._print_markdown(rel_probs) + + return rel_probs, None async def relevance( self, @@ -200,13 +235,6 @@ async def relevance( "OpenAI ChatCompletion has no option to return a relevance score." ) - async def predict( - self, *, context: str, default: str = "", verbose: bool = False - ) -> dict[str, float]: - raise NotImplementedError( - "OpenAI ChatCompletion does not support getting probabilities." - ) - async def _complete(self, prompt, **kwargs) -> dict: """Send a completion request to the OpenAI API with the given prompt and parameters.""" kwargs.update( @@ -215,9 +243,17 @@ async def _complete(self, prompt, **kwargs) -> dict: "temperature": self.temperature, "top_p": self.top_p, "n": 1, + "logprobs": self.logprobs, + "top_logprobs": self.top_logprobs, } ) - messages = [{"role": "user", "content": prompt}] + messages = [ + { + "role": "system", + "content": "You are a helpful assistant. Your answers follow instructions and remain grounded in the context.", + }, + {"role": "user", "content": prompt}, + ] response = await openai_chatcomplete(messages, **kwargs) if "choices" not in response: raise ValueError(f"No choices in response: {response}") @@ -227,6 +263,38 @@ def _extract_completion(self, response: dict) -> str: """Extract the answer text from the completion response.""" return response["choices"][0]["message"]["content"].strip() + def _extract_prediction(self, response: dict) -> dict[str, float]: + """Extract the prediction dictionary from the completion response.""" + answer = response["choices"][0]["logprobs"]["content"][0]["top_logprobs"] + return {a["token"]: math.exp(a["logprob"]) for a in answer} + + def _compute_relative_probs( + self, choices: tuple[str, ...], choice_prefix: str, prediction: dict[str, float] + ) -> dict[str, float]: + """Compute the relative probabilities of the choices based on the prediction.""" + + def lookup_prob(choice: str): + scores = 0.0 + for token, prob in prediction.items(): + if choice[len(choice_prefix) :].startswith(token): + scores += prob + return scores + + abs_probs = {choice: lookup_prob(choice) for choice in choices} + Z = sum(abs_probs.values()) + if Z < 0.6: + log.warning(f"{1-Z} of unaccounted probability in classify") + log.warning(choice_prefix) + log.warning(str(prediction)) + log.warning(str(abs_probs)) + + rel_probs = ( + {choice: prob / Z for (choice, prob) in abs_probs.items()} + if Z != 0.0 + else abs_probs + ) + return rel_probs + def _print_markdown(self, obj: Any): """Print the text with markdown formatting.""" env().print(obj, format_markdown=True) diff --git a/ice/apis/openai.py b/ice/apis/openai.py index ac34ed92..5c23289e 100644 --- a/ice/apis/openai.py +++ b/ice/apis/openai.py @@ -176,8 +176,10 @@ async def openai_chatcomplete( stop: Optional[str] = "\n", top_p: float = 1, temperature: float = 0, - model: str = "gpt-3.5-turbo", + model: str = "gpt-3.5-turbo-16k", max_tokens: int = 256, + logprobs: bool = False, + top_logprobs: Optional[int] = None, logit_bias: Optional[Mapping[str, Union[int, float]]] = None, n: int = 1, cache_id: int = 0, # for repeated non-deterministic sampling using caching @@ -191,6 +193,8 @@ async def openai_chatcomplete( "model": model, "max_tokens": max_tokens, "n": n, + "logprobs": logprobs, + "top_logprobs": top_logprobs, } if logit_bias: params["logit_bias"] = logit_bias # type: ignore[assignment]