Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel processing of experiment #31

Merged
merged 5 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Before running the script, ensure you have the following:
- Azure OpenAI
- Need to set `OPENAI_API_KEY`, `AZURE_OPENAI_API_ENDPOINT` environment variables. You can also set the `AZURE_OPENAI_API_VERSION` variable too. Also recommended to set the `AZURE_OPENAI_MODEL_ID` in the environment variable to either avoid passing in the `model_name` each time if using the same one consistently.
- Gemini
- Need to set `GEMINI_PROJECT_ID`, and `GEMINI_LOCATION` environment variables. Also recommended to set the `GEMINI_MODEL_ID` in the environment variable to either avoid passing in the `model_name` each time if using the same one consistently.
- Need to set `GEMINI_PROJECT_ID`, and `GEMINI_LOCATION` environment variables. Also recommended to set the `GEMINI_MODEL_NAME` in the environment variable to either avoid passing in the `model_name` each time if using the same one consistently.

### Installation

Expand Down
64 changes: 57 additions & 7 deletions src/batch_llm/experiment_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,29 @@ def __init__(
self.output_folder, f"{self.creation_time}-log.txt"
)

# grouped experiment prompts by model
self.grouped_experiment_prompts: dict[str, list[dict]] = (
self.group_prompts_by_model()
)

def __str__(self) -> str:
return self.file_name

def group_prompts_by_model(self) -> dict[str, list[dict]]:
# return self.grouped_experiment_prompts if it exists
if hasattr(self, "grouped_experiment_prompts"):
return self.grouped_experiment_prompts

grouped_dict = {}
for item in self.experiment_prompts:
model = item.get("model")
if model not in grouped_dict:
grouped_dict[model] = [item]

grouped_dict[model].append(item)

return grouped_dict


class ExperimentPipeline:
def __init__(
Expand Down Expand Up @@ -190,10 +210,34 @@ async def process_experiment(
self.log_estimate(experiment=experiment)

# run the experiment asynchronously
logging.info(f"Sending {experiment.number_queries} queries...")
await self.send_requests_retry(
experiment=experiment,
)
if self.settings.parallel:
logging.info(
f"Sending {experiment.number_queries} queries in parallel by grouping models..."
)
queries_per_model = {
model: len(prompts)
for model, prompts in experiment.grouped_experiment_prompts.items()
}
logging.info(f"Queries per model: {queries_per_model}")

tasks = [
asyncio.create_task(
self.send_requests_retry(
experiment=experiment, prompt_dicts=prompt_dicts, model=model
)
)
for model, prompt_dicts in experiment.grouped_experiment_prompts.items()
]
await tqdm_asyncio.gather(
*tasks, desc="Waiting for all models to complete", unit="model"
)
else:
logging.info(f"Sending {experiment.number_queries} queries...")
await self.send_requests_retry(
experiment=experiment,
prompt_dicts=experiment.experiment_prompts,
model=None,
)

# calculate average processing time per query for the experiment
end_time = time.time()
Expand Down Expand Up @@ -223,17 +267,19 @@ async def send_requests(
experiment: Experiment,
prompt_dicts: list[dict],
attempt: int,
model: str | None = None,
) -> tuple[list[dict], list[dict | Exception]]:
"""
Send requests to the API asynchronously.
"""
request_interval = 60 / self.settings.max_queries
tasks = []
for_model_string = f"for model {model} " if model is not None else ""

for index, item in enumerate(
tqdm(
prompt_dicts,
desc=f"Sending {len(prompt_dicts)} queries",
desc=f"Sending {len(prompt_dicts)} queries {for_model_string}",
unit="query",
)
):
Expand All @@ -254,14 +300,16 @@ async def send_requests(

# wait for all tasks to complete before returning
responses = await tqdm_asyncio.gather(
*tasks, desc="Waiting for responses", unit="query"
*tasks, desc=f"Waiting for responses {for_model_string}", unit="query"
)

return prompt_dicts, responses

async def send_requests_retry(
self,
experiment: Experiment,
prompt_dicts: list[dict],
model: str | None = None,
) -> None:
"""
Send requests to the API asynchronously and retry failed queries
Expand All @@ -273,8 +321,9 @@ async def send_requests_retry(
# send off the requests
remaining_prompt_dicts, responses = await self.send_requests(
experiment=experiment,
prompt_dicts=experiment.experiment_prompts,
prompt_dicts=prompt_dicts,
attempt=attempt,
model=model,
)

while True:
Expand All @@ -300,6 +349,7 @@ async def send_requests_retry(
experiment=experiment,
prompt_dicts=remaining_prompt_dicts,
attempt=attempt,
model=model,
)
else:
# if there are no failed queries, break out of the loop
Expand Down
2 changes: 1 addition & 1 deletion src/batch_llm/models/gemini/gemini_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def check_environment_variables() -> list[Exception]:
if var not in os.environ:
issues.append(ValueError(f"Environment variable {var} is not set"))

other_env_vars = ["GEMINI_MODEL_ID"]
other_env_vars = ["GEMINI_MODEL_NAME"]
for var in other_env_vars:
if var not in os.environ:
issues.append(Warning(f"Environment variable {var} is not set"))
Expand Down
17 changes: 12 additions & 5 deletions src/batch_llm/models/ollama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,18 @@ async def _async_query_string(self, prompt_dict: dict, index: int | str) -> dict
prompt_dict["response"] = response_text
return prompt_dict
except ResponseError as err:
# if there's a response error due to a model not being downloaded,
# raise a NotImplementedError so that it doesn't get retried
raise NotImplementedError(
f"Model {model_name} is not downloaded: {type(err).__name__} - {err}"
)
if "try pulling it first" in str(err):
# if there's a response error due to a model not being downloaded,
# raise a NotImplementedError so that it doesn't get retried
raise NotImplementedError(
f"Model {model_name} is not downloaded: {type(err).__name__} - {err}"
)
elif "invalid options" in str(err):
# if there's a response error due to invalid options, raise a ValueError
# so that it doesn't get retried
raise ValueError(
f"Invalid options for model {model_name}: {type(err).__name__} - {err}"
)
except Exception as err:
error_as_string = f"{type(err).__name__} - {err}"
log_message = log_error_response_query(
Expand Down
8 changes: 8 additions & 0 deletions src/batch_llm/scripts/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def main():
type=int,
default=5,
)
parser.add_argument(
"--parallel",
"-p",
help="Run the pipeline in parallel",
action="store_true",
default=False,
)
args = parser.parse_args()

# initialise logging
Expand All @@ -47,6 +54,7 @@ def main():
data_folder=args.data_folder,
max_queries=args.max_queries,
max_attempts=args.max_attempts,
parallel=args.parallel,
)
# log the settings that are set for the pipeline
logging.info(settings)
Expand Down
10 changes: 8 additions & 2 deletions src/batch_llm/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ class Settings:
# - max_attempts (maximum number of attempts when retrying)

def __init__(
self, data_folder: str = "data", max_queries: int = 10, max_attempts: int = 3
self,
data_folder: str = "data",
max_queries: int = 10,
max_attempts: int = 3,
parallel: bool = False,
):
self._data_folder = data_folder
# check the data folder exists
Expand All @@ -23,11 +27,13 @@ def __init__(
self.set_and_create_subfolders()
self._max_queries = max_queries
self._max_attempts = max_attempts
self.parallel = parallel

def __str__(self) -> str:
return (
f"Settings: data_folder={self.data_folder}, "
f"max_queries={self.max_queries}, max_attempts={self.max_attempts}\n"
f"max_queries={self.max_queries}, max_attempts={self.max_attempts}, "
f"parallel={self.parallel}\n"
f"Subfolders: input_folder={self.input_folder}, "
f"output_folder={self.output_folder}, media_folder={self.media_folder}"
)
Expand Down
19 changes: 17 additions & 2 deletions tests/core/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def test_settings_default_init(temporary_data_folders):
assert settings.data_folder == "data"
assert settings.max_queries == 10
assert settings.max_attempts == 3
assert settings.parallel is False

# check the subfolders
assert settings.input_folder == "data/input"
Expand All @@ -25,12 +26,15 @@ def test_settings_default_init(temporary_data_folders):


def test_settings_custom_init(temporary_data_folders):
settings = Settings(data_folder="dummy_data", max_queries=20, max_attempts=5)
settings = Settings(
data_folder="dummy_data", max_queries=20, max_attempts=5, parallel=True
)

# check the custom values
assert settings.data_folder == "dummy_data"
assert settings.max_queries == 20
assert settings.max_attempts == 5
assert settings.parallel is True

# check the subfolders
assert settings.input_folder == "dummy_data/input"
Expand All @@ -48,7 +52,7 @@ def test_settings_str(temporary_data_folders):

# when printing, it should show the settings and subfolders
assert str(settings) == (
"Settings: data_folder=data, max_queries=10, max_attempts=3\n"
"Settings: data_folder=data, max_queries=10, max_attempts=3, parallel=False\n"
"Subfolders: input_folder=data/input, output_folder=data/output, media_folder=data/media"
)

Expand Down Expand Up @@ -256,3 +260,14 @@ def test_max_attempts_setter(temporary_data_folders):
# set it to a different value
settings.max_attempts = 5
assert settings.max_attempts == 5


def test_parallel_getter_and_setter(temporary_data_folders):
settings = Settings()

# check the default value
assert settings.parallel is False

# set it to a different value
settings.parallel = True
assert settings.parallel is True
Loading