Skip to content

Commit

Permalink
Merge pull request #31 from alan-turing-institute/parallel-processing
Browse files Browse the repository at this point in the history
Parallel processing of experiment
  • Loading branch information
rchan26 authored Apr 26, 2024
2 parents c468d09 + 8036aad commit b11a33e
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Before running the script, ensure you have the following:
- Azure OpenAI
- Need to set `OPENAI_API_KEY`, `AZURE_OPENAI_API_ENDPOINT` environment variables. You can also set the `AZURE_OPENAI_API_VERSION` variable too. Also recommended to set the `AZURE_OPENAI_MODEL_ID` in the environment variable to either avoid passing in the `model_name` each time if using the same one consistently.
- Gemini
- Need to set `GEMINI_PROJECT_ID`, and `GEMINI_LOCATION` environment variables. Also recommended to set the `GEMINI_MODEL_ID` in the environment variable to either avoid passing in the `model_name` each time if using the same one consistently.
- Need to set `GEMINI_PROJECT_ID`, and `GEMINI_LOCATION` environment variables. Also recommended to set the `GEMINI_MODEL_NAME` in the environment variable to either avoid passing in the `model_name` each time if using the same one consistently.

### Installation

Expand Down
64 changes: 57 additions & 7 deletions src/batch_llm/experiment_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,29 @@ def __init__(
self.output_folder, f"{self.creation_time}-log.txt"
)

# grouped experiment prompts by model
self.grouped_experiment_prompts: dict[str, list[dict]] = (
self.group_prompts_by_model()
)

def __str__(self) -> str:
return self.file_name

def group_prompts_by_model(self) -> dict[str, list[dict]]:
# return self.grouped_experiment_prompts if it exists
if hasattr(self, "grouped_experiment_prompts"):
return self.grouped_experiment_prompts

grouped_dict = {}
for item in self.experiment_prompts:
model = item.get("model")
if model not in grouped_dict:
grouped_dict[model] = [item]

grouped_dict[model].append(item)

return grouped_dict


class ExperimentPipeline:
def __init__(
Expand Down Expand Up @@ -190,10 +210,34 @@ async def process_experiment(
self.log_estimate(experiment=experiment)

# run the experiment asynchronously
logging.info(f"Sending {experiment.number_queries} queries...")
await self.send_requests_retry(
experiment=experiment,
)
if self.settings.parallel:
logging.info(
f"Sending {experiment.number_queries} queries in parallel by grouping models..."
)
queries_per_model = {
model: len(prompts)
for model, prompts in experiment.grouped_experiment_prompts.items()
}
logging.info(f"Queries per model: {queries_per_model}")

tasks = [
asyncio.create_task(
self.send_requests_retry(
experiment=experiment, prompt_dicts=prompt_dicts, model=model
)
)
for model, prompt_dicts in experiment.grouped_experiment_prompts.items()
]
await tqdm_asyncio.gather(
*tasks, desc="Waiting for all models to complete", unit="model"
)
else:
logging.info(f"Sending {experiment.number_queries} queries...")
await self.send_requests_retry(
experiment=experiment,
prompt_dicts=experiment.experiment_prompts,
model=None,
)

# calculate average processing time per query for the experiment
end_time = time.time()
Expand Down Expand Up @@ -223,17 +267,19 @@ async def send_requests(
experiment: Experiment,
prompt_dicts: list[dict],
attempt: int,
model: str | None = None,
) -> tuple[list[dict], list[dict | Exception]]:
"""
Send requests to the API asynchronously.
"""
request_interval = 60 / self.settings.max_queries
tasks = []
for_model_string = f"for model {model} " if model is not None else ""

for index, item in enumerate(
tqdm(
prompt_dicts,
desc=f"Sending {len(prompt_dicts)} queries",
desc=f"Sending {len(prompt_dicts)} queries {for_model_string}",
unit="query",
)
):
Expand All @@ -254,14 +300,16 @@ async def send_requests(

# wait for all tasks to complete before returning
responses = await tqdm_asyncio.gather(
*tasks, desc="Waiting for responses", unit="query"
*tasks, desc=f"Waiting for responses {for_model_string}", unit="query"
)

return prompt_dicts, responses

async def send_requests_retry(
self,
experiment: Experiment,
prompt_dicts: list[dict],
model: str | None = None,
) -> None:
"""
Send requests to the API asynchronously and retry failed queries
Expand All @@ -273,8 +321,9 @@ async def send_requests_retry(
# send off the requests
remaining_prompt_dicts, responses = await self.send_requests(
experiment=experiment,
prompt_dicts=experiment.experiment_prompts,
prompt_dicts=prompt_dicts,
attempt=attempt,
model=model,
)

while True:
Expand All @@ -300,6 +349,7 @@ async def send_requests_retry(
experiment=experiment,
prompt_dicts=remaining_prompt_dicts,
attempt=attempt,
model=model,
)
else:
# if there are no failed queries, break out of the loop
Expand Down
2 changes: 1 addition & 1 deletion src/batch_llm/models/gemini/gemini_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def check_environment_variables() -> list[Exception]:
if var not in os.environ:
issues.append(ValueError(f"Environment variable {var} is not set"))

other_env_vars = ["GEMINI_MODEL_ID"]
other_env_vars = ["GEMINI_MODEL_NAME"]
for var in other_env_vars:
if var not in os.environ:
issues.append(Warning(f"Environment variable {var} is not set"))
Expand Down
17 changes: 12 additions & 5 deletions src/batch_llm/models/ollama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,18 @@ async def _async_query_string(self, prompt_dict: dict, index: int | str) -> dict
prompt_dict["response"] = response_text
return prompt_dict
except ResponseError as err:
# if there's a response error due to a model not being downloaded,
# raise a NotImplementedError so that it doesn't get retried
raise NotImplementedError(
f"Model {model_name} is not downloaded: {type(err).__name__} - {err}"
)
if "try pulling it first" in str(err):
# if there's a response error due to a model not being downloaded,
# raise a NotImplementedError so that it doesn't get retried
raise NotImplementedError(
f"Model {model_name} is not downloaded: {type(err).__name__} - {err}"
)
elif "invalid options" in str(err):
# if there's a response error due to invalid options, raise a ValueError
# so that it doesn't get retried
raise ValueError(
f"Invalid options for model {model_name}: {type(err).__name__} - {err}"
)
except Exception as err:
error_as_string = f"{type(err).__name__} - {err}"
log_message = log_error_response_query(
Expand Down
8 changes: 8 additions & 0 deletions src/batch_llm/scripts/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def main():
type=int,
default=5,
)
parser.add_argument(
"--parallel",
"-p",
help="Run the pipeline in parallel",
action="store_true",
default=False,
)
args = parser.parse_args()

# initialise logging
Expand All @@ -47,6 +54,7 @@ def main():
data_folder=args.data_folder,
max_queries=args.max_queries,
max_attempts=args.max_attempts,
parallel=args.parallel,
)
# log the settings that are set for the pipeline
logging.info(settings)
Expand Down
10 changes: 8 additions & 2 deletions src/batch_llm/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ class Settings:
# - max_attempts (maximum number of attempts when retrying)

def __init__(
self, data_folder: str = "data", max_queries: int = 10, max_attempts: int = 3
self,
data_folder: str = "data",
max_queries: int = 10,
max_attempts: int = 3,
parallel: bool = False,
):
self._data_folder = data_folder
# check the data folder exists
Expand All @@ -23,11 +27,13 @@ def __init__(
self.set_and_create_subfolders()
self._max_queries = max_queries
self._max_attempts = max_attempts
self.parallel = parallel

def __str__(self) -> str:
return (
f"Settings: data_folder={self.data_folder}, "
f"max_queries={self.max_queries}, max_attempts={self.max_attempts}\n"
f"max_queries={self.max_queries}, max_attempts={self.max_attempts}, "
f"parallel={self.parallel}\n"
f"Subfolders: input_folder={self.input_folder}, "
f"output_folder={self.output_folder}, media_folder={self.media_folder}"
)
Expand Down
19 changes: 17 additions & 2 deletions tests/core/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def test_settings_default_init(temporary_data_folders):
assert settings.data_folder == "data"
assert settings.max_queries == 10
assert settings.max_attempts == 3
assert settings.parallel is False

# check the subfolders
assert settings.input_folder == "data/input"
Expand All @@ -25,12 +26,15 @@ def test_settings_default_init(temporary_data_folders):


def test_settings_custom_init(temporary_data_folders):
settings = Settings(data_folder="dummy_data", max_queries=20, max_attempts=5)
settings = Settings(
data_folder="dummy_data", max_queries=20, max_attempts=5, parallel=True
)

# check the custom values
assert settings.data_folder == "dummy_data"
assert settings.max_queries == 20
assert settings.max_attempts == 5
assert settings.parallel is True

# check the subfolders
assert settings.input_folder == "dummy_data/input"
Expand All @@ -48,7 +52,7 @@ def test_settings_str(temporary_data_folders):

# when printing, it should show the settings and subfolders
assert str(settings) == (
"Settings: data_folder=data, max_queries=10, max_attempts=3\n"
"Settings: data_folder=data, max_queries=10, max_attempts=3, parallel=False\n"
"Subfolders: input_folder=data/input, output_folder=data/output, media_folder=data/media"
)

Expand Down Expand Up @@ -256,3 +260,14 @@ def test_max_attempts_setter(temporary_data_folders):
# set it to a different value
settings.max_attempts = 5
assert settings.max_attempts == 5


def test_parallel_getter_and_setter(temporary_data_folders):
settings = Settings()

# check the default value
assert settings.parallel is False

# set it to a different value
settings.parallel = True
assert settings.parallel is True

0 comments on commit b11a33e

Please sign in to comment.