Skip to content

Commit

Permalink
Added more type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
DiddiZ committed May 29, 2023
1 parent 34ac94a commit ec2d787
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 73 deletions.
4 changes: 2 additions & 2 deletions donk/costs/cost_function_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from donk.costs.quadratic_costs import QuadraticCosts


def _vectorize_cost_function(fun):
def _vectorize_cost_function(fun) -> Callable[[np.ndarray, np.ndarray], np.ndarray]:
return np.vectorize(lambda X, U: np.array(fun(X, U)), signature="(v,x),(t,u)->(v)")


Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
"""Initialize this `MultipartSymbolicCostFunction`."""
from sympy import lambdify, symbols

def cost_fun(X, U):
def cost_fun(X: np.ndarray, U: np.ndarray) -> np.ndarray:
"""Sum up individual parts."""
cost = cost_funs[0](X, U)
for fn in cost_funs[1:]:
Expand Down
13 changes: 6 additions & 7 deletions donk/costs/losses.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations

import numpy as np


def loss_combined(x, losses):
def loss_combined(x: np.ndarray, losses) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Evaluates and sums up multiple loss functions.
Args:
x: (T, dX) states, actual values.
losses: List of tuples (`loss`, `kwargs`)
`loss`: loss function to evaluate.
`kwargs`: Addional arguments passed to the loss function (optional).
"""
if len(losses) < 1:
raise ValueError("loss_combined requred at least one loss function to sum up.")
Expand All @@ -24,7 +25,7 @@ def loss_combined(x, losses):
return l, lx, lxx


def loss_l2(x, t, w):
def loss_l2(x: np.ndarray, t: np.ndarray, w: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Evaluate and compute derivatives for l2 norm penalty.
loss = sum(0.5 * (x - t)^2 * w)
Expand All @@ -38,7 +39,6 @@ def loss_l2(x, t, w):
l: (T,) cost at each timestep.
lx: (T, D) first order derivative.
lxx: (T, D, D) second order derivative.
"""
# Get trajectory length.
_, dX = x.shape
Expand All @@ -60,7 +60,7 @@ def loss_l2(x, t, w):
return l, lx, lxx


def loss_l1(x, t, w, alpha):
def loss_l1(x: np.ndarray, t: np.ndarray, w: np.ndarray, alpha: float) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Evaluate and compute derivatives for l2 norm penalty.
loss = sum(sqrt((x - t)^2 + alpha) * w)
Expand All @@ -74,7 +74,6 @@ def loss_l1(x, t, w, alpha):
l: (T,) cost at each timestep.
lx: (T, D) first order derivative.
lxx: (T, D, D) second order derivative.
"""
# Get trajectory length.
_, dX = x.shape
Expand All @@ -97,7 +96,7 @@ def loss_l1(x, t, w, alpha):
return l, lx, lxx


def loss_log_cosh(x, t, w):
def loss_log_cosh(x: np.ndarray, t: np.ndarray, w: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Evaluate and compute derivatives for log-cosh loss.
loss = sum(log(cosh(x - t)) * w)
Expand Down
14 changes: 8 additions & 6 deletions donk/costs/quadratic_costs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any

import numpy as np

from donk.costs.cost_function import CostFunction
Expand All @@ -13,7 +15,7 @@ class QuadraticCosts(CostFunction):
cost(x, u) = 1/2 [x u]^T*C*[x u] + [x u]^T*c + cc
"""

def __init__(self, C: np.ndarray, c: np.ndarray, cc: np.ndarray):
def __init__(self, C: np.ndarray, c: np.ndarray, cc: np.ndarray) -> None:
"""Initialize this LinearDynamics object.
Args:
Expand All @@ -34,20 +36,20 @@ def __init__(self, C: np.ndarray, c: np.ndarray, cc: np.ndarray):
self.c = c
self.cc = cc

def __add__(self, other) -> QuadraticCosts:
def __add__(self, other: Any) -> QuadraticCosts:
"""Sum two cost functions."""
if isinstance(other, QuadraticCosts):
return QuadraticCosts(self.C + other.C, self.c + other.c, self.cc + other.cc)
return NotImplemented

def __mul__(self, other) -> QuadraticCosts:
"""Scale cost function with vonstant scalar."""
def __mul__(self, other: Any) -> QuadraticCosts:
"""Scale cost function with constant scalar."""
if np.isscalar(other):
return QuadraticCosts(self.C * other, self.c * other, self.cc * other)
return NotImplemented

def __rmul__(self, other) -> QuadraticCosts:
"""Scale cost function with vonstant scalar."""
def __rmul__(self, other: Any) -> QuadraticCosts:
"""Scale cost function with constant scalar."""
return self.__mul__(other)

def compute_costs(self, X: np.ndarray, U: np.ndarray) -> np.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion donk/datalogging/shelve.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __enter__(self) -> ShelveDataLogger:

return super().__enter__()

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
super().__exit__(exc_type, exc_val, exc_tb)

# Close shelve
Expand Down
4 changes: 3 additions & 1 deletion donk/dynamics/prior/prior.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass

Expand Down Expand Up @@ -43,7 +45,7 @@ def map_covar(self) -> np.ndarray:
return self.Phi / (self.N_covar + d + 1)

@staticmethod
def non_informative_prior(d):
def non_informative_prior(d: int) -> NormalInverseWishart:
"""Create a non-onformative prior distribution.
Args:
Expand Down
2 changes: 1 addition & 1 deletion donk/dynamics/prior/prior_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class GMMPrior(DynamicsPrior):
"""A Gaussian Mixture Model (GMM) based prior."""

def __init__(self, n_clusters, random_state=None) -> None:
def __init__(self, n_clusters: int, random_state=None) -> None:
"""Initialize this `GMMPrior`."""
self.gmm = GaussianMixture(n_components=n_clusters, random_state=random_state)

Expand Down
20 changes: 12 additions & 8 deletions donk/traj_opt/lqg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from dataclasses import dataclass

import numpy as np
Expand Down Expand Up @@ -72,7 +74,7 @@ def step(self, eta: float) -> ILQRStepResult:

return ILQRStepResult(eta, pol, kl_div, expected_costs, traj)

def sample_surface(self, min_eta: float = 1e-6, max_eta: float = 1e16, N: int = 100):
def sample_surface(self, min_eta: float = 1e-6, max_eta: float = 1e16, N: int = 100) -> list[ILQRStepResult]:
"""Sample the Lagrangian at different values for eta.
For visualization/debugging purposes.
Expand All @@ -85,7 +87,9 @@ def sample_surface(self, min_eta: float = 1e-6, max_eta: float = 1e16, N: int =
results = [self.step(eta) for eta in np.logspace(np.log10(min_eta), np.log10(max_eta), num=N)]
return results

def optimize(self, kl_step: float, min_eta: float = 1e-6, max_eta: float = 1e16, rtol: float = 1e-2):
def optimize(
self, kl_step: float, min_eta: float = 1e-6, max_eta: float = 1e16, rtol: float = 1e-2
) -> ILQRStepResult:
"""Perform iLQG trajectory optimization.
Args:
Expand All @@ -109,7 +113,7 @@ def optimize(self, kl_step: float, min_eta: float = 1e-6, max_eta: float = 1e16,
raise ValueError(f"max_eta eta to low ({max_eta})")

# Find the point where kl divergence equals the kl_step
def constraint_violation(log_eta):
def constraint_violation(log_eta: float) -> float:
return self.step(np.exp(log_eta)).kl_div - kl_step

# Search root of the constraint violation
Expand All @@ -119,7 +123,7 @@ def constraint_violation(log_eta):
return self.step(np.exp(log_eta))


def backward(dynamics: LinearDynamics, C, c, gamma=1) -> LinearGaussianPolicy:
def backward(dynamics: LinearDynamics, C: np.ndarray, c: np.ndarray, gamma: float = 1) -> LinearGaussianPolicy:
"""Perform LQR backward pass.
`C` is required to be symmetric.
Expand Down Expand Up @@ -182,7 +186,7 @@ def forward(
dynamics: LinearDynamics,
policy: LinearGaussianPolicy,
initial_state: StateDistribution,
regularization=1e-6,
regularization: float = 1e-6,
) -> TrajectoryDistribution:
"""Perform LQR forward pass.
Expand Down Expand Up @@ -240,7 +244,7 @@ def forward(
return TrajectoryDistribution(traj_mean, traj_covar, dX, dU)


def extended_costs_kl(prev_pol: LinearGaussianPolicy):
def extended_costs_kl(prev_pol: LinearGaussianPolicy) -> tuple[np.ndarray, np.ndarray]:
"""Compute expansion of extended cost used in the iLQR backward pass.
The extended cost function is -log p(u_t | x_t) with p being the previous trajectory distribution.
Expand Down Expand Up @@ -269,7 +273,7 @@ def extended_costs_kl(prev_pol: LinearGaussianPolicy):
return C, c


def kl_divergence_action(X, pol: LinearGaussianPolicy, prev_pol: LinearGaussianPolicy):
def kl_divergence_action(X: np.ndarray, pol: LinearGaussianPolicy, prev_pol: LinearGaussianPolicy) -> float:
"""Compute KL divergence between new and previous trajectory distributions.
Args:
Expand Down Expand Up @@ -314,7 +318,7 @@ def step_adjust(
costs_prev: QuadraticCosts,
max_step_mult: float = 10,
min_step_mult: float = 0.1,
):
) -> float:
"""Compute new multiplier for KL divergence constraint based on expected vs. actual improvement.
See:
Expand Down
4 changes: 2 additions & 2 deletions donk/utils/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def symmetrize(A: np.ndarray) -> np.ndarray:
return A


def regularize(A: np.ndarray, regularization) -> np.ndarray:
def regularize(A: np.ndarray, regularization: float) -> np.ndarray:
"""Regularizes a matrix or a batch of matrices by adding a constant to the diagonal.
Modifies the given matrix in-place.
Expand All @@ -84,7 +84,7 @@ def trace_of_product(A: np.ndarray, B: np.ndarray) -> np.ndarray:
return np.einsum("...ij,...ji->...", A, B)


def batched_multivariate_normal(mean: np.ndarray, covar: np.ndarray, N: int, rng) -> np.ndarray:
def batched_multivariate_normal(mean: np.ndarray, covar: np.ndarray, N: int, rng: np.random.Generator) -> np.ndarray:
"""Draw `N` samples from `T` multivariate normal distributions each.
Args:
Expand Down
2 changes: 1 addition & 1 deletion donk/visualization/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def visualize_costs(
costs: list[np.ndarray],
cost_labels: list[str],
include_total: bool = True,
):
) -> None:
"""Plots mutiple cost curves.
Args:
Expand Down
4 changes: 3 additions & 1 deletion donk/visualization/states.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Visualization tool for state spaces."""
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


def visualize_correlation(output_file, X):
def visualize_correlation(output_file: Path | str | None, X: np.ndarray) -> None:
"""Visualize the correlation between states.
Args:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[flake8]
ignore = E203,E741,W503,D100,D104,D105,PL123,ANN101
per-file-ignores =
tests/*:D101,D102,ANN201
tests/*:D101,D102,ANN001,ANN201,ANN202
max-line-length = 120
docstring-convention = google
known-modules = donk.ai:[donk,tests]
Expand Down
5 changes: 0 additions & 5 deletions tests/data/__init__.py

This file was deleted.

34 changes: 0 additions & 34 deletions tests/data/data.py

This file was deleted.

41 changes: 38 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,62 @@
"""Utilities for unit testing."""
from __future__ import annotations

from pathlib import Path

import numpy as np

from donk.dynamics import LinearDynamics
from donk.policy import LinearGaussianPolicy
from donk.utils import symmetrize


def random_spd(shape, rng):
def random_spd(shape: tuple[int], rng: np.random.Generator):
"""Generate an arbitrarily random matrix guaranteeed to be s.p.d."""
return symmetrize(rng.uniform(size=shape)) + np.eye(shape[-1]) * shape[-1]


def random_lq_pol(T, dX, dU, rng):
def random_lq_pol(T: int, dX: int, dU: int, rng: np.random.Generator):
"""Generate an arbitrarily random linear gaussian policy."""
K = rng.normal(size=(T, dU, dX))
k = rng.normal(size=(T, dU))
pol_covar = random_spd((T, dU, dU), rng)
return LinearGaussianPolicy(K, k, pol_covar)


def random_tvlg(T, dX, dU, rng):
def random_tvlg(T: int, dX: int, dU: int, rng: np.random.Generator):
"""Generate arbitrarily random linear gaussian dynamics."""
Fm = rng.normal(size=(T, dX, dX + dU))
fv = rng.normal(size=(T, dX))
dyn_covar = random_spd((T, dX, dX), rng)
return LinearDynamics(Fm, fv, dyn_covar)


def load_state_controller_dataset(
dataset: int, itr: int
) -> tuple[np.ndarray, LinearGaussianPolicy, np.ndarray, np.ndarray]:
"""Loads a state_controller dataset.
Args:
dataset: Id of the dataset
itr: Iteration to return
Returns:
X: (N, T, dX) Real states
pol: Fitted linear policy
X_mean: (T, dX) Mean of state distribution
X_covar: (T, dX, dX) Covariance of state distribution
"""
file = Path(f"tests/data/state_controller_{dataset:02d}.npz")
if not file.is_file():
raise ValueError(f"There is no dataset 'state_controller_{dataset:02d}.npz'")

with np.load(file) as data:
if itr not in range(len(data["X"])):
raise ValueError(f"Invalid iteration {itr}")

X = data["X"][itr]
pol = LinearGaussianPolicy(K=data["K"][itr], k=data["k"][itr], pol_covar=data["pol_covar"][itr])
X_mean = data["X_mean"][itr]
X_covar = data["X_covar"][itr]

return X, pol, X_mean, X_covar

0 comments on commit ec2d787

Please sign in to comment.