diff --git a/doc/pendulum.md b/doc/pendulum.md index 2bb64a2..2321e72 100644 --- a/doc/pendulum.md +++ b/doc/pendulum.md @@ -25,7 +25,7 @@ purely knowledge graph driven AI. A full swing takes a couple of years. Frank argued that there needs to be a middle ground where machine learning and knowledge graphs come together. -I belive that the implementation of probabilistic models in this package is capable of doing so. +I believe that the implementation of probabilistic models in this package is capable of doing so. Knowledge graphs generate sets that describe possible assignments that match the constraints and instance knowledge of ontologies (a random event, so to say). Probability distributions describe the likelihoods of every possible solution. diff --git a/scripts/gmm.py b/scripts/gmm.py index e76685d..ee51138 100644 --- a/scripts/gmm.py +++ b/scripts/gmm.py @@ -28,7 +28,7 @@ number_of_variables = 2 number_of_samples_per_component = 100000 number_of_components = 2 -number_of_mixtures = 100 +number_of_mixtures = 1000 number_of_iterations = 1000 # model selection diff --git a/scripts/jpt_speed_comparison.py b/scripts/jpt_speed_comparison.py index ac01f85..4ee388e 100644 --- a/scripts/jpt_speed_comparison.py +++ b/scripts/jpt_speed_comparison.py @@ -110,7 +110,7 @@ def timed_jax_method(): # ll = jax_model.log_likelihood(samples) # assert (ll > -jnp.inf).all() -times_nx, times_jax = eval_performance(nx_model.log_likelihood, (data, ), compiled_ll_jax, (data_jax, ), 10, 5) +times_nx, times_jax = eval_performance(nx_model.log_likelihood, (data, ), compiled_ll_jax, (data_jax, ), 15, 10) # times_nx, times_jax = eval_performance(nx_model.probability_of_simple_event, (event,), jax_model.probability_of_simple_event, (event,), 20, 10) # times_nx, times_jax = eval_performance(nx_model.sample, (10000, ), jax_model.sample, (10000,), 10, 5) diff --git a/scripts/nyga_speed_comparison.py b/scripts/nyga_speed_comparison.py index 65b1cb9..1d97dcf 100644 --- a/scripts/nyga_speed_comparison.py +++ b/scripts/nyga_speed_comparison.py @@ -34,7 +34,7 @@ path_prefix = os.path.join(os.path.expanduser("~"), "Documents") nx_model_path = os.path.join(path_prefix, "nx_nyga.pm") jax_model_path = os.path.join(path_prefix, "jax_nyga.pm") -load_from_disc = True +load_from_disc = False save_to_disc = True @@ -103,7 +103,7 @@ def timed_jax_method(): # times_nx, times_jax = eval_performance(nx_model.log_likelihood, (data, ), compiled_ll_jax, (data_jax, ), 20, 2) # times_nx, times_jax = eval_performance(prob_nx, event, prob_jax, event, 15, 10) -times_nx, times_jax = eval_performance(nx_model.sample, (1000, ), jax_model.sample, (1000, ), 10, 5) +times_nx, times_jax = eval_performance(nx_model.sample, (1000, ), jax_model.sample, (1000, ), 5, 10) time_jax = np.mean(times_jax), np.std(times_jax) time_nx = np.mean(times_nx), np.std(times_nx) diff --git a/src/probabilistic_model/learning/region_graph/__init__.py b/src/probabilistic_model/learning/region_graph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/probabilistic_model/learning/region_graph/region_graph.py b/src/probabilistic_model/learning/region_graph/region_graph.py new file mode 100644 index 0000000..5ab1b1e --- /dev/null +++ b/src/probabilistic_model/learning/region_graph/region_graph.py @@ -0,0 +1,160 @@ +from collections import deque + +import networkx as nx +import numpy as np +from jax.experimental.sparse import BCOO +from random_events.variable import Continuous +from sortedcontainers import SortedSet +from typing_extensions import List, Self, Type + +from ...distributions import GaussianDistribution +from ...probabilistic_circuit.jax import SumLayer, ProductLayer +from ...probabilistic_circuit.jax.gaussian_layer import GaussianLayer +from ...probabilistic_circuit.nx.probabilistic_circuit import ProbabilisticCircuit, SumUnit, ProductUnit +from ...probabilistic_circuit.nx.distributions.distributions import UnivariateContinuousLeaf +from ...probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit as JPC +import jax.numpy as jnp +import jax.random + + +class Region: + variables: SortedSet + + def __init__(self, variables: SortedSet): + self.variables = variables + + def __hash__(self) -> int: + return id(self) + + def random_partition(self, k=2) -> List[Self]: + indices = np.arange(len(self.variables)) + np.random.shuffle(indices) + partitions = [Region(SortedSet([self.variables[index] for index in split])) for split in np.array_split(indices, k)] + return partitions + + def __repr__(self) -> str: + return "{" + ", ".join([v.name for v in self.variables]) + "}" + +class Partition: + def __hash__(self) -> int: + return id(self) + + +class RegionGraph(nx.DiGraph): + + variables: SortedSet + + def __init__(self, variables: SortedSet, + partitions: int = 2, + depth:int = 2, + repetitions:int = 2): + super().__init__() + self.variables = variables + self.partitions = partitions + self.depth = depth + self.repetitions = repetitions + + + def create_random_region_graph(self): + root = Region(self.variables) + self.add_node(root) + for repetition in range(self.repetitions): + self.recursive_split(root) + return self + + def regions(self): + for node in self.nodes: + if isinstance(node, Region): + yield node + + def partition_nodes(self): + for node in self.nodes: + if isinstance(node, Partition): + yield node + + def recursive_split(self, node: Region): + root_partition = Partition() + self.add_edge(node, root_partition) + remaining_regions = deque([(node, self.depth, root_partition)]) + + while remaining_regions: + region, depth, partition = remaining_regions.popleft() + + + if len(region.variables) == 1: + continue + + if depth == 0: + for variable in region.variables: + self.add_edge(partition, Region(SortedSet([variable]))) + continue + + new_regions = region.random_partition(self.partitions) + for new_region in new_regions: + self.add_edge(partition, new_region) + new_partition = Partition() + self.add_edge(new_region, new_partition) + remaining_regions.append((new_region, depth - 1, new_partition)) + + @property + def root(self) -> Region: + """ + The root of the circuit is the node with in-degree 0. + This is the output node, that will perform the final computation. + + :return: The root of the circuit. + """ + possible_roots = [node for node in self.nodes() if self.in_degree(node) == 0] + if len(possible_roots) > 1: + raise ValueError(f"More than one root found. Possible roots are {possible_roots}") + + return possible_roots[0] + + + def as_probabilistic_circuit(self, continuous_distribution_type: Type = GaussianLayer, + input_units: int = 5, sum_units: int = 5, key=jax.random.PRNGKey(69)) -> JPC: + root = self.root + + # create nodes for each region + for layer in reversed(list(nx.bfs_layers(self, root))): + for node in layer: + children = list(self.successors(node)) + parents = list(self.predecessors(node)) + if isinstance(node, Region): + # if the region is a leaf + if len(children) == 0: + variable = node.variables[0] + variable_index = self.variables.index(variable) + if isinstance(variable, Continuous): + location = jax.random.uniform(key, shape=(input_units,), minval=-1., maxval=1.) + log_scale = jnp.log(jax.random.uniform(key, shape=(input_units,), minval=0.01, maxval=1.)) + node.layer = GaussianLayer(variable_index, location=location, log_scale=log_scale, min_scale=jnp.full_like(location, 0.01)) + node.layer.validate() + else: + raise NotImplementedError + + # if the region is root or in the middle + else: + # if the region is root + if len(parents) == 0: + sum_units = 1 + + log_weights = [BCOO.fromdense(jax.random.uniform(key, shape=(sum_units, child.layer.number_of_nodes), minval=0., maxval=1.)) for child in children] + for log_weight in log_weights: + log_weight.data = jnp.log(log_weight.data) + node.layer = SumLayer([child.layer for child in children], log_weights=log_weights) + node.layer.validate() + + + elif isinstance(node, Partition): + node_lengths = [child.layer.number_of_nodes for child in children] + assert (len(set(node_lengths)) == 1), "Node lengths must be all equal. Got {}".format(node_lengths) + + edges = jnp.arange(node_lengths[0]).reshape(1, -1).repeat(len(children), axis=0) + sparse_edges = BCOO.fromdense(jnp.ones_like(edges)) + sparse_edges.data = edges.flatten() + node.layer = ProductLayer([child.layer for child in children], sparse_edges) + node.layer.validate() + + model = JPC(self.variables, root.layer) + return model diff --git a/src/probabilistic_model/probabilistic_circuit/jax/discrete_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/discrete_layer.py new file mode 100644 index 0000000..1840f61 --- /dev/null +++ b/src/probabilistic_model/probabilistic_circuit/jax/discrete_layer.py @@ -0,0 +1,112 @@ +from typing import Tuple, Type, List, Dict, Any + +from random_events.set import SetElement +from random_events.variable import Symbolic, Variable +from sortedcontainers import SortedSet +from typing_extensions import Self, Optional + +from . import NXConverterLayer +from .inner_layer import InputLayer +import jax.numpy as jnp + +from ..nx.distributions import UnivariateDiscreteLeaf +from ..nx.probabilistic_circuit import Unit +from ...distributions import SymbolicDistribution +import tqdm +import numpy as np +from ..nx.probabilistic_circuit import ProbabilisticCircuit as NXProbabilisticCircuit + +from ...utils import MissingDict + + +class DiscreteLayer(InputLayer): + + log_probabilities: jnp.array + """ + The logarithm of probability for each state of the variable. + + The shape is (#nodes, #states). + """ + + def __init__(self, variable: int, log_probabilities: jnp.array): + super().__init__(variable) + self.log_probabilities = log_probabilities + + @classmethod + def nx_classes(cls) -> Tuple[Type, ...]: + return SymbolicDistribution, + + def validate(self): + return True + + @property + def normalization_constant(self) -> jnp.array: + return jnp.exp(self.log_probabilities).sum(1) + + @property + def log_normalization_constant(self) -> jnp.array: + return jnp.log(self.normalization_constant) + + @property + def normalized_log_probabilities(self) -> jnp.array: + return self.log_probabilities - self.log_normalization_constant[:, None] + + @property + def number_of_nodes(self) -> int: + return self.log_probabilities.shape[0] + + def log_likelihood_of_nodes_single(self, x: jnp.array) -> jnp.array: + return self.log_probabilities.T[x.astype(int)] - self.log_normalization_constant + + + @classmethod + def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[UnivariateDiscreteLeaf], + child_layers: List[NXConverterLayer], + progress_bar: bool = True) -> \ + NXConverterLayer: + hash_remap = {hash(node): index for index, node in enumerate(nodes)} + + variable: Symbolic = nodes[0].variable + + parameters = np.zeros((len(nodes), len(variable.domain.simple_sets))) + + for node in (tqdm.tqdm(nodes, desc=f"Creating discrete layer for variable {variable.name}") + if progress_bar else nodes): + for state, value in node.distribution.probabilities.items(): + parameters[hash_remap[hash(node)], state] = value + + + result = cls(nodes[0].probabilistic_circuit.variables.index(variable), jnp.log(parameters)) + return NXConverterLayer(result, nodes, hash_remap) + + def to_json(self) -> Dict[str, Any]: + return {**super().to_json(), + "variable": self.variable, "log_probabilities": self.log_probabilities.tolist()} + + @classmethod + def _from_json(cls, data: Dict[str, Any]) -> Self: + return cls(data["variable"], jnp.array(data["log_probabilities"])) + + def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit, + progress_bar: Optional[tqdm.tqdm] = None) -> List[Unit]: + + variable = variables[self.variable] + domain = variable.domain_type() + + if progress_bar: + progress_bar.set_postfix_str(f"Creating discrete distributions for variable {variable.name}") + + nodes = [UnivariateDiscreteLeaf(SymbolicDistribution(variable, MissingDict(float, + {domain(state): value for state, value in enumerate(jnp.exp(log_probabilities))})), result) + for log_probabilities in self.normalized_log_probabilities] + + if progress_bar: + progress_bar.update(self.number_of_nodes) + return nodes + + + + + + + diff --git a/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py index 6a34948..8388e90 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py @@ -37,11 +37,11 @@ def nx_classes(cls) -> Tuple[Type, ...]: def validate(self): assert self.location.shape == self.log_scale.shape, "The shapes of location and scale must match." assert self.min_scale.shape == self.log_scale.shape, "The shapes of the min_scale and scale bounds must match." - assert jnp.all(self.min_scale >= 0), "The scale must be positive." + assert jnp.all(self.min_scale >= 0), "The minimum scale must be positive." @property def number_of_nodes(self) -> int: - return len(self.location.shape) + return self.location.shape[0] @property def scale(self) -> jnp.array: @@ -80,20 +80,18 @@ def _from_json(cls, data: Dict[str, Any]) -> Self: return cls(data["variable"], jnp.array(data["location"]), jnp.array(data["scale"]), jnp.array(data["min_scale"])) - def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm] = None) -> List[ - Unit]: + def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit, + progress_bar: Optional[tqdm.tqdm] = None) -> List[Unit]: variable = variables[self.variable] if progress_bar: progress_bar.set_postfix_str(f"Creating Gaussian distributions for variable {variable.name}") - nx_pc = NXProbabilisticCircuit() nodes = [UnivariateContinuousLeaf( - GaussianDistribution(variable=variable, location=location.item(), scale=scale.item())) + GaussianDistribution(variable=variable, location=location.item(), scale=scale.item()),result) for location, scale in zip(self.location, self.scale)] if progress_bar: progress_bar.update(self.number_of_nodes) - nx_pc.add_nodes_from(nodes) return nodes diff --git a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py index 1ad7b9b..c32f899 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py @@ -20,7 +20,7 @@ from sortedcontainers import SortedSet from typing_extensions import List, Iterator, Tuple, Union, Type, Dict, Any, Self, Optional -from . import shrink_index_array +from . import shrink_index_array, embed_sparse_array_in_nan_array from .utils import copy_bcoo, sample_from_sparse_probabilities_csc, sparse_remove_rows_and_cols_where_all from ..nx.probabilistic_circuit import (SumUnit, ProductUnit, Unit, ProbabilisticCircuit as NXProbabilisticCircuit) @@ -58,30 +58,6 @@ def log_likelihood_of_nodes(self, x: jnp.array) -> jnp.array: """ return jax.vmap(self.log_likelihood_of_nodes_single)(x) - def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array: - """ - Calculate the cumulative distribution function of the distribution if applicable. - - :param x: The input vector. - :return: The cumulative distribution function of every node in the layer for x. - """ - raise NotImplementedError - - def cdf_of_nodes(self, x: jnp.array) -> jnp.array: - """ - Vectorized version of :meth:`cdf_of_nodes_single` - """ - return jax.vmap(self.cdf_of_nodes_single)(x) - - def probability_of_simple_event(self, event: SimpleEvent) -> jnp.array: - """ - Calculate the probability of a simple event P(E). - - :param event: The simple event to calculate the probability for. It has to contain every variable. - :return: P(E) - """ - raise NotImplementedError - def validate(self): """ Validate the parameters and their layouts. @@ -107,7 +83,7 @@ def all_layers_with_depth(self, depth: int = 0) -> List[Tuple[int, Layer]]: """ return [(depth, self)] - @property + @cached_property @abstractmethod def variables(self) -> jax.Array: """ @@ -128,8 +104,8 @@ def nx_classes(cls) -> Tuple[Type, ...]: """ return tuple() - def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm] = None) -> List[ - Unit]: + def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit, + progress_bar: Optional[tqdm.tqdm] = None,) -> List[Unit]: """ Convert the layer to a networkx circuit. For every node in this circuit, a corresponding node in the networkx circuit @@ -137,25 +113,10 @@ def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm The nodes all belong to the same circuit. :param variables: The variables of the circuit. + :param result: The resulting circuit to write into :param progress_bar: A progress bar to show the progress. - :return: The nodes of the networkx circuit. - """ - raise NotImplementedError - @property - def impossible_condition_result(self) -> Tuple[None, jax.Array]: - """ - :return: The result that a layer yields if it is conditioned on an event E with P(E) = 0 - """ - return None, jnp.full((self.number_of_nodes,), -jnp.inf, dtype=jnp.float32) - - def log_conditional_of_simple_event(self, event: SimpleEvent, ) -> Tuple[Optional[Self], jax.Array]: - """ - Calculate the conditional probability distribution given a simple event P(X|E). - Also return the log probability of E log(P(E)). - - :param event: The event to calculate the conditional distribution for. - :return: The conditional distribution and the log probability of the event. + :return: The nodes of the networkx circuit. """ raise NotImplementedError @@ -183,6 +144,7 @@ def create_layers_from_nodes(nodes: List[Unit], child_layers: List[NXConverterLa for scope in unique_scopes: nodes_of_current_type_and_scope = [node for node in nodes_of_current_type if tuple(node.variables) == scope] + layer = layer_type.create_layer_from_nodes_with_same_type_and_scope(nodes_of_current_type_and_scope, child_layers, progress_bar) result.append(layer) @@ -219,42 +181,12 @@ def number_of_trainable_parameters(self): return number_of_parameters @property - def number_of_components(self) -> 0: + def number_of_components(self) -> int: """ :return: The number of components (leaves + edges) of the entire circuit """ return self.number_of_nodes - def sample_from_frequencies(self, frequencies: np.array, result: np.array, start_index=0): - raise NotImplementedError - - def moment_of_nodes(self, order: jax.Array, center: jax.Array): - """ - Calculate the moment of the nodes. - The order and center vectors describe the moments for all variables in the entire model. Hence, they should - never be touched by the forward pass. - - :param order: The order of the moment for each variable. - :param center: The center of the moment for each variable. - :return: The moments of the nodes with shape (#nodes, #variables). - """ - raise NotImplementedError - - def merge_with(self, others: List[Self]) -> Self: - """ - Merge the layer with others of the same type. - """ - raise NotImplementedError - - def remove_nodes(self, remove_mask: jax.Array) -> Self: - """ - Remove nodes from the layer. - - :param remove_mask: A boolean mask of the nodes to remove. - :return: The layer with the nodes removed. - """ - raise NotImplementedError - class InnerLayer(Layer, ABC): """ @@ -270,7 +202,7 @@ def __init__(self, child_layers: List[Layer]): super().__init__() self.child_layers = child_layers - @property + @cached_property @abstractmethod def variables(self) -> jax.Array: raise NotImplementedError @@ -298,12 +230,6 @@ def to_json(self) -> Dict[str, Any]: result["child_layers"] = [child_layer.to_json() for child_layer in self.child_layers] return result - def clean_up_orphans(self) -> Self: - """ - Clean up the layer by removing orphans in the child layers. - """ - raise NotImplementedError - class InputLayer(Layer, ABC): """ @@ -322,7 +248,7 @@ def __init__(self, variable: int): super().__init__() self._variables = jnp.array([variable]) - @property + @cached_property def variables(self) -> jax.Array: return self._variables @@ -337,6 +263,7 @@ def variable(self): class SumLayer(InnerLayer): + log_weights: List[BCOO] child_layers: Union[List[[ProductLayer]], List[InputLayer]] @@ -412,168 +339,10 @@ def log_likelihood_of_nodes_single(self, x: jax.Array) -> jax.Array: # multiply the weights with the child layer likelihood cloned_log_weights.data += child_layer_log_likelihood[cloned_log_weights.indices[:, 1]] cloned_log_weights.data = jnp.exp(cloned_log_weights.data) # exponent weights + result = result.at[cloned_log_weights.indices[:, 0]].add(cloned_log_weights.data, + indices_are_sorted=False, unique_indices=False) - # sum the weights for each node - ll = cloned_log_weights.sum(1).todense() - - # sum the child layer result - result += ll - - return jnp.where(result > 0, jnp.log(result) - self.log_normalization_constants, -jnp.inf) - - def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array: - result = jnp.zeros(self.number_of_nodes, dtype=jnp.float32) - - for log_weights, child_layer in self.log_weighted_child_layers: - # get the cdf of the child nodes - child_layer_cdf = child_layer.cdf_of_nodes_single(x) - - # weight the cdf of the child nodes by the weight for each node of this layer - cloned_log_weights = copy_bcoo(log_weights) # clone the weights - - # multiply the weights with the child layer cdf - cloned_log_weights.data = jnp.exp(cloned_log_weights.data) # exponent weights - cloned_log_weights.data *= child_layer_cdf[cloned_log_weights.indices[:, 1]] - - # sum the weights for each node - ll = cloned_log_weights.sum(1).todense() - - # sum the child layer result - result += ll - - # normalize the result - normalization_constants = jnp.exp(self.log_normalization_constants) - return result / normalization_constants - - def probability_of_simple_event(self, event: SimpleEvent) -> jnp.array: - result = jnp.zeros(self.number_of_nodes, dtype=jnp.float32) - - for log_weights, child_layer in self.log_weighted_child_layers: - # get the probability of the child nodes - child_layer_prob = child_layer.probability_of_simple_event(event) - - # weight the probability of the child nodes by the weight for each node of this layer - cloned_log_weights = copy_bcoo(log_weights) # clone the weights - - # multiply the weights with the child layer cdf - cloned_log_weights.data = jnp.exp(cloned_log_weights.data) # exponent weights - cloned_log_weights.data *= child_layer_prob[cloned_log_weights.indices[:, 1]] - - # sum the weights for each node - ll = cloned_log_weights.sum(1).todense() - - # sum the child layer result - result += ll - - # normalize the result - normalization_constants = jnp.exp(self.log_normalization_constants) - return result / normalization_constants - - def moment_of_nodes(self, order: jax.Array, center: jax.Array): - result = jnp.zeros((self.number_of_nodes, len(self.variables)), dtype=jnp.float32) - - for log_weights, child_layer in self.log_weighted_child_layers: - # get the moment of the child nodes - moment = child_layer.moment_of_nodes(order, center) # shape (#child_layer_nodes, #variables) - - # weight the moment of the child nodes by the weight for each node of this layer - weights = copy_bcoo(log_weights) # clone the weights, shape (#nodes, #child_layer_nodes) - weights.data = jnp.exp(weights.data) # exponent weights - - # calculate the weighted sum in layer - moment = weights @ moment - - # sum the child layer result - result += moment - - return result / jnp.exp(self.log_normalization_constants.reshape(-1, 1)) - - def sample_from_frequencies(self, frequencies: np.array, result: np.array, start_index=0): - node_to_child_frequency_map = self.node_to_child_frequency_map(frequencies) - - # offset for shifting through the frequencies of the node_to_child_frequency_map - prev_column_index = 0 - - consumed_indices = start_index - - for child_layer in self.child_layers: - # extract the frequencies for the child layer - current_frequency_block = node_to_child_frequency_map[:, - prev_column_index:prev_column_index + child_layer.number_of_nodes] - frequencies_for_child_nodes = current_frequency_block.sum(0) - child_layer.sample_from_frequencies(frequencies_for_child_nodes, result, consumed_indices) - consumed_indices += frequencies_for_child_nodes.sum() - - # shift the offset - prev_column_index += child_layer.number_of_nodes - - def node_to_child_frequency_map(self, frequencies: np.array): - """ - Sample from the exact distribution of the layer by interpreting every node as latent variable. - This is very slow due to BCOO.sum_duplicates being very slow. - - :param frequencies: - :param key: - :return: - """ - clw = self.normalized_weights - csr = coo_matrix((clw.data, clw.indices.T), shape=clw.shape).tocsr(copy=False) - return sample_from_sparse_probabilities_csc(csr, frequencies) - - def log_conditional_of_simple_event(self, event: SimpleEvent, ) -> Tuple[Optional[Self], jax.Array]: - conditional_child_layers = [] - conditional_log_weights = [] - - probabilities = jnp.zeros(self.number_of_nodes, dtype=jnp.float32) - - for log_weights, child_layer in self.log_weighted_child_layers: - # get the conditional of the child layer - conditional, child_log_prob = child_layer.log_conditional_of_simple_event(event) - if conditional is None: - continue - - # clone weights - log_weights = copy_bcoo(log_weights) - - # calculate the weighted sum of the child log probabilities - log_weights.data += child_log_prob[log_weights.indices[:, 1]] - - # skip if this layer is not connected to anything anymore - if jnp.all(log_weights.data == -jnp.inf): - continue - - log_weights.data = jnp.exp(log_weights.data) - - # calculate the probabilities of the child nodes in total - current_probabilities = log_weights.sum(1).todense() - probabilities += current_probabilities - - log_weights.data = jnp.log(log_weights.data) - - conditional_child_layers.append(conditional) - conditional_log_weights.append(log_weights) - - if len(conditional_child_layers) == 0: - return self.impossible_condition_result - - log_probabilities = jnp.log(probabilities) - - concatenated_log_weights = bcoo_concatenate(conditional_log_weights, dimension=1).sort_indices() - # remove rows and columns where all weights are -inf - cleaned_log_weights = sparse_remove_rows_and_cols_where_all(concatenated_log_weights, -jnp.inf) - - # normalize the weights - z = cleaned_log_weights.sum(1).todense() - cleaned_log_weights.data -= z[cleaned_log_weights.indices[:, 0]] - - # slice the weights for each child layer - log_weight_slices = jnp.array([0] + [ccl.number_of_nodes for ccl in conditional_child_layers]) - log_weight_slices = jnp.cumsum(log_weight_slices) - conditional_log_weights = [cleaned_log_weights[:, log_weight_slices[i]:log_weight_slices[i + 1]].sort_indices() - for i in range(len(conditional_child_layers))] - - resulting_layer = SumLayer(conditional_child_layers, conditional_log_weights) - return resulting_layer, (log_probabilities - self.log_normalization_constants) + return jnp.log(result) - self.log_normalization_constants def __deepcopy__(self): child_layers = [child_layer.__deepcopy__() for child_layer in self.child_layers] @@ -632,46 +401,27 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[SumUnit], sum_layer = cls([cl.layer for cl in filtered_child_layers], log_weights) return NXConverterLayer(sum_layer, nodes, result_hash_remap) - def remove_nodes(self, remove_mask: jax.Array) -> Self: - new_log_weights = [lw[~remove_mask] for lw in self.log_weights] - return self.__class__(self.child_layers, new_log_weights) - - def clean_up_orphans(self) -> Self: - raise NotImplementedError - - def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm] = None) -> List[ - Unit]: + def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit, + progress_bar: Optional[tqdm.tqdm] = None) -> List[Unit]: variables_ = [variables[i] for i in self.variables] if progress_bar: progress_bar.set_postfix_str(f"Parsing Sum Layer for variables {variables_}") - nx_pc = NXProbabilisticCircuit() - units = [SumUnit() for _ in range(self.number_of_nodes)] - nx_pc.add_nodes_from(units) - - child_layer_nx = [cl.to_nx(variables, progress_bar) for cl in self.child_layers] + units = [SumUnit(result) for _ in range(self.number_of_nodes)] - clw = self.normalized_weights - csc_weights = coo_matrix((clw.data, clw.indices.T), shape=clw.shape).tocsc(copy=False) + child_layer_nx = [cl.to_nx(variables, result, progress_bar) for cl in self.child_layers] - # offset for shifting through the frequencies of the node_to_child_frequency_map - prev_column_index = 0 + for log_weights, child_layer in zip(self.log_weights, child_layer_nx): - for child_layer in child_layer_nx: # extract the weights for the child layer - current_weight_block: csc_array = csc_weights[:, prev_column_index:prev_column_index + len(child_layer)] - current_weight_block: coo_array = current_weight_block.tocoo(False) - - for row, col, weight in zip(current_weight_block.row, current_weight_block.col, current_weight_block.data): - units[row].add_subcircuit(child_layer[col], weight) - + for ((row, col), log_weight) in zip(log_weights.indices, log_weights.data): + units[row].add_subcircuit(child_layer[col], jnp.exp(log_weight).item(), False) if progress_bar: progress_bar.update() - # shift the offset - prev_column_index += len(child_layer) + [unit.normalize() for unit in units] return units @@ -699,6 +449,8 @@ class ProductLayer(InnerLayer): The shape is (#child_layers, #nodes). """ + _variables: Optional[jnp.array] = None + def __init__(self, child_layers: List[Layer], edges: BCOO): """ Initialize the product layer. @@ -708,6 +460,7 @@ def __init__(self, child_layers: List[Layer], edges: BCOO): """ super().__init__(child_layers) self.edges = edges + self.variables def validate(self): assert self.edges.shape == (len(self.child_layers), self.number_of_nodes), \ @@ -726,13 +479,13 @@ def nx_classes(cls) -> Tuple[Type, ...]: def number_of_components(self) -> int: return sum([cl.number_of_components for cl in self.child_layers]) + self.edges.nse - @cached_property + @property def variables(self) -> jax.Array: - child_layer_variables = jnp.concatenate([child_layer.variables for child_layer in self.child_layers]) - max_size = child_layer_variables.shape[0] - unique_values = jnp.unique(child_layer_variables, size=max_size, fill_value=-1) - unique_values = unique_values[unique_values >= 0] - return unique_values.sort() + if self._variables is None: + variables = jnp.concatenate([child_layer.variables for child_layer in self.child_layers]) + variables = jnp.unique(variables) + self._variables = variables + return self._variables def log_likelihood_of_nodes_single(self, x: jax.Array) -> jax.Array: result = jnp.zeros(self.number_of_nodes, dtype=jnp.float32) @@ -749,62 +502,6 @@ def log_likelihood_of_nodes_single(self, x: jax.Array) -> jax.Array: return result - def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array: - result = jnp.ones(self.number_of_nodes, dtype=jnp.float32) - - for edges, layer in zip(self.edges, self.child_layers): - # calculate the cdf over the columns of the child layer - cdf = layer.cdf_of_nodes_single(x[layer.variables]) # shape: #child_nodes - - # gather the cdf at the indices of the nodes that are required for the edges - cdf = cdf[edges.data] # shape: #len(edges.values()) - - # multiply the gathered values by the result where the edges define the indices - result = result.at[edges.indices[:, 0]].mul(cdf) - - return result - - def probability_of_simple_event(self, event: SimpleEvent) -> jnp.array: - result = jnp.ones(self.number_of_nodes, dtype=jnp.float32) - - for edges, layer in zip(self.edges, self.child_layers): - # calculate the cdf over the columns of the child layer - prob = layer.probability_of_simple_event(event) # shape: #child_nodes - - # gather the cdf at the indices of the nodes that are required for the edges - prob = prob[edges.data] # shape: #len(edges.values()) - - # multiply the gathered values by the result where the edges define the indices - result = result.at[edges.indices[:, 0]].mul(prob) - - return result - - def sample_from_frequencies(self, frequencies: np.array, result: np.array, start_index=0): - edges_csr = coo_array((self.edges.data, self.edges.indices.T), shape=self.edges.shape).tocsr() - for row_index, (start, end, child_layer) in enumerate( - zip(edges_csr.indptr[:-1], edges_csr.indptr[1:], self.child_layers)): - # get the edges for the current child layer - row = edges_csr.data[start:end] - column_indices = edges_csr.indices[start:end] - - frequencies_for_child_layer = np.zeros((child_layer.number_of_nodes,), dtype=np.int32) - frequencies_for_child_layer[row] = frequencies[column_indices] - - child_layer.sample_from_frequencies(frequencies_for_child_layer, result, start_index) - - def moment_of_nodes(self, order: jax.Array, center: jax.Array): - result = jnp.full((self.number_of_nodes, self.variables.shape[0]), jnp.nan) - for edges, layer in zip(self.edges, self.child_layers): - edges = edges.sum_duplicates(remove_zeros=False) - - # calculate the moments over the columns of the child layer - child_layer_moment = layer.moment_of_nodes(order, center) - - # gather the moments at the indices of the nodes that are required for the edges - result = result.at[edges.indices[:, 0], layer.variables].set(child_layer_moment[edges.data][:, 0]) - - return result - def __deepcopy__(self): child_layers = [child_layer.__deepcopy__() for child_layer in self.child_layers] edges = copy_bcoo(self.edges) @@ -815,102 +512,6 @@ def to_json(self) -> Dict[str, Any]: result["edges"] = (self.edges.data.tolist(), self.edges.indices.tolist(), self.edges.shape) return result - def log_conditional_of_simple_event(self, event: SimpleEvent, ) -> Tuple[Optional[Self], jax.Array]: - - # initialize the conditional child layers and the log probabilities - log_probabilities = jnp.zeros(self.number_of_nodes, dtype=jnp.float32) - conditional_child_layers = [] - remapped_edges = [] - - # for edge bundle and child layer - for index, (edges, child_layer) in enumerate(zip(self.edges, self.child_layers)): - edges: BCOO - edges = edges.sum_duplicates(remove_zeros=False) - - # condition the child layer - conditional, child_log_prob = child_layer.log_conditional_of_simple_event(event) - - # if it is entirely impossible, this layer also is - if conditional is None: - continue - - # update the log probabilities and child layers - log_probabilities = log_probabilities.at[edges.indices[:, 0]].add(child_log_prob[edges.data]) - conditional_child_layers.append(conditional) - - # create the remapping of the node indices. nan indicates the node got deleted - # enumerate the indices of the conditional child layer nodes - new_node_indices = jnp.arange(conditional.number_of_nodes) - - # initialize the remapping of the child layer node indices - layer_remap = jnp.full((child_layer.number_of_nodes,), jnp.nan, dtype=jnp.float32) - layer_remap = layer_remap.at[child_log_prob > -jnp.inf].set(new_node_indices) - - # update the edges - remapped_child_edges = layer_remap[edges.data] - valid_edges = ~jnp.isnan(remapped_child_edges) - - # create new indices for the edges - new_indices = edges.indices[valid_edges] - new_indices = jnp.concatenate([jnp.zeros((len(new_indices), 1), dtype=jnp.int32), new_indices], - axis=1) - - new_edges = BCOO((remapped_child_edges[valid_edges].astype(jnp.int32), - new_indices), - shape=(1, self.number_of_nodes), indices_sorted=True, - unique_indices=True) - remapped_edges.append(new_edges) - - remapped_edges = bcoo_concatenate(remapped_edges, dimension=0).sort_indices() - - # get nodes that should be removed as boolean mask - remove_mask = log_probabilities == -jnp.inf # shape (#nodes, ) - keep_mask = ~remove_mask - - # remove the nodes that have -inf log probabilities from remapped_edges - remapped_edges = coo_array((remapped_edges.data, remapped_edges.indices.T), shape=remapped_edges.shape).tocsc() - remapped_edges = remapped_edges[:, keep_mask].tocoo() - remapped_edges = BCOO((remapped_edges.data, jnp.stack((remapped_edges.row, remapped_edges.col)).T), - shape=remapped_edges.shape, indices_sorted=True, unique_indices=True) - - # construct result and clean it up - result = self.__class__(conditional_child_layers, remapped_edges) - result = result.clean_up_orphans() - return result, log_probabilities - - def clean_up_orphans(self): - """ - Clean up the layer by removing orphans in the child layers. - """ - new_child_layers = [] - - for index, (edges, child_layer) in enumerate(zip(self.edges, self.child_layers)): - edges: BCOO - edges = edges.sum_duplicates(remove_zeros=False) - # mask rather nodes have parent edges or not - orphans = jnp.ones(child_layer.number_of_nodes, dtype=jnp.bool) - - # mark nodes that have parents with False - data = edges.data - if len(data) > 0: - orphans = orphans.at[data].set(False) - - # if orphans exist - if orphans.any(): - # remove them from the child layer - child_layer = child_layer.remove_nodes(orphans) - new_child_layers.append(child_layer) - - # compress edges - shrunken_indices = shrink_index_array(self.edges.indices) - new_edges = BCOO((self.edges.data, shrunken_indices), shape=self.edges.shape, indices_sorted=True, - unique_indices=True) - return self.__class__(new_child_layers, new_edges) - - def remove_nodes(self, remove_mask: jax.Array) -> Self: - new_edges = self.edges[:, ~remove_mask] - return self.__class__(self.child_layers, new_edges) - @classmethod def _from_json(cls, data: Dict[str, Any]) -> Self: child_layer = [Layer.from_json(child_layer) for child_layer in data["child_layers"]] @@ -955,21 +556,22 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[Unit], layer = cls([cl.layer for cl in child_layers], edges) return NXConverterLayer(layer, nodes, hash_remap) - def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm] = None) -> List[ - Unit]: + def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit, + progress_bar: Optional[tqdm.tqdm] = None) -> List[Unit]: + + if result is None: + result = NXProbabilisticCircuit() variables_ = [variables[i] for i in self.variables] if progress_bar: progress_bar.set_postfix_str(f"Parsing Product Layer of variables {variables_}") - nx_pc = NXProbabilisticCircuit() - units = [ProductUnit() for _ in range(self.number_of_nodes)] - nx_pc.add_nodes_from(units) - - child_layer_nx = [cl.to_nx(variables, progress_bar) for cl in self.child_layers] + units = [ProductUnit(result) for _ in range(self.number_of_nodes)] + child_layer_nx = [cl.to_nx(variables, result, progress_bar) for cl in self.child_layers] for (row, col), data in zip(self.edges.indices, self.edges.data): - units[col].add_subcircuit(child_layer_nx[row][data]) + units[col].add_subcircuit(child_layer_nx[row][data], mount=False) + if progress_bar: progress_bar.update() diff --git a/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py index 05833d9..6c4040f 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py @@ -27,108 +27,6 @@ class ContinuousLayer(InputLayer, ABC): Abstract base class for continuous univariate input units. """ - def probability_of_simple_event(self, event: SimpleEvent) -> jax.Array: - interval: Interval = list(event.values())[self.variables[0]] - return self.probability_of_interval(interval) - - def probability_of_interval(self, interval: Interval) -> jnp.array: - points = jnp.array([simple_interval_to_open_array(i) for i in interval.simple_sets]) - upper_bound_cdf = self.cdf_of_nodes(points[:, (1,)]) - lower_bound_cdf = self.cdf_of_nodes(points[:, (0,)]) - return (upper_bound_cdf - lower_bound_cdf).sum(axis=0) - - def probability_of_simple_interval(self, interval: SimpleInterval) -> jax.Array: - points = simple_interval_to_open_array(interval) - upper_bound_cdf = self.cdf_of_nodes_single(points[1]) - lower_bound_cdf = self.cdf_of_nodes_single(points[0]) - return upper_bound_cdf - lower_bound_cdf - - def log_conditional_of_simple_event(self, event: SimpleEvent) -> Tuple[ - Optional[Union[Self, DiracDeltaLayer]], jax.Array]: - if event.is_empty(): - return self.impossible_condition_result - - interval: Interval = list(event.values())[self.variable] - - if interval.is_singleton(): - return self.log_conditional_from_singleton(interval.simple_sets[0]) - - if len(interval.simple_sets) == 1: - return self.log_conditional_from_simple_interval(interval.simple_sets[0]) - else: - return self.log_conditional_from_interval(interval) - - def log_conditional_from_singleton(self, singleton: SimpleInterval) -> Tuple[DiracDeltaLayer, jax.Array]: - """ - Calculate the conditional distribution given a singleton interval. - - In this case, the conditional distribution is a Dirac delta distribution and the log-likelihood is chosen - instead of the log-probability. - - This method returns a Dirac delta layer that has at most the same number of nodes as the input layer. - - :param singleton: The singleton event - :return: The dirac delta layer and the log-likelihoods with shape (something <= #singletons, 1). - """ - value = singleton.lower - log_likelihoods = self.log_likelihood_of_nodes( - jnp.array(value).reshape(-1, 1))[:, 0] # shape: (#nodes, ) - - possible_indices = (log_likelihoods > -jnp.inf).nonzero()[0] # shape: (#dirac-nodes, ) - filtered_likelihood = log_likelihoods[possible_indices] - locations = jnp.full_like(filtered_likelihood, value) - layer = DiracDeltaLayer(self.variable, locations, jnp.exp(filtered_likelihood)) - return layer, log_likelihoods - - def log_conditional_from_simple_interval(self, interval: SimpleInterval) -> Tuple[Self, jax.Array]: - """ - Calculate the conditional distribution given a simple interval with p(interval) > 0. - The interval could also be a singleton. - - :param interval: The simple interval - :return: The conditional distribution and the log-probability of the interval. - """ - raise NotImplementedError - - def log_conditional_from_interval(self, interval: Interval) -> Tuple[SumLayer, jax.Array]: - """ - Calculate the conditional distribution given an interval with p(interval) > 0. - - :param interval: The simple interval - :return: The conditional distribution and the log-probability of the interval. - """ - - # get conditionals of each simple interval - results = [self.log_conditional_from_simple_interval(simple_interval) for simple_interval in - interval.simple_sets] - - layers, log_probs = zip(*results) - - # stack the log probabilities - stacked_log_probabilities = jnp.stack(log_probs, axis=1) # shape: (#simple_intervals, #nodes) - - # calculate the log probabilities of the entire interval - exp_stacked_log_probabilities = jnp.exp(stacked_log_probabilities) - summed_exp_stacked_log_probabilities = jnp.sum(exp_stacked_log_probabilities, axis=1) - total_log_probabilities = jnp.log(summed_exp_stacked_log_probabilities) # shape: (#nodes, 1) - - # create new input layer - possible_layers = [layer for layer in layers if layer is not None] - input_layer = possible_layers[0] - input_layer = input_layer.merge_with(possible_layers[1:]) - - # remove the rows that are entirely -inf and normalize weights - bcoo_data = remove_rows_and_cols_where_all(exp_stacked_log_probabilities / - summed_exp_stacked_log_probabilities.reshape(-1, 1), - 0) - - log_weights = BCOO.fromdense(bcoo_data) - log_weights.data = jnp.log(log_weights.data) - - resulting_layer = SumLayer([input_layer], [log_weights]) - return resulting_layer, total_log_probabilities - - class ContinuousLayerWithFiniteSupport(ContinuousLayer, ABC): """ Abstract class for continuous univariate input units with finite support. @@ -185,9 +83,6 @@ def to_json(self) -> Dict[str, Any]: def __deepcopy__(self): return self.__class__(self.variables[0].item(), self.interval.copy()) - def remove_nodes(self, remove_mask: jax.Array) -> Self: - return self.__class__(self.variable, self.interval[~remove_mask]) - class DiracDeltaLayer(ContinuousLayer): location: jax.Array = eqx.field(static=True) @@ -234,35 +129,6 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[Univariate result = cls(nodes[0].probabilistic_circuit.variables.index(nodes[0].variable), locations, density_caps) return NXConverterLayer(result, nodes, hash_remap) - def sample_from_frequencies(self, frequencies: np.array, result: np.array, start_index=0): - values = self.location.repeat(frequencies).reshape(-1, 1) - result[start_index:start_index + len(values), self.variables] = values - - def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array: - return jnp.where(x < self.location, 0., 1.) - - def moment_of_nodes(self, order: jax.Array, center: jax.Array): - order = order[self.variables[0]] - center = center[self.variables[0]] - if order == 0: - result = jnp.ones(self.number_of_nodes) - elif order == 1: - result = self.location - center - else: - result = jnp.zeros(self.number_of_nodes) - return result.reshape(-1, 1) - - def log_conditional_from_simple_interval(self, interval: SimpleInterval) -> Tuple[Self, jax.Array]: - log_probs = jnp.log(self.probability_of_simple_interval(interval)) - - valid_log_probs = log_probs > -jnp.inf - - if not valid_log_probs.any(): - return self.impossible_condition_result - - result = self.__class__(self.variable, self.location[valid_log_probs], - self.density_cap[valid_log_probs]) - return result, log_probs def to_json(self) -> Dict[str, Any]: result = super().to_json() @@ -274,24 +140,16 @@ def to_json(self) -> Dict[str, Any]: def _from_json(cls, data: Dict[str, Any]) -> Self: return cls(data["variable"], jnp.array(data["location"]), jnp.array(data["density_cap"])) - def merge_with(self, others: List[Self]) -> Self: - return self.__class__(self.variable, jnp.concatenate([self.location] + [other.location for other in others]), - jnp.concatenate([self.density_cap] + [other.density_cap for other in others])) - - def remove_nodes(self, remove_mask: jax.Array) -> Self: - return self.__class__(self.variable, self.location[~remove_mask], self.density_cap[~remove_mask]) - - def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm] = None) -> List[ + def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit, + progress_bar: Optional[tqdm.tqdm] = None) -> List[ Unit]: - nx_pc = NXProbabilisticCircuit() variable = variables[self.variable] if progress_bar: progress_bar.set_postfix_str(f"Creating Dirac Delta distributions for variable {variable.name}") - nodes = [UnivariateContinuousLeaf(DiracDeltaDistribution(variable, location.item(), density_cap.item())) + nodes = [UnivariateContinuousLeaf(DiracDeltaDistribution(variable, location.item(), density_cap.item()), result) for location, density_cap in zip(self.location, self.density_cap)] progress_bar.update(self.number_of_nodes) - nx_pc.add_nodes_from(nodes) return nodes diff --git a/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py b/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py index 0bae6ab..a36f34d 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py @@ -40,14 +40,6 @@ def __init__(self, variables: SortedSet, root: Layer): def log_likelihood(self, x: jax.Array) -> jax.Array: return self.root.log_likelihood_of_nodes(x)[:, 0] - def sample(self, amount: int) -> np.array: - result_array = np.full((amount, len(self.variables)), np.nan) - self.root.sample_from_frequencies(np.array([amount]), result_array) - return result_array - - def probability_of_simple_event(self, event: SimpleEvent): - return self.root.probability_of_simple_event(event) - @classmethod def from_nx(cls, pc: NXProbabilisticCircuit, progress_bar: bool = False) -> ProbabilisticCircuit: """ @@ -89,7 +81,9 @@ def to_nx(self, progress_bar: bool = True) -> NXProbabilisticCircuit: progress_bar = tqdm.tqdm(total=number_of_edges, desc="Converting to nx") else: progress_bar = None - return self.root.to_nx(self.variables, progress_bar)[0].probabilistic_circuit + result = NXProbabilisticCircuit() + self.root.to_nx(self.variables, result, progress_bar) + return result def to_json(self) -> Dict[str, Any]: result = super().to_json() diff --git a/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py index d392a7b..12f83e1 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py @@ -62,68 +62,26 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[Univariate result = cls(nodes[0].probabilistic_circuit.variables.index(variable), intervals) return NXConverterLayer(result, nodes, hash_remap) - def sample_from_frequencies(self, frequencies: np.array, result: np.array, start_index=0): - # sample from U(0,1) - standard_uniform_samples = np.random.uniform(size=(sum(frequencies), 1)) - - # calculate range for each node - range_per_sample = (self.upper - self.lower).repeat(frequencies).reshape(-1, 1) - - # calculate the right shift for each node - right_shift_per_sample = self.lower.repeat(frequencies).reshape(-1, 1) - - # apply the transformation to the desired intervals - samples = standard_uniform_samples * range_per_sample + right_shift_per_sample - - result[start_index:start_index + len(samples), self.variables] = samples - - def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array: - return jnp.clip((x - self.lower) / (self.upper - self.lower), 0, 1) - - def moment_of_nodes(self, order: jax.Array, center: jax.Array): - """ - Calculate the moment of the uniform distribution. - """ - order = order[self.variables[0]] - center = center[self.variables[0]] - pdf_value = jnp.exp(self.log_pdf_value()) - lower_integral_value = (pdf_value * (self.lower - center) ** (order + 1)) / (order + 1) - upper_integral_value = (pdf_value * (self.upper - center) ** (order + 1)) / (order + 1) - return (upper_integral_value - lower_integral_value).reshape(-1, 1) - - def log_conditional_from_simple_interval(self, interval: SimpleInterval) -> Tuple[Self, jax.Array]: - probabilities = jnp.log(self.probability_of_simple_interval(interval)) - open_interval_array = simple_interval_to_open_array(interval) - new_lowers = jnp.maximum(self.lower, open_interval_array[0]) - new_uppers = jnp.minimum(self.upper, open_interval_array[1]) - valid_intervals = new_lowers < new_uppers - new_intervals = jnp.stack([new_lowers[valid_intervals], new_uppers[valid_intervals]]).T - return self.__class__(self.variable, new_intervals), probabilities - - def merge_with(self, others: List[Self]) -> Self: - return self.__class__(self.variable, jnp.vstack([self.interval] + [other.interval for other in others])) - @classmethod def _from_json(cls, data: Dict[str, Any]) -> Self: return cls(data["variable"], jnp.array(data["interval"])) - def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm] = None) -> List[ - Unit]: + def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit, + progress_bar: Optional[tqdm.tqdm] = None) -> List[Unit]: variable = variables[self.variable] if progress_bar: progress_bar.set_postfix_str(f"Creating Uniform distributions for variable {variable.name}") - nx_pc = NXProbabilisticCircuit() nodes = [UnivariateContinuousLeaf( UniformDistribution(variable=variable, interval=random_events.interval.SimpleInterval(lower.item(), upper.item(), random_events.interval.Bound.OPEN, - random_events.interval.Bound.OPEN))) + random_events.interval.Bound.OPEN)), + result) for lower, upper in self.interval] if progress_bar: progress_bar.update(self.number_of_nodes) - nx_pc.add_nodes_from(nodes) return nodes diff --git a/test/test_jax/test_discrete_layer.py b/test/test_jax/test_discrete_layer.py new file mode 100644 index 0000000..9e976e6 --- /dev/null +++ b/test/test_jax/test_discrete_layer.py @@ -0,0 +1,71 @@ +import unittest + +from random_events.set import SetElement +from random_events.variable import Continuous, Symbolic +from sortedcontainers import SortedSet + +from probabilistic_model.distributions import SymbolicDistribution +from probabilistic_model.probabilistic_circuit.jax.discrete_layer import DiscreteLayer +from probabilistic_model.probabilistic_circuit.jax.gaussian_layer import GaussianLayer, GaussianDistribution +from probabilistic_model.probabilistic_circuit.nx.probabilistic_circuit import \ + ProbabilisticCircuit as NXProbabilisticCircuit, SumUnit +from probabilistic_model.probabilistic_circuit.nx.distributions import UnivariateContinuousLeaf, UnivariateDiscreteLeaf +from probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit +import jax.numpy as jnp + +from probabilistic_model.utils import MissingDict + + +class Animal(SetElement): + EMPTY_SET = -1 + CAT = 0 + DOG = 1 + FISH = 2 + +class DiscreteLayerTestCase(unittest.TestCase): + + model: DiscreteLayer + x = Symbolic("x", Animal) + + @classmethod + def setUpClass(cls): + cls.model = DiscreteLayer(0, jnp.log(jnp.array([[0, 1, 2], [3, 4, 0]]))) + cls.model.validate() + + def test_normalization(self): + result = self.model.normalization_constant + correct = jnp.array([3., 7.]) + self.assertTrue(jnp.allclose(result, correct, atol=1e-3)) + + def test_log_likelihood(self): + x = jnp.array([0.0]) + result = self.model.log_likelihood_of_nodes_single(x) + correct = jnp.log(jnp.array([.0, 3/7])) + self.assertTrue(jnp.allclose(result, correct, atol=1e-3)) + + def test_from_nx(self): + + p1 = MissingDict(float, {Animal.CAT: 0., Animal.DOG: 1, Animal.FISH: 2}) + d1 = UnivariateDiscreteLeaf(SymbolicDistribution(self.x, p1)) + + p2 = MissingDict(float, {Animal.CAT: 3, Animal.DOG: 4, Animal.FISH: 0}) + d2 = UnivariateDiscreteLeaf(SymbolicDistribution(self.x, p2)) + s = SumUnit() + s.add_subcircuit(d1, 0.5) + s.add_subcircuit(d2, 0.5) + nx_pc = s.probabilistic_circuit + jax_pc = ProbabilisticCircuit.from_nx(nx_pc) + discrete_layer = jax_pc.root.child_layers[0] + self.assertIsInstance(discrete_layer, DiscreteLayer) + self.assertEqual(discrete_layer.variable, 0) + self.assertEqual(discrete_layer.log_probabilities.shape, (2, 3)) + self.assertTrue(jnp.allclose(discrete_layer.log_probabilities, self.model.log_probabilities)) + + def test_to_nx(self): + nx_circuit = self.model.to_nx(SortedSet([self.x]), NXProbabilisticCircuit())[0].probabilistic_circuit + self.assertEqual(len(nx_circuit.nodes()), 2) + self.assertEqual(len(nx_circuit.edges()), 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_jax/test_input_layer.py b/test/test_jax/test_input_layer.py index 9a57000..a0d73c9 100644 --- a/test/test_jax/test_input_layer.py +++ b/test/test_jax/test_input_layer.py @@ -20,28 +20,6 @@ def test_likelihood(self): [-jnp.inf, -jnp.inf]] assert jnp.allclose(ll, jnp.array(result)) - def test_cdf(self): - data = jnp.array([-1, 0, 1, 2], dtype=jnp.float32).reshape(-1, 1) - cdf = self.layer.cdf_of_nodes(data) - result = jnp.array([[0, 0], [1, 0], [1, 1], [1, 1]], dtype=jnp.float32) - self.assertTrue(jnp.allclose(cdf, result)) - - def test_moment(self): - order = jnp.array([1.], dtype=jnp.int32) - center = jnp.array([1.5], dtype=jnp.float32) - moment = self.layer.moment_of_nodes(order, center) - result = jnp.array([-1.5, -0.5], dtype=jnp.float32).reshape(-1, 1) - self.assertTrue(jnp.allclose(moment, result)) - - def test_conditional_of_simple_interval(self): - interval = closed(-0.5, 0.5).simple_sets[0] - layer, ll = self.layer.log_conditional_from_simple_interval(interval) - result = jnp.log(jnp.array([1, 0], dtype=jnp.float32)) - self.assertTrue(jnp.allclose(ll, result)) - layer.validate() - self.assertEqual(layer.number_of_nodes, 1) - self.assertTrue(jnp.allclose(layer.location, jnp.array([0.]))) - self.assertTrue(jnp.allclose(layer.density_cap, jnp.array([1.]))) if __name__ == '__main__': diff --git a/test/test_jax/test_probabilistic_circuit.py b/test/test_jax/test_probabilistic_circuit.py index 9cd04ea..f8700ad 100644 --- a/test/test_jax/test_probabilistic_circuit.py +++ b/test/test_jax/test_probabilistic_circuit.py @@ -89,17 +89,7 @@ def test_trainable_parameters(self): number_of_parameters = sum([len(p) for p in flattened_params]) self.assertEqual(number_of_parameters, 10) - def test_serialization(self): - json = self.jax_model.to_json() - model = ProbabilisticCircuit.from_json(json) - samples = model.sample(1000) - jax_ll = model.log_likelihood(samples) - self.assertTrue((jax_ll > -jnp.inf).all()) - def test_sample(self): - samples = self.jax_model.sample(100) - ll = self.jax_model.log_likelihood(samples) - self.assertTrue((ll > -jnp.inf).all()) class JPTIntegrationTestCase(unittest.TestCase): number_of_variables = 2 diff --git a/test/test_jax/test_product_layer.py b/test/test_jax/test_product_layer.py index efb841c..10d680c 100644 --- a/test/test_jax/test_product_layer.py +++ b/test/test_jax/test_product_layer.py @@ -55,71 +55,8 @@ def test_likelihood(self): self.assertTrue(likelihood[0, 1] == -jnp.inf) self.assertTrue(likelihood[1, 0] == -jnp.inf) - def test_cdf(self): - data = jnp.array([[0, 0, 0], [0, 5, 6], [2, 4, 6], [10, 10, 10]], dtype=jnp.float32) - cdf = self.product_layer.cdf_of_nodes(data) - self.assertEqual(cdf.shape, (4, 2)) - result = jnp.array([[0, 0], [1., 0.], [0., 1], [1, 1]], dtype=jnp.float32) - self.assertTrue(jnp.allclose(cdf, result)) - def test_moment(self): - order = jnp.array([1, 1, 2], dtype=jnp.int32) - center = jnp.array([0., 1., 2], dtype=jnp.float32) - moment = self.product_layer.moment_of_nodes(order, center) - result = jnp.array([[0, 4., 0.], - [2., 3., 0.]], dtype=jnp.float32) - self.assertTrue(jnp.allclose(moment, result)) - - def test_probability(self): - event = SimpleEvent({self.x: closed(0,2), self.y: singleton(5), self.z: singleton(6.)}) - prob = self.product_layer.probability_of_simple_event(event) - result = jnp.array([1, 0], dtype=jnp.float32) - self.assertTrue(jnp.allclose(prob, result)) - - def test_conditioning(self): - - event = SimpleEvent({self.x: closed(-1, 1), - self.y: closed(4.5, 5.5), - self.z: closed(5.5, 6.5)}) - - conditional, log_prob = self.product_layer.log_conditional_of_simple_event(event) - conditional.validate() - self.assertTrue(jnp.allclose(log_prob, jnp.log(jnp.array([1., 0.])))) - self.assertEqual(conditional.number_of_nodes, 1) - self.assertEqual(len(conditional.child_layers), 3) - self.assertEqual(conditional.child_layers[0].number_of_nodes, 1) - self.assertEqual(conditional.child_layers[1].number_of_nodes, 1) - self.assertEqual(conditional.child_layers[2].number_of_nodes, 1) - - -class PCProductLayerTestCase(unittest.TestCase): - - x = Continuous("x") - y = Continuous("y") - z = Continuous("z") - - p1_x = DiracDeltaLayer(0, jnp.array([0., 1.]), jnp.array([1, 1])) - p2_x = DiracDeltaLayer(0, jnp.array([2., 3.]), jnp.array([1, 1])) - p_y = DiracDeltaLayer(1, jnp.array([4., 5.]), jnp.array([1, 1])) - p_z = DiracDeltaLayer(2, jnp.array([6.]), jnp.array([1])) - model: ProbabilisticCircuit - - def setUp(self): - indices = jnp.array([[0, 0], - [1, 0], - [3, 0]]) - values = jnp.array([0, 0, 1]) - edges = BCOO((values, indices), shape=(4, 2)).sum_duplicates(remove_zeros=False).sort_indices() - product_layer = ProductLayer([self.p_z, self.p1_x, self.p2_x, self.p_y, ], edges) - self.model = ProbabilisticCircuit(SortedSet([self.x, self.y, self.z]), product_layer) - - def test_sample(self): - samples = self.model.sample(3) - result = np.array([[0, 5, 6], - [0, 5, 6], - [0, 5, 6]]) - self.assertTrue(np.allclose(samples, result)) if __name__ == '__main__': diff --git a/test/test_jax/test_sum_layer.py b/test/test_jax/test_sum_layer.py index 2dda14e..52ef529 100644 --- a/test/test_jax/test_sum_layer.py +++ b/test/test_jax/test_sum_layer.py @@ -69,97 +69,11 @@ def test_ll(self): [0., 0.,]])) assert jnp.allclose(ll, result) - def test_cdf(self): - data = jnp.arange(7, dtype=jnp.float32).reshape(-1, 1) - 0.5 - cdf = self.sum_layer.cdf_of_nodes(data) - self.assertEqual(cdf.shape, (7, 2)) - result = jnp.array([[0, 0], # -0.5 - [0, 0.4], # 0.5 - [0.1, 0.4], # 1.5 - [0.3, 0.7], # 2.5 - [0.6, 0.7], # 3.5 - [0.6, 0.8], # 4.5 - [1, 1], # 5.5 - ], dtype=jnp.float32) - self.assertTrue(jnp.allclose(cdf, result)) - - def test_moment(self): - order = jnp.array([1], dtype=jnp.int32) - center = jnp.array([2.5], dtype=jnp.float32) - moment = self.sum_layer.moment_of_nodes(order, center) - result = jnp.array([0.9, -0.5], dtype=jnp.float32).reshape(-1, 1) - self.assertTrue(jnp.allclose(moment, result)) - - def test_probability(self): - event = SimpleEvent({self.x: closed(0.5, 2.5) | closed(4.5, 10)}) - prob = self.sum_layer.probability_of_simple_event(event) - result = jnp.array([0.7, 0.5], dtype=jnp.float32) - self.assertTrue(jnp.allclose(result, prob)) - - def test_conditional(self): - event = SimpleEvent({self.x: closed(0.5, 1.5)}) - c, lp = self.sum_layer.log_conditional_of_simple_event(event) - c.validate() - self.assertEqual(c.number_of_nodes, 1) - self.assertEqual(len(c.child_layers), 1) - self.assertEqual(c.child_layers[0].number_of_nodes, 1) - self.assertTrue(jnp.allclose(c.log_weights[0].todense(), jnp.array([[0.]]))) - self.assertTrue(jnp.allclose(lp, jnp.log(jnp.array([0.1, 0.])))) - - def test_conditional_2(self): - event = SimpleEvent({self.x: closed(1.5, 4.5)}) - c, lp = self.sum_layer.log_conditional_of_simple_event(event) - c.validate() - self.assertEqual(c.number_of_nodes, 2) - self.assertEqual(len(c.child_layers), 2) - self.assertEqual(c.child_layers[0].number_of_nodes, 1) - self.assertEqual(c.child_layers[1].number_of_nodes, 2) - - def test_remove(self): - result = self.sum_layer.remove_nodes(jnp.array([True, False])) - result.validate() - self.assertEqual(result.number_of_nodes, 1) - - -class PCSumUnitTestCase(unittest.TestCase): - x: Continuous = Continuous("x") - - p1_x = DiracDeltaLayer(0, jnp.array([0., 1.]), jnp.array([1, 2])) - p2_x = DiracDeltaLayer(0,jnp.array([2.]), jnp.array([3])) - p3_x = DiracDeltaLayer(0, jnp.array([3., 4., 5.]), jnp.array([4, 5, 6])) - p4_x = DiracDeltaLayer(0, jnp.array([6.]), jnp.array([1])) - model: ProbabilisticCircuit - - @classmethod - def setUpClass(cls): - weights_p1 = BCOO.fromdense(jnp.array([[0, 0.1]])) * 2 - weights_p1.data = jnp.log(weights_p1.data) - - weights_p2 = BCOO.fromdense(jnp.array([[0.2]])) * 2 - weights_p2.data = jnp.log(weights_p2.data) - - weights_p3 = BCOO.fromdense(jnp.array([[0.3, 0, 0.4]])) * 2 - weights_p3.data = jnp.log(weights_p3.data) - - weights_p4 = BCOO.fromdense(jnp.array([[0]])) * 2 - weights_p4.data = jnp.log(weights_p4.data) - - sum_layer = SumLayer([cls.p1_x, cls.p2_x, cls.p3_x, cls.p4_x], - log_weights=[weights_p1, weights_p2, weights_p3, weights_p4]) - sum_layer.validate() - cls.model = ProbabilisticCircuit(SortedSet([cls.x]), sum_layer) - - def test_sampling(self): - np.random.seed(69) - samples = self.model.sample(10) - self.assertEqual(samples.shape, (10, 1)) - result = np.array([2, 2, 2, 3, 3, 5, 5, 5, 5, 5]).reshape(-1, 1) - self.assertTrue(np.allclose(samples, result)) - - def test_conditioning(self): - event = SimpleEvent({self.x: closed(1.5, 4.5)}) - conditional, log_prob = self.model.root.log_conditional_of_simple_event(event) - # conditional.validate() + def test_ll_single(self): + data = jnp.array([0]) + l = self.sum_layer.log_likelihood_of_nodes_single(data) + result = jnp.log(jnp.array([0., 0.4])) + assert jnp.allclose(l, result) class NygaDistributionTestCase(unittest.TestCase): @@ -180,7 +94,3 @@ def setUpClass(cls): def test_log_likelihood(self): ll = self.jax_model.log_likelihood(self.data) self.assertTrue(jnp.all(ll > -jnp.inf)) - - def test_sampling(self): - data = self.jax_model.sample(1000) - self.assertEqual(data.shape, (1000, 1)) \ No newline at end of file diff --git a/test/test_jax/test_uniform_layer.py b/test/test_jax/test_uniform_layer.py index de90d55..33e2941 100644 --- a/test/test_jax/test_uniform_layer.py +++ b/test/test_jax/test_uniform_layer.py @@ -40,73 +40,9 @@ def test_from_interval(self): result = jnp.array([[0, 0, 1, 1], [0, 1, 0, 1]]) self.assertTrue(jnp.allclose(ll, result)) - def test_cdf(self): - data = jnp.array([0.5, 1.5, 4]).reshape(-1, 1) - cdf = self.p_x.cdf_of_nodes(data) - self.assertEqual(cdf.shape, (3, 2)) - result = jnp.array([[0.5, 0], [1, 0.25], [1, 1]]) - self.assertTrue(jnp.allclose(cdf, result)) - - def test_moment(self): - order = jnp.array([1], dtype=jnp.int32) - center = jnp.array([1.], dtype=jnp.float32) - moment = self.p_x.moment_of_nodes(order, center) - result = jnp.array([[-0.5], [1.]], dtype=jnp.float32) - self.assertTrue(jnp.allclose(moment, result)) - - def test_probability(self): - event = SimpleEvent({self.x: closed(0.5, 2.5) | closed(3, 5)}) - prob = self.p_x.probability_of_simple_event(event) - self.assertEqual(prob.shape, (2,)) - result = jnp.array([0.5, 0.75]) - self.assertTrue(jnp.allclose(prob, result)) - def test_to_json(self): data = self.p_x.to_json() json.dumps(data) p_x = UniformLayer.from_json(data) self.assertTrue(jnp.allclose(self.p_x.interval, p_x.interval)) - - def test_conditional_singleton(self): - event = SimpleEvent({self.x: closed(0.5, 0.5)}) - layer, ll = self.p_x.log_conditional_of_simple_event(event) - self.assertEqual(layer.number_of_nodes, 1) - self.assertTrue(jnp.allclose(jnp.array([0.5]), layer.location)) - self.assertTrue(jnp.allclose(jnp.array([1.]), layer.density_cap)) - - def test_conditional_single_truncation(self): - event = SimpleEvent({self.x: closed(0.5, 2.5)}) - layer, ll = self.p_x.log_conditional_of_simple_event(event) - layer.validate() - self.assertEqual(layer.number_of_nodes, 2) - self.assertTrue(jnp.allclose(layer.interval, jnp.array([[0.5, 1], [1, 2.5]]))) - self.assertTrue(jnp.allclose(jnp.log(jnp.array([0.5, 0.75])), ll)) - - def test_conditional_with_node_removal(self): - event = SimpleEvent({self.x: closed(0.25, 0.5)}) - layer, ll = self.p_x.log_conditional_of_simple_event(event) - layer.validate() - self.assertEqual(layer.number_of_nodes, 1) - self.assertTrue(jnp.allclose(layer.interval, jnp.array([[0.25, 0.5]]))) - self.assertTrue(jnp.allclose(jnp.log(jnp.array([0.25, 0.])), ll)) - - def test_conditional_multiple_truncation(self): - event = closed(-1, 0.5) | closed(0.7, 0.8) | closed(2., 3.) | closed(3.5, 4.) - - layer, ll = self.p_x.log_conditional_from_interval(event) - - self.assertTrue(jnp.allclose(jnp.log(jnp.array([0.6, 0.5])), ll)) - self.assertIsInstance(layer, SumLayer) - - layer.validate() - self.assertEqual(layer.number_of_nodes, 2) - self.assertEqual(len(layer.child_layers), 1) - self.assertTrue(jnp.allclose(layer.child_layers[0].interval, jnp.array([[0., 0.5], [0.7, 0.8], [2., 3.]]))) - - log_weights_by_hand = jnp.array([[0.5, 0.1, 0.], [0., 0., 0.5]]) - log_weights_by_hand /= jnp.sum(log_weights_by_hand, axis=1, keepdims=True) - log_weights_by_hand = BCOO.fromdense(log_weights_by_hand) - log_weights_by_hand.data = jnp.log(log_weights_by_hand.data) - self.assertTrue(jnp.allclose(layer.log_weights[0].data, log_weights_by_hand.data)) - self.assertTrue(jnp.allclose(layer.log_weights[0].indices, log_weights_by_hand.indices)) diff --git a/test/test_rat_spn/__init__.py b/test/test_rat_spn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_rat_spn/test_region_graph.py b/test/test_rat_spn/test_region_graph.py new file mode 100644 index 0000000..4e902a7 --- /dev/null +++ b/test/test_rat_spn/test_region_graph.py @@ -0,0 +1,74 @@ +import random +import unittest + +from jax import tree_flatten +from matplotlib import pyplot as plt +from networkx.drawing.nx_agraph import graphviz_layout +from random_events.variable import Continuous +import pydot +from probabilistic_model.learning.region_graph.region_graph import * +from probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit as JPC +from probabilistic_model.probabilistic_circuit.jax.gaussian_layer import GaussianLayer +import numpy as np +import equinox as eqx +import optax +import tqdm +import plotly.graph_objects as go + +np.random.seed(420) +random.seed(420) + +class RandomRegionGraphTestCase(unittest.TestCase): + + variables = SortedSet([Continuous(str(i)) for i in range(4)]) + + region_graph = RegionGraph(variables, partitions=2, depth=1, repetitions=2) + region_graph = region_graph.create_random_region_graph() + + def test_region_graph(self): + self.assertEqual(len(self.region_graph.nodes()), 19) + + def test_as_jpc(self): + model = self.region_graph.as_probabilistic_circuit(input_units=10, sum_units=5) + nx_model = model.to_nx() + nx_model.plot_structure() + # plt.show() + self.assertEqual(len(list(node for node in nx_model.nodes() if isinstance(node, SumUnit))), 21) + + + +class RandomRegionGraphLearningTestCase(unittest.TestCase): + + variables = SortedSet([Continuous(str(i)) for i in range(4)]) + region_graph = RegionGraph(variables, partitions=2, depth=1, repetitions=2) + region_graph = region_graph.create_random_region_graph() + + def test_learning(self): + data = np.random.uniform(0, 1, (10000, len(self.variables))) + model = self.region_graph.as_probabilistic_circuit(input_units=5, sum_units=5) + + root = model.root + + @eqx.filter_jit + def loss(model, x): + ll = model.log_likelihood_of_nodes(x) + return -jnp.mean(ll) + + optim = optax.adamw(0.01) + opt_state = optim.init(eqx.filter(root, eqx.is_inexact_array)) + + for _ in tqdm.trange(50): + loss_value, grads = eqx.filter_value_and_grad(loss)(root, data) + grads_of_sum_layer = eqx.filter(tree_flatten(grads), eqx.is_inexact_array)[0][0] + self.assertTrue(jnp.all(jnp.isfinite(grads_of_sum_layer))) + + updates, opt_state = optim.update( + grads, opt_state, eqx.filter(root, eqx.is_inexact_array) + ) + root = eqx.apply_updates(root, updates) + + + + +if __name__ == '__main__': + unittest.main()