From ec2d7879033a9f534a06b5e3cb9da2f64fe784fe Mon Sep 17 00:00:00 2001 From: Robin Kupper Date: Sun, 28 May 2023 10:21:52 +0200 Subject: [PATCH] Added more type annotations --- donk/costs/cost_function_symbolic.py | 4 +-- donk/costs/losses.py | 13 ++++----- donk/costs/quadratic_costs.py | 14 ++++++---- donk/datalogging/shelve.py | 2 +- donk/dynamics/prior/prior.py | 4 ++- donk/dynamics/prior/prior_gmm.py | 2 +- donk/traj_opt/lqg.py | 20 ++++++++------ donk/utils/batched.py | 4 +-- donk/visualization/costs.py | 2 +- donk/visualization/states.py | 4 ++- setup.cfg | 2 +- tests/data/__init__.py | 5 ---- tests/data/data.py | 34 ----------------------- tests/utils.py | 41 ++++++++++++++++++++++++++-- 14 files changed, 78 insertions(+), 73 deletions(-) delete mode 100644 tests/data/__init__.py delete mode 100644 tests/data/data.py diff --git a/donk/costs/cost_function_symbolic.py b/donk/costs/cost_function_symbolic.py index ef34070..8fb3342 100644 --- a/donk/costs/cost_function_symbolic.py +++ b/donk/costs/cost_function_symbolic.py @@ -8,7 +8,7 @@ from donk.costs.quadratic_costs import QuadraticCosts -def _vectorize_cost_function(fun): +def _vectorize_cost_function(fun) -> Callable[[np.ndarray, np.ndarray], np.ndarray]: return np.vectorize(lambda X, U: np.array(fun(X, U)), signature="(v,x),(t,u)->(v)") @@ -113,7 +113,7 @@ def __init__( """Initialize this `MultipartSymbolicCostFunction`.""" from sympy import lambdify, symbols - def cost_fun(X, U): + def cost_fun(X: np.ndarray, U: np.ndarray) -> np.ndarray: """Sum up individual parts.""" cost = cost_funs[0](X, U) for fn in cost_funs[1:]: diff --git a/donk/costs/losses.py b/donk/costs/losses.py index 13e1e6c..dddffc9 100644 --- a/donk/costs/losses.py +++ b/donk/costs/losses.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import numpy as np -def loss_combined(x, losses): +def loss_combined(x: np.ndarray, losses) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Evaluates and sums up multiple loss functions. Args: @@ -9,7 +11,6 @@ def loss_combined(x, losses): losses: List of tuples (`loss`, `kwargs`) `loss`: loss function to evaluate. `kwargs`: Addional arguments passed to the loss function (optional). - """ if len(losses) < 1: raise ValueError("loss_combined requred at least one loss function to sum up.") @@ -24,7 +25,7 @@ def loss_combined(x, losses): return l, lx, lxx -def loss_l2(x, t, w): +def loss_l2(x: np.ndarray, t: np.ndarray, w: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Evaluate and compute derivatives for l2 norm penalty. loss = sum(0.5 * (x - t)^2 * w) @@ -38,7 +39,6 @@ def loss_l2(x, t, w): l: (T,) cost at each timestep. lx: (T, D) first order derivative. lxx: (T, D, D) second order derivative. - """ # Get trajectory length. _, dX = x.shape @@ -60,7 +60,7 @@ def loss_l2(x, t, w): return l, lx, lxx -def loss_l1(x, t, w, alpha): +def loss_l1(x: np.ndarray, t: np.ndarray, w: np.ndarray, alpha: float) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Evaluate and compute derivatives for l2 norm penalty. loss = sum(sqrt((x - t)^2 + alpha) * w) @@ -74,7 +74,6 @@ def loss_l1(x, t, w, alpha): l: (T,) cost at each timestep. lx: (T, D) first order derivative. lxx: (T, D, D) second order derivative. - """ # Get trajectory length. _, dX = x.shape @@ -97,7 +96,7 @@ def loss_l1(x, t, w, alpha): return l, lx, lxx -def loss_log_cosh(x, t, w): +def loss_log_cosh(x: np.ndarray, t: np.ndarray, w: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Evaluate and compute derivatives for log-cosh loss. loss = sum(log(cosh(x - t)) * w) diff --git a/donk/costs/quadratic_costs.py b/donk/costs/quadratic_costs.py index c6aed6f..bb5edc3 100644 --- a/donk/costs/quadratic_costs.py +++ b/donk/costs/quadratic_costs.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + import numpy as np from donk.costs.cost_function import CostFunction @@ -13,7 +15,7 @@ class QuadraticCosts(CostFunction): cost(x, u) = 1/2 [x u]^T*C*[x u] + [x u]^T*c + cc """ - def __init__(self, C: np.ndarray, c: np.ndarray, cc: np.ndarray): + def __init__(self, C: np.ndarray, c: np.ndarray, cc: np.ndarray) -> None: """Initialize this LinearDynamics object. Args: @@ -34,20 +36,20 @@ def __init__(self, C: np.ndarray, c: np.ndarray, cc: np.ndarray): self.c = c self.cc = cc - def __add__(self, other) -> QuadraticCosts: + def __add__(self, other: Any) -> QuadraticCosts: """Sum two cost functions.""" if isinstance(other, QuadraticCosts): return QuadraticCosts(self.C + other.C, self.c + other.c, self.cc + other.cc) return NotImplemented - def __mul__(self, other) -> QuadraticCosts: - """Scale cost function with vonstant scalar.""" + def __mul__(self, other: Any) -> QuadraticCosts: + """Scale cost function with constant scalar.""" if np.isscalar(other): return QuadraticCosts(self.C * other, self.c * other, self.cc * other) return NotImplemented - def __rmul__(self, other) -> QuadraticCosts: - """Scale cost function with vonstant scalar.""" + def __rmul__(self, other: Any) -> QuadraticCosts: + """Scale cost function with constant scalar.""" return self.__mul__(other) def compute_costs(self, X: np.ndarray, U: np.ndarray) -> np.ndarray: diff --git a/donk/datalogging/shelve.py b/donk/datalogging/shelve.py index f7d6a5a..d246191 100644 --- a/donk/datalogging/shelve.py +++ b/donk/datalogging/shelve.py @@ -35,7 +35,7 @@ def __enter__(self) -> ShelveDataLogger: return super().__enter__() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb) -> None: super().__exit__(exc_type, exc_val, exc_tb) # Close shelve diff --git a/donk/dynamics/prior/prior.py b/donk/dynamics/prior/prior.py index 9bb9f65..f6b623d 100644 --- a/donk/dynamics/prior/prior.py +++ b/donk/dynamics/prior/prior.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from dataclasses import dataclass @@ -43,7 +45,7 @@ def map_covar(self) -> np.ndarray: return self.Phi / (self.N_covar + d + 1) @staticmethod - def non_informative_prior(d): + def non_informative_prior(d: int) -> NormalInverseWishart: """Create a non-onformative prior distribution. Args: diff --git a/donk/dynamics/prior/prior_gmm.py b/donk/dynamics/prior/prior_gmm.py index d2e329f..aa3b5ac 100644 --- a/donk/dynamics/prior/prior_gmm.py +++ b/donk/dynamics/prior/prior_gmm.py @@ -7,7 +7,7 @@ class GMMPrior(DynamicsPrior): """A Gaussian Mixture Model (GMM) based prior.""" - def __init__(self, n_clusters, random_state=None) -> None: + def __init__(self, n_clusters: int, random_state=None) -> None: """Initialize this `GMMPrior`.""" self.gmm = GaussianMixture(n_components=n_clusters, random_state=random_state) diff --git a/donk/traj_opt/lqg.py b/donk/traj_opt/lqg.py index 8e444e0..334722e 100644 --- a/donk/traj_opt/lqg.py +++ b/donk/traj_opt/lqg.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass import numpy as np @@ -72,7 +74,7 @@ def step(self, eta: float) -> ILQRStepResult: return ILQRStepResult(eta, pol, kl_div, expected_costs, traj) - def sample_surface(self, min_eta: float = 1e-6, max_eta: float = 1e16, N: int = 100): + def sample_surface(self, min_eta: float = 1e-6, max_eta: float = 1e16, N: int = 100) -> list[ILQRStepResult]: """Sample the Lagrangian at different values for eta. For visualization/debugging purposes. @@ -85,7 +87,9 @@ def sample_surface(self, min_eta: float = 1e-6, max_eta: float = 1e16, N: int = results = [self.step(eta) for eta in np.logspace(np.log10(min_eta), np.log10(max_eta), num=N)] return results - def optimize(self, kl_step: float, min_eta: float = 1e-6, max_eta: float = 1e16, rtol: float = 1e-2): + def optimize( + self, kl_step: float, min_eta: float = 1e-6, max_eta: float = 1e16, rtol: float = 1e-2 + ) -> ILQRStepResult: """Perform iLQG trajectory optimization. Args: @@ -109,7 +113,7 @@ def optimize(self, kl_step: float, min_eta: float = 1e-6, max_eta: float = 1e16, raise ValueError(f"max_eta eta to low ({max_eta})") # Find the point where kl divergence equals the kl_step - def constraint_violation(log_eta): + def constraint_violation(log_eta: float) -> float: return self.step(np.exp(log_eta)).kl_div - kl_step # Search root of the constraint violation @@ -119,7 +123,7 @@ def constraint_violation(log_eta): return self.step(np.exp(log_eta)) -def backward(dynamics: LinearDynamics, C, c, gamma=1) -> LinearGaussianPolicy: +def backward(dynamics: LinearDynamics, C: np.ndarray, c: np.ndarray, gamma: float = 1) -> LinearGaussianPolicy: """Perform LQR backward pass. `C` is required to be symmetric. @@ -182,7 +186,7 @@ def forward( dynamics: LinearDynamics, policy: LinearGaussianPolicy, initial_state: StateDistribution, - regularization=1e-6, + regularization: float = 1e-6, ) -> TrajectoryDistribution: """Perform LQR forward pass. @@ -240,7 +244,7 @@ def forward( return TrajectoryDistribution(traj_mean, traj_covar, dX, dU) -def extended_costs_kl(prev_pol: LinearGaussianPolicy): +def extended_costs_kl(prev_pol: LinearGaussianPolicy) -> tuple[np.ndarray, np.ndarray]: """Compute expansion of extended cost used in the iLQR backward pass. The extended cost function is -log p(u_t | x_t) with p being the previous trajectory distribution. @@ -269,7 +273,7 @@ def extended_costs_kl(prev_pol: LinearGaussianPolicy): return C, c -def kl_divergence_action(X, pol: LinearGaussianPolicy, prev_pol: LinearGaussianPolicy): +def kl_divergence_action(X: np.ndarray, pol: LinearGaussianPolicy, prev_pol: LinearGaussianPolicy) -> float: """Compute KL divergence between new and previous trajectory distributions. Args: @@ -314,7 +318,7 @@ def step_adjust( costs_prev: QuadraticCosts, max_step_mult: float = 10, min_step_mult: float = 0.1, -): +) -> float: """Compute new multiplier for KL divergence constraint based on expected vs. actual improvement. See: diff --git a/donk/utils/batched.py b/donk/utils/batched.py index 8a55fd7..85aa660 100644 --- a/donk/utils/batched.py +++ b/donk/utils/batched.py @@ -59,7 +59,7 @@ def symmetrize(A: np.ndarray) -> np.ndarray: return A -def regularize(A: np.ndarray, regularization) -> np.ndarray: +def regularize(A: np.ndarray, regularization: float) -> np.ndarray: """Regularizes a matrix or a batch of matrices by adding a constant to the diagonal. Modifies the given matrix in-place. @@ -84,7 +84,7 @@ def trace_of_product(A: np.ndarray, B: np.ndarray) -> np.ndarray: return np.einsum("...ij,...ji->...", A, B) -def batched_multivariate_normal(mean: np.ndarray, covar: np.ndarray, N: int, rng) -> np.ndarray: +def batched_multivariate_normal(mean: np.ndarray, covar: np.ndarray, N: int, rng: np.random.Generator) -> np.ndarray: """Draw `N` samples from `T` multivariate normal distributions each. Args: diff --git a/donk/visualization/costs.py b/donk/visualization/costs.py index 6c67375..caa0934 100644 --- a/donk/visualization/costs.py +++ b/donk/visualization/costs.py @@ -11,7 +11,7 @@ def visualize_costs( costs: list[np.ndarray], cost_labels: list[str], include_total: bool = True, -): +) -> None: """Plots mutiple cost curves. Args: diff --git a/donk/visualization/states.py b/donk/visualization/states.py index 7a626ff..83a4b16 100644 --- a/donk/visualization/states.py +++ b/donk/visualization/states.py @@ -1,10 +1,12 @@ """Visualization tool for state spaces.""" +from pathlib import Path + import matplotlib.pyplot as plt import numpy as np import seaborn as sns -def visualize_correlation(output_file, X): +def visualize_correlation(output_file: Path | str | None, X: np.ndarray) -> None: """Visualize the correlation between states. Args: diff --git a/setup.cfg b/setup.cfg index 971cb75..3b2f71d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ [flake8] ignore = E203,E741,W503,D100,D104,D105,PL123,ANN101 per-file-ignores = - tests/*:D101,D102,ANN201 + tests/*:D101,D102,ANN001,ANN201,ANN202 max-line-length = 120 docstring-convention = google known-modules = donk.ai:[donk,tests] diff --git a/tests/data/__init__.py b/tests/data/__init__.py deleted file mode 100644 index 44e70b5..0000000 --- a/tests/data/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from tests.data import load_state_controller_dataset - -__all__ = [ - "load_state_controller_dataset", -] diff --git a/tests/data/data.py b/tests/data/data.py deleted file mode 100644 index 4feabc6..0000000 --- a/tests/data/data.py +++ /dev/null @@ -1,34 +0,0 @@ -from pathlib import Path - -import numpy as np - -from donk.policy import LinearGaussianPolicy - - -def load_state_controller_dataset(dataset, itr): - """Loads a state_controller dataset. - - Args: - dataset: Id of the dataset - itr: Iteration to return - - Returns: - X: (N, T, dX) Real states - pol: Fitted linear policy - X_mean: (T, dX) Mean of state distribution - X_covar: (T, dX, dX) Covariance of state distribution - """ - file = Path(f"tests/data/state_controller_{dataset:02d}.npz") - if not file.is_file(): - raise ValueError(f"There is no dataset 'state_controller_{dataset:02d}.npz'") - - with np.load(file) as data: - if itr not in range(len(data["X"])): - raise ValueError(f"Invalid iteration {itr}") - - X = data["X"][itr] - pol = LinearGaussianPolicy(K=data["K"][itr], k=data["k"][itr], pol_covar=data["pol_covar"][itr]) - X_mean = data["X_mean"][itr] - X_covar = data["X_covar"][itr] - - return X, pol, X_mean, X_covar diff --git a/tests/utils.py b/tests/utils.py index d769570..bfda671 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,8 @@ """Utilities for unit testing.""" +from __future__ import annotations + +from pathlib import Path + import numpy as np from donk.dynamics import LinearDynamics @@ -6,12 +10,12 @@ from donk.utils import symmetrize -def random_spd(shape, rng): +def random_spd(shape: tuple[int], rng: np.random.Generator): """Generate an arbitrarily random matrix guaranteeed to be s.p.d.""" return symmetrize(rng.uniform(size=shape)) + np.eye(shape[-1]) * shape[-1] -def random_lq_pol(T, dX, dU, rng): +def random_lq_pol(T: int, dX: int, dU: int, rng: np.random.Generator): """Generate an arbitrarily random linear gaussian policy.""" K = rng.normal(size=(T, dU, dX)) k = rng.normal(size=(T, dU)) @@ -19,9 +23,40 @@ def random_lq_pol(T, dX, dU, rng): return LinearGaussianPolicy(K, k, pol_covar) -def random_tvlg(T, dX, dU, rng): +def random_tvlg(T: int, dX: int, dU: int, rng: np.random.Generator): """Generate arbitrarily random linear gaussian dynamics.""" Fm = rng.normal(size=(T, dX, dX + dU)) fv = rng.normal(size=(T, dX)) dyn_covar = random_spd((T, dX, dX), rng) return LinearDynamics(Fm, fv, dyn_covar) + + +def load_state_controller_dataset( + dataset: int, itr: int +) -> tuple[np.ndarray, LinearGaussianPolicy, np.ndarray, np.ndarray]: + """Loads a state_controller dataset. + + Args: + dataset: Id of the dataset + itr: Iteration to return + + Returns: + X: (N, T, dX) Real states + pol: Fitted linear policy + X_mean: (T, dX) Mean of state distribution + X_covar: (T, dX, dX) Covariance of state distribution + """ + file = Path(f"tests/data/state_controller_{dataset:02d}.npz") + if not file.is_file(): + raise ValueError(f"There is no dataset 'state_controller_{dataset:02d}.npz'") + + with np.load(file) as data: + if itr not in range(len(data["X"])): + raise ValueError(f"Invalid iteration {itr}") + + X = data["X"][itr] + pol = LinearGaussianPolicy(K=data["K"][itr], k=data["k"][itr], pol_covar=data["pol_covar"][itr]) + X_mean = data["X_mean"][itr] + X_covar = data["X_covar"][itr] + + return X, pol, X_mean, X_covar