diff --git a/lowtime/__init__.py b/lowtime/__init__.py index 71b63b8..ee2edd1 100644 --- a/lowtime/__init__.py +++ b/lowtime/__init__.py @@ -14,4 +14,4 @@ """A library for solving the time-cost tradeoff problem.""" -__version__ = "0.1.0" +__version__ = "0.2.0" diff --git a/lowtime/cost_model.py b/lowtime/cost_model.py index 376f0f5..44a2b7d 100644 --- a/lowtime/cost_model.py +++ b/lowtime/cost_model.py @@ -197,6 +197,10 @@ def l2_error(coefficients: tuple[float, float, float]) -> float: if error == np.inf: continue + # The exponential model must be convex. + if opt_a < 0.0: + continue + # We have coefficients that somewhat fit the data. logger.info( "Initial guess %s fit with coefficients %s and error %f.", @@ -224,6 +228,6 @@ def l2_error(coefficients: tuple[float, float, float]) -> float: return best_coefficients @lru_cache - def __call__(self, quant_time: int) -> float: + def __call__(self, quant_time: int) -> float: # pyright: ignore """Predict execution cost given quantized time.""" return self.fn(quant_time, *self.coefficients) diff --git a/lowtime/graph_utils.py b/lowtime/graph_utils.py index 3a09ac8..72859ad 100644 --- a/lowtime/graph_utils.py +++ b/lowtime/graph_utils.py @@ -279,6 +279,9 @@ def aoa_to_critical_dag(aoa_dag: nx.DiGraph, attr_name: str = "op") -> nx.DiGrap - The graph has only one source node, annotated as "source_node" on the graph. - The graph has only one sink node, annotated as "sink_node" on the graph. """ + if not nx.is_directed_acyclic_graph(aoa_dag): + raise ValueError("The given graph is not a DAG.") + # Clear all earliest/latest start/end times. for _, _, edge_attr in aoa_dag.edges(data=True): operation: Operation = edge_attr[attr_name] diff --git a/lowtime/operation.py b/lowtime/operation.py index 7051afb..91cccc9 100644 --- a/lowtime/operation.py +++ b/lowtime/operation.py @@ -94,9 +94,12 @@ class CandidateExecutionOptions(Generic[KnobT]): Args: options: All candidate execution options of the operation. + noise_factor: A factor to multiply `real_time` and `cost` by to allow some slack. + """ options: list[ExecutionOption[KnobT]] + noise_factor: float = 1.0 _knob_cache: dict[int, KnobT] = field(init=False, repr=False, factory=dict) def __attrs_post_init__(self) -> None: @@ -105,8 +108,10 @@ def __attrs_post_init__(self) -> None: orig_options = sorted(self.options, key=lambda x: x.real_time, reverse=True) filtered_options: list[ExecutionOption[KnobT]] = [] for option in orig_options: + real_time = option.real_time * self.noise_factor + cost = option.cost * self.noise_factor if any( - other.real_time < option.real_time and other.cost < option.cost + other.real_time < real_time and other.cost < cost for other in orig_options ): continue @@ -215,7 +220,7 @@ def __attrs_post_init__(self) -> None: self.min_duration = min(quant_times) # By default, execute with the slowest speed. `assigned_knob` will - # automatically be set by `duration_setter`. + # automatically be set by `_knob_setter`. self.duration = self.max_duration def __str__(self) -> str: @@ -257,6 +262,6 @@ def __str__(self) -> str: """Return a readable string representation.""" return "DummyOperation()" - def get_cost(self, _: int | None = None) -> float: + def get_cost(self, _: int | None = None) -> float: # pyright: ignore """No cost for dummy operations.""" raise AttributeError("DummyOperation has no cost.") diff --git a/pyproject.toml b/pyproject.toml index b0b8afe..a8296f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,9 @@ include = ["lowtime*"] exclude = ["examples", "scripts", "stubs", "tests"] [tool.ruff] +line-length = 120 + +[tool.ruff.lint] select = [ "E", # pycodestyle error "F", # pyflakes @@ -69,12 +72,11 @@ ignore = [ "SIM115", # Context manager for opening files "E501", # Line too long ] -line-length = 120 -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "google" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "**/__init__.py" = ["F401", "F403"] [tool.pyright] diff --git a/scripts/lint.sh b/scripts/lint.sh index 8ebdebc..1aa50c6 100755 --- a/scripts/lint.sh +++ b/scripts/lint.sh @@ -8,5 +8,5 @@ else black --check lowtime fi -ruff lowtime +ruff check lowtime pyright lowtime