Skip to content

Commit

Permalink
Schema changes and improvments (#24)
Browse files Browse the repository at this point in the history
* add language info in session results

* rename Evaluation to Latest Results

* Alibaba: better error handling and update model codes

* make it easier to debug

* update deps
  • Loading branch information
semio authored Sep 11, 2023
1 parent 504e332 commit e9da232
Show file tree
Hide file tree
Showing 9 changed files with 1,059 additions and 987 deletions.
2 changes: 2 additions & 0 deletions automation-api/lib/ai_eval_spreadsheet/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class Config:

class EvalResult(BaseModel):
question_id: Optional[str] = Field(None, title="Question ID")
language: Optional[str] = Field(None, title="Language")
prompt_variation_id: Optional[str] = Field(None, title="Prompt variation ID")
model_configuration_id: Optional[str] = Field(None, title="Model Configuration ID")
last_evaluation_datetime: Optional[datetime] = Field(None, title="Last Evaluation")
Expand All @@ -133,6 +134,7 @@ class SessionResult(BaseModel):
model_configuration_id: Optional[str] = Field(None, title="Model Configuration ID")
survey_id: Optional[str] = Field(None, title="Survey ID")
question_id: Optional[str] = Field(None, title="Question ID")
language: Optional[str] = Field(None, title="Language")
question_number: Optional[int] = Field(None, title="Question No.")
output: Optional[str] = Field(None, title="Response Text")
grade: Optional[str] = Field(None, title="Grade")
Expand Down
2 changes: 1 addition & 1 deletion automation-api/lib/ai_eval_spreadsheet/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class AiEvalData:
"prompt_variations": "Prompt variations",
"gen_ai_models": "Models",
"gen_ai_model_configs": "Model configurations",
"evaluation_results": "Evaluations",
"evaluation_results": "Latest Results",
"session_results": "Sessions",
}

Expand Down
16 changes: 11 additions & 5 deletions automation-api/lib/llms/alibaba.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@ def response_is_ok(response):
return False


def return_last_message(retry_state):
last_val = retry_state.outcome.result()
result = {"output": {"text": f"Error: {last_val.code}: {last_val.message}"}}
return result


@retry(
retry=(retry_if_exception_type() | retry_if_not_result(response_is_ok)),
stop=stop_after_attempt(3),
retry_error_callback=return_last_message,
)
def get_reply(**kwargs):
return Generation.call(**kwargs)
Expand Down Expand Up @@ -74,15 +81,15 @@ def validate_environment(cls, values: Dict) -> Dict: # noqa: N805
def _call(
self,
prompt: str,
history: Optional[List[Dict]] = None,
messages: Optional[List[Dict]] = None,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")

if history is None:
history = []
if messages is None:
messages = []

if self.seed is None:
# FIXME: Alibaba's API support uint64
Expand All @@ -99,13 +106,12 @@ def _call(
)(
model=self.model_name,
prompt=prompt,
history=history,
messages=messages,
top_p=self.top_p,
top_k=self.top_k,
seed=seed,
enable_search=self.enable_search,
)

return result["output"]["text"]

@property
Expand Down
5 changes: 4 additions & 1 deletion automation-api/lib/llms/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
stop_after_attempt,
)

from lib.app_singleton import app_logger as logger
from lib.config import read_config
from lib.llms.iflytek import SparkClient

Expand Down Expand Up @@ -88,7 +89,9 @@ def _call(
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
return self.generate_text_with_retry(prompt)
output = self.generate_text_with_retry(prompt)
logger.debug(f"Spark: {output}")
return output

@property
def _identifying_params(self) -> Mapping[str, Any]:
Expand Down
4 changes: 2 additions & 2 deletions automation-api/lib/llms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def get_iflytek_model(**kwargs: Any) -> Spark:
)


def get_alibaba_model(**kwargs: Any) -> Alibaba:
def get_alibaba_model(model_name, **kwargs: Any) -> Alibaba:
config: Dict[str, str] = read_config()
dashscope_api_key = config["DASHSCOPE_API_KEY"]
return Alibaba(dashscope_api_key=dashscope_api_key, **kwargs)
return Alibaba(model_name=model_name, dashscope_api_key=dashscope_api_key, **kwargs)


def run_model(
Expand Down
14 changes: 10 additions & 4 deletions automation-api/lib/pilot/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def run_evaluation(
logger.warning(
f"({model_config_id}, {prompt_var_id}, {survey_id}) has been evaluated."
)

session_df = pd.DataFrame.from_records(session_result)
session_df = SessionResultsDf.validate(session_df)
# write result to tmp file.
session_df.to_csv(out_file_path, index=False)
logger.info(f"session saved to {out_file_path}")

return session_df

Expand Down Expand Up @@ -143,7 +143,10 @@ def run_evaluation(
survey_id = get_survey_hash(questions)
survey = (survey_id, questions)

eval_llm = get_model("gpt-3.5-turbo", "OpenAI", {"temperature": 0})
# FIXME: add support to set eval llm and parameters.
eval_llm = get_model(
"gpt-3.5-turbo", "OpenAI", {"temperature": 0, "request_timeout": 120}
)

search_space = list(product(model_configs, prompt_variants))

Expand All @@ -156,8 +159,11 @@ def run_evaluation(
out_dir=args.tmp_dir,
)

with Pool(args.jobs) as p:
session_dfs = p.map(threaded_func, search_space)
if args.jobs == 1:
session_dfs = [threaded_func(v) for v in search_space]
else:
with Pool(args.jobs) as p:
session_dfs = p.map(threaded_func, search_space)

try:
session_df = pd.concat(session_dfs)
Expand Down
3 changes: 2 additions & 1 deletion automation-api/lib/pilot/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_model(model_id, vendor, model_conf):
elif vendor == "iFlyTek":
return get_iflytek_model(**model_conf)
elif vendor == "Alibaba":
return get_alibaba_model(**model_conf)
return get_alibaba_model(model_id, **model_conf)
else:
raise NotImplementedError(f"{model_id} from {vendor} is not supported yet.")

Expand Down Expand Up @@ -301,6 +301,7 @@ def run_survey(
"model_configuration_id": model_config_id,
"prompt_variation_id": prompt_id,
"question_id": question[0].question_id,
"language": question[0].language,
"question_number": i + 1,
}
question_data = create_question_data_for_test(question_tmpl, question)
Expand Down
Loading

0 comments on commit e9da232

Please sign in to comment.