Skip to content

Commit

Permalink
Update num_examples and lida dataframe
Browse files Browse the repository at this point in the history
  • Loading branch information
ultmaster committed Jan 11, 2024
1 parent 36b9c16 commit 9f4b4d4
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 9 deletions.
12 changes: 9 additions & 3 deletions coml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,24 @@ class CoMLAgent:
tokens in the prompt.
num_examples: The number of examples to show in the prompt. It can be a
number between 0 and 1, interpreted as the percentage of examples to show.
It can also be an integer, interpreted as the number of examples to show.
message_style: Can be ``chatgpt`` in which system messages are shown, or
``gemini`` in which only human and ai messages are shown.
chain_of_thought: Whether to use chain of thought (COT) in the prompt.
context_order: The order of the context in the prompt. Default to ``vcr``.
``v`` for variable descriptions, ``c`` for codes, ``r`` for request.
ensemble: Perform ``ensemble`` number of LLM calls and ensemble the results.
ensemble_shuffle: Shuffle the examples in the prompt before ensemble.
example_ranking: A model that ranks the examples. If provided, the examples
will be ranked by the model before selecting the examples.
"""

def __init__(
self,
llm: BaseChatModel,
prompt_version: Literal["v1", "v2"] = "v2",
prompt_validation: Callable[[list[BaseMessage]], bool] | None = None,
num_examples: float = 1.0,
num_examples: float | int = 1.0,
message_style: Literal["chatgpt", "gemini"] = "chatgpt",
chain_of_thought: bool = False,
context_order: Literal[
Expand Down Expand Up @@ -262,8 +265,11 @@ def _select_examples(self, query: str, fewshots: list[_Type]) -> list[_Type]:
similarities.sort(key=lambda x: x[0])
fewshots = [shot for _, shot in similarities]

num_shots = max(int(len(fewshots) * self.num_examples), 1)
return fewshots[:num_shots]
if isinstance(self.num_examples, int):
return fewshots[:self.num_examples]
else:
num_shots = max(int(len(fewshots) * self.num_examples), 1)
return fewshots[:num_shots]

def generate_code(
self,
Expand Down
94 changes: 88 additions & 6 deletions coml/prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import json
import re
import types
import warnings
from pathlib import Path
from typing import Any, TypedDict, cast
from typing import Any, TypedDict, Literal, cast
from typing_extensions import NotRequired

import pandas as pd


class GenerateContextIncomplete(TypedDict):
variables: dict[str, str]
Expand Down Expand Up @@ -39,6 +42,76 @@ class FixContext(TypedDict):
interactions: list[InteractionIncomplete | Interaction]


def lida_dataframe_describe(df: pd.DataFrame, n_samples: int) -> list[dict]:
"""Get properties of each column in a pandas DataFrame, in which way used in LIDA."""

def check_type(dtype: str, value):
"""Cast value to right type to ensure it is JSON serializable"""
if "float" in str(dtype):
return float(value)
elif "int" in str(dtype):
return int(value)
else:
return value

properties_list = []
for column in df.columns:
dtype = df[column].dtype
properties = {}
if dtype in [int, float, complex]:
properties["dtype"] = "number"
properties["std"] = check_type(dtype, df[column].std())
properties["min"] = check_type(dtype, df[column].min())
properties["max"] = check_type(dtype, df[column].max())

elif dtype == bool:
properties["dtype"] = "boolean"
elif dtype == object:
# Check if the string column can be cast to a valid datetime
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
pd.to_datetime(df[column], errors="raise")
properties["dtype"] = "date"
except ValueError:
# Check if the string column has a limited number of values
if df[column].nunique() / len(df[column]) < 0.5:
properties["dtype"] = "category"
else:
properties["dtype"] = "string"
elif pd.api.types.is_categorical_dtype(df[column]):
properties["dtype"] = "category"
elif pd.api.types.is_datetime64_any_dtype(df[column]):
properties["dtype"] = "date"
else:
properties["dtype"] = str(dtype)

# add min max if dtype is date
if properties["dtype"] == "date":
try:
properties["min"] = df[column].min()
properties["max"] = df[column].max()
except TypeError:
cast_date_col = pd.to_datetime(df[column], errors="coerce")
properties["min"] = cast_date_col.min()
properties["max"] = cast_date_col.max()
# Add additional properties to the output dictionary
nunique = df[column].nunique()
if "samples" not in properties:
non_null_values = df[column][df[column].notnull()].unique()
n_samples = min(n_samples, len(non_null_values))
samples = (
pd.Series(non_null_values).sample(n_samples, random_state=42).tolist()
)
properties["samples"] = samples
properties["num_unique_values"] = nunique
# properties["semantic_type"] = ""
# properties["description"] = ""
properties_list.append({"column": column, "properties": properties})

return properties_list


PANDAS_DESCRIPTION_CONFIG: Any = dict(max_cols=10, max_colwidth=20, max_rows=10)
MAXIMUM_LIST_ITEMS = 30

Expand All @@ -47,6 +120,7 @@ def describe_variable(
value: Any,
pandas_description_config: Any | None = None,
maximum_list_items: int | None = None,
dataframe_format: Literal["coml", "lida"] = "coml",
) -> str:
import numpy
import pandas
Expand All @@ -59,11 +133,19 @@ def describe_variable(
if isinstance(value, numpy.ndarray):
return "numpy.ndarray(shape={}, dtype={})".format(value.shape, value.dtype)
elif isinstance(value, pandas.DataFrame):
return "pandas.DataFrame(shape={}, columns={})\n{}".format(
value.shape,
describe_variable(value.columns.tolist()),
add_indent(value.to_string(**pandas_description_config).rstrip()),
)
if dataframe_format == "coml":
return "pandas.DataFrame(shape={}, columns={})\n{}".format(
value.shape,
describe_variable(value.columns.tolist()),
add_indent(value.to_string(**pandas_description_config).rstrip()),
)
elif dataframe_format == "lida":
return "pandas.DataFrame(shape={}, columns={})".format(
value.shape,
lida_dataframe_describe(
value, n_samples=pandas_description_config.get("max_rows", 10)
),
)
elif isinstance(value, pandas.Series):
return "pandas.Series(shape={})".format(value.shape)
elif isinstance(value, list):
Expand Down

0 comments on commit 9f4b4d4

Please sign in to comment.