Skip to content

Commit

Permalink
Handle batch case
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Dec 3, 2024
1 parent c86421c commit da279ac
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 19 deletions.
91 changes: 88 additions & 3 deletions src/optimagic/optimization/history.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from dataclasses import dataclass
from functools import partial
from typing import Any, Literal
from typing import Any, Callable, Iterable, Literal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -192,6 +192,16 @@ def flat_param_names(self) -> list[str]:
def _get_time(
self, cost_model: CostModel | Literal["wall_time"]
) -> NDArray[np.float64]:
"""Return the cumulative time measure.
Args:
cost_model: The cost model that is used to calculate the time measure. If
"wall_time", the wall time is returned.
Returns:
np.ndarray: The time measure.
"""
if not isinstance(cost_model, CostModel) and cost_model != "wall_time":
raise ValueError("cost_model must be a CostModel or 'wall_time'.")

Expand All @@ -207,11 +217,31 @@ def _get_time(
fun_and_jac_time = self._get_time_per_task(
task=EvalTask.FUN_AND_JAC, cost_factor=cost_model.fun_and_jac
)
return fun_time + jac_time + fun_and_jac_time

time = fun_time + jac_time + fun_and_jac_time
batch_time = _batch_apply(
data=time,
batch_ids=self.batches,
func=cost_model.aggregate_batch_time,
)
return np.cumsum(batch_time)

def _get_time_per_task(
self, task: EvalTask, cost_factor: float | None
) -> NDArray[np.float64]:
"""Return the time measure per task.
Args:
task: The task for which the time is calculated.
cost_factor: The cost factor used to calculate the time. If None, the time
is the difference between the start and stop time, otherwise the time
is given by the cost factor.
Returns:
np.ndarray: The time per task. For entries where the task is not the
requested task, the time is 0.
"""
dummy_task = np.array([1 if t == task else 0 for t in self.task])
if cost_factor is None:
factor: float | NDArray[np.float64] = np.array(
Expand All @@ -220,7 +250,7 @@ def _get_time_per_task(
else:
factor = cost_factor

return np.cumsum(factor * dummy_task)
return factor * dummy_task

@property
def start_time(self) -> list[float]:
Expand Down Expand Up @@ -351,3 +381,58 @@ def _task_as_categorical(task: list[EvalTask]) -> pd.Categorical:
return pd.Categorical(
[t.value for t in task], categories=[t.value for t in EvalTask]
)


def _batch_apply(
data: NDArray[np.float64],
batch_ids: list[int],
func: Callable[[Iterable[float]], float],
) -> NDArray[np.float64]:
"""Apply a reduction operator on batches of data.
Args:
data: 1d array with data.
batch_ids: A list whose length is equal to the size of data. Values need to be
sorted and can be repeated.
func: A reduction function that takes an iterable of floats as input (e.g., a
numpy array or a list) and returns a scalar.
Returns:
The transformed data. Has the same length as data. For each batch, the result of
the reduction operation is stored at the first index of that batch, and all
other values of that batch are set to zero.
"""
batch_start = _get_batch_start(batch_ids)
batch_stop = [*batch_start, len(data)][1:]

batch_result = []
for batch, (start, stop) in zip(
batch_ids, zip(batch_start, batch_stop, strict=False), strict=False
):
try:
batch_data = data[start:stop]
reduced = func(batch_data)
batch_result.append(reduced)
except Exception as e:
msg = (
f"Calling function {func.__name__} on batch {batch} of the History "
f"History raised an Exception. Please verify that {func.__name__} is "
"properly defined."
)
raise ValueError(msg) from e

out = np.zeros_like(data)
out[batch_start] = batch_result
return out


def _get_batch_start(batch_ids: list[int]) -> list[int]:
"""Get start indices of batch.
This function assumes that batch_ids non-empty and sorted.
"""
ids_arr = np.array(batch_ids, dtype=np.int64)
indices = np.where(ids_arr[:-1] != ids_arr[1:])[0] + 1
return np.insert(indices, 0, 0).tolist()
4 changes: 2 additions & 2 deletions src/optimagic/timing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Callable
from typing import Callable, Iterable


@dataclass(frozen=True)
Expand All @@ -8,7 +8,7 @@ class CostModel:
jac: float | None
fun_and_jac: float | None
label: str
aggregate_batch_time: Callable[[list[float]], float]
aggregate_batch_time: Callable[[Iterable[float]], float]


evaluation_time = CostModel(
Expand Down
85 changes: 71 additions & 14 deletions tests/optimagic/optimization/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from optimagic.optimization.history import (
History,
HistoryEntry,
_batch_apply,
_calculate_monotone_sequence,
_get_batch_start,
_get_flat_param_names,
_get_flat_params,
_is_1d_array,
Expand Down Expand Up @@ -143,8 +145,8 @@ def params():


@pytest.fixture
def history(params):
data = {
def history_data(params):
return {
"fun": [10, None, 9, None, 2, 5],
"task": [
EvalTask.FUN,
Expand All @@ -157,9 +159,19 @@ def history(params):
"start_time": [0, 2, 5, 7, 10, 12],
"stop_time": [1, 4, 6, 9, 11, 14],
"params": params,
"batches": [0, 0, 1, 1, 2, 2],
"batches": [0, 1, 2, 3, 4, 5],
}


@pytest.fixture
def history(history_data):
return History(direction=Direction.MINIMIZE, **history_data)


@pytest.fixture
def history_with_batch_data(history_data):
data = history_data.copy()
data["batches"] = [0, 0, 1, 1, 2, 2]
return History(direction=Direction.MINIMIZE, **data)


Expand Down Expand Up @@ -211,9 +223,8 @@ def test_history_fun_data_with_fun_evaluations_cost_model_and_monotone(history):
assert_frame_equal(got, exp, check_dtype=False, check_categorical=False)


@pytest.mark.xfail(reason="Must be fixed!")
def test_history_fun_data_with_fun_batches_cost_model(history):
got = history.fun_data(
def test_history_fun_data_with_fun_batches_cost_model(history_with_batch_data):
got = history_with_batch_data.fun_data(
cost_model=om.timing.fun_batches,
monotone=False,
)
Expand Down Expand Up @@ -328,23 +339,23 @@ def test_flat_param_names(history):

def test_get_time_per_task_fun(history):
got = history._get_time_per_task(EvalTask.FUN, cost_factor=1)
exp = np.array([1, 1, 2, 2, 3, 3])
exp = np.array([1, 0, 1, 0, 1, 0])
assert_array_equal(got, exp)


def test_get_time_per_task_jac(history):
got = history._get_time_per_task(EvalTask.JAC, cost_factor=1)
exp = np.array([0, 1, 1, 2, 2, 2])
def test_get_time_per_task_jac_cost_factor_none(history):
got = history._get_time_per_task(EvalTask.JAC, cost_factor=None)
exp = np.array([0, 2, 0, 2, 0, 0])
assert_array_equal(got, exp)


def test_get_time_per_task_fun_and_jac(history):
got = history._get_time_per_task(EvalTask.FUN_AND_JAC, cost_factor=1)
exp = np.array([0, 0, 0, 0, 0, 1])
got = history._get_time_per_task(EvalTask.FUN_AND_JAC, cost_factor=-0.5)
exp = np.array([0, 0, 0, 0, 0, -0.5])
assert_array_equal(got, exp)


def test_get_time_cost_model(history):
def test_get_time_custom_cost_model(history):
cost_model = om.timing.CostModel(
fun=0.5, jac=1, fun_and_jac=2, label="test", aggregate_batch_time=sum
)
Expand All @@ -362,6 +373,30 @@ def test_get_time_cost_model(history):
assert_array_equal(got, exp)


def test_get_time_fun_evaluations(history):
got = history._get_time(cost_model=om.timing.fun_evaluations)
exp = np.array([1, 1, 2, 2, 3, 4])
assert_array_equal(got, exp)


def test_get_time_fun_batches(history):
got = history._get_time(cost_model=om.timing.fun_batches)
exp = np.array([1, 1, 2, 2, 3, 4])
assert_array_equal(got, exp)


def test_get_time_fun_batches_with_batch_data(history_with_batch_data):
got = history_with_batch_data._get_time(cost_model=om.timing.fun_batches)
exp = np.array([1, 1, 2, 2, 3, 3])
assert_array_equal(got, exp)


def test_get_time_evaluation_time(history):
got = history._get_time(cost_model=om.timing.evaluation_time)
exp = np.array([1, 3, 4, 6, 7, 9])
assert_array_equal(got, exp)


def test_get_time_wall_time(history):
got = history._get_time(cost_model="wall_time")
exp = np.array([1, 4, 6, 9, 11, 14])
Expand All @@ -381,7 +416,7 @@ def test_stop_time_property(history):


def test_batches_property(history):
assert history.batches == [0, 0, 1, 1, 2, 2]
assert history.batches == [0, 1, 2, 3, 4, 5]


# Tasks
Expand Down Expand Up @@ -466,3 +501,25 @@ def test_task_as_categorical():
got = _task_as_categorical(task)
assert got.tolist() == ["fun", "jac", "fun_and_jac"]
assert isinstance(got.dtype, pd.CategoricalDtype)


def test_get_batch_start():
batches = [0, 0, 1, 1, 1, 2, 2, 3]
got = _get_batch_start(batches)
assert got == [0, 2, 5, 7]


def test_batch_apply_sum():
data = np.array([0, 1, 2, 3, 4])
batch_ids = [0, 0, 1, 1, 2]
exp = np.array([1, 0, 5, 0, 4])
got = _batch_apply(data, batch_ids, sum)
assert_array_equal(exp, got)


def test_batch_apply_max():
data = np.array([0, 1, 2, 3, 4])
batch_ids = [0, 0, 1, 1, 2]
exp = np.array([1, 0, 3, 0, 4])
got = _batch_apply(data, batch_ids, max)
assert_array_equal(exp, got)

0 comments on commit da279ac

Please sign in to comment.