diff --git a/jaxutils/parameters.py b/jaxutils/parameters.py index 3656024..c9fa665 100644 --- a/jaxutils/parameters.py +++ b/jaxutils/parameters.py @@ -1,14 +1,8 @@ -import warnings -from copy import deepcopy -from typing import Dict, Tuple +from typing import Dict from warnings import warn -import distrax as dx -import jax -import jax.numpy as jnp import jax.random as jr from jax.random import KeyArray -from jaxtyping import Array, Float from jaxutils import PyTree