Skip to content

Commit

Permalink
Fixes to most ruff linting errors (#2)
Browse files Browse the repository at this point in the history
Co-authored-by: Lukas Turcani <[email protected]>
  • Loading branch information
jezsadler and lukasturcani authored Aug 20, 2024
1 parent 09dad47 commit 48643bb
Show file tree
Hide file tree
Showing 51 changed files with 1,587 additions and 1,331 deletions.
14 changes: 7 additions & 7 deletions docs/notebooks/data/build_sin_quadratic_csv.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from random import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

n_samples = 10000
w = 5

x = np.linspace(-2, 2, n_samples)
df = pd.DataFrame(x, columns=["x"])
df["y"] = (
rng = np.random.default_rng()
sin_quads = pd.DataFrame(x, columns=["x"])
sin_quads["y"] = (
np.sin(w * x)
+ x**2
+ np.array([np.random.uniform() * 0.1 for _ in range(n_samples)])
+ np.array([rng.uniform() * 0.1 for _ in range(n_samples)])
)

plt.plot(df["x"], df["y"])
plt.plot(sin_quads["x"], sin_quads["y"])
plt.show()

df.to_csv("sin_quadratic.csv")
sin_quads.to_csv("sin_quadratic.csv")
14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ ignore = [
"ANN401",
"COM812",
"ISC001",
"SLF001",
"ARG001",
"N803",
"N806",
# Remove these after issue https://github.com/cog-imperial/OMLT/issues/153 is fixed.
"D100",
"D101",
"D102",
"D103",
"D104",
"D105",
"D106",
"D107",
# TODO: Remove these eventually
"ANN001",
"ANN002",
Expand Down Expand Up @@ -106,6 +119,7 @@ convention = "google"
"INP001",
]
"docs/conf.py" = ["D100", "INP001"]
"src/omlt/neuralnet/layer.py" = ["N802"]

[tool.mypy]
show_error_codes = true
Expand Down
10 changes: 4 additions & 6 deletions src/omlt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
"""
OMLT
====
"""OMLT.
OMLT is a Python package for representing machine learning models
(neural networks and gradient-boosted trees) within the Pyomo optimization environment.
OMLT is a Python package for representing machine learning models (neural networks
and gradient-boosted trees) within the Pyomo optimization environment.
The package provides various optimization formulations for machine learning models
(such as full-space, reduced-space, and MILP) as well as an interface to import
sequential Keras and general ONNX models.
"""

from omlt._version import __version__
from omlt.block import OmltBlock
from omlt.block import OmltBlock # type: ignore[attr-defined]
from omlt.scaling import OffsetScaling

__all__ = [
Expand Down
31 changes: 21 additions & 10 deletions src/omlt/block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""
"""OmltBlock.
The omlt.block module contains the implementation of the OmltBlock class. This
class is used in combination with a formulation object to construct the
necessary constraints and variables to represent ML models.
Expand All @@ -23,7 +24,6 @@ class is used in combination with a formulation object to construct the
pyo.assert_optimal_termination(status)
"""

import warnings

import pyomo.environ as pyo
from pyomo.core.base.block import _BlockData, declare_custom_block
Expand All @@ -32,13 +32,14 @@ class is used in combination with a formulation object to construct the
@declare_custom_block(name="OmltBlock")
class OmltBlockData(_BlockData):
def __init__(self, component):
super(OmltBlockData, self).__init__(component)
super().__init__(component)
self.__formulation = None
self.__input_indexes = None
self.__output_indexes = None

def _setup_inputs_outputs(self, *, input_indexes, output_indexes):
"""
"""Setup inputs and outputs.
This function should be called by the derived class to create the
inputs and outputs on the block
Expand All @@ -52,19 +53,15 @@ def _setup_inputs_outputs(self, *, input_indexes, output_indexes):
"""
self.__input_indexes = input_indexes
self.__output_indexes = output_indexes
if not input_indexes or not output_indexes:
# TODO: implement this check higher up in the class hierarchy to provide more contextual error msg
raise ValueError(
"OmltBlock must have at least one input and at least one output."
)

self.inputs_set = pyo.Set(initialize=input_indexes)
self.inputs = pyo.Var(self.inputs_set, initialize=0)
self.outputs_set = pyo.Set(initialize=output_indexes)
self.outputs = pyo.Var(self.outputs_set, initialize=0)

def build_formulation(self, formulation):
"""
"""Build formulation.
Call this method to construct the constraints (and possibly
intermediate variables) necessary for the particular neural network
formulation. The formulation object can be accessed later through the
Expand All @@ -75,6 +72,20 @@ def build_formulation(self, formulation):
formulation : instance of _PyomoFormulation
see, for example, FullSpaceNNFormulation
"""
if not formulation.input_indexes:
msg = (
"OmltBlock must have at least one input to build a formulation. "
f"{formulation} has no inputs."
)
raise ValueError(msg)

if not formulation.output_indexes:
msg = (
"OmltBlock must have at least one output to build a formulation. "
f"{formulation} has no outputs."
)
raise ValueError(msg)

self._setup_inputs_outputs(
input_indexes=list(formulation.input_indexes),
output_indexes=list(formulation.output_indexes),
Expand Down
36 changes: 23 additions & 13 deletions src/omlt/formulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@


class _PyomoFormulationInterface(abc.ABC):
"""
"""Pyomo Formulation Interface.
Base class interface for a Pyomo formulation object. This class
is largely internal, and developers of new formulations should derive from
_PyomoFormulation.
Expand All @@ -23,51 +24,60 @@ def _set_block(self, block):
@abc.abstractmethod
def block(self):
"""Return the block associated with this formulation."""
pass

@property
@abc.abstractmethod
def input_indexes(self):
"""Return the indices corresponding to the inputs of the
"""Input indexes.
Return the indices corresponding to the inputs of the
ML model. This is a list of entries (which may be tuples
for higher dimensional inputs).
"""
pass

@property
@abc.abstractmethod
def output_indexes(self):
"""Return the indices corresponding to the outputs of the
"""Output indexes.
Return the indices corresponding to the outputs of the
ML model. This is a list of entries (which may be tuples
for higher dimensional outputs).
"""
pass

@abc.abstractmethod
def _build_formulation(self):
"""This method is called by the OmltBlock object to build the
"""Build formulation.
This method is called by the OmltBlock object to build the
corresponding mathematical formulation of the model.
"""
pass


class _PyomoFormulation(_PyomoFormulationInterface):
"""
"""Pyomo Formulation.
This is a base class for different Pyomo formulations. To create a new
formulation, inherit from this class and implement the abstract methods and properties.
formulation, inherit from this class and implement the abstract methods
and properties.
"""

def __init__(self):
super(_PyomoFormulation, self).__init__()
self.__block = None

def _set_block(self, block):
self.__block = weakref.ref(block)

@property
def block(self):
"""The underlying block containing the constraints / variables for this formulation."""
return self.__block()
"""Block.
The underlying block containing the constraints / variables for this
formulation.
"""
if self.__block is not None:
return self.__block()
return None

Check warning on line 80 in src/omlt/formulation.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/formulation.py#L80

Added line #L80 was not covered by tests


def scalar_or_tuple(x):
Expand Down
5 changes: 4 additions & 1 deletion src/omlt/gbt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
r"""
r"""Gradient-Boosted Trees formulation.
We use the following notation to describe the gradient-boosted trees formulation:
.. math::
Expand All @@ -25,3 +26,5 @@

from omlt.gbt.gbt_formulation import GBTBigMFormulation
from omlt.gbt.model import GradientBoostedTreeModel

__all__ = ["GBTBigMFormulation", "GradientBoostedTreeModel"]
Loading

0 comments on commit 48643bb

Please sign in to comment.