From 7e9205149e37c44ec7b98206ccd964244ce33834 Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Fri, 13 Dec 2024 13:08:57 +0100 Subject: [PATCH 01/10] Removed unsupported things from jax circuits --- .../probabilistic_circuit/jax/inner_layer.py | 393 +----------------- .../probabilistic_circuit/jax/input_layer.py | 141 ------- .../jax/probabilistic_circuit.py | 8 - test/test_jax/test_input_layer.py | 22 - test/test_jax/test_probabilistic_circuit.py | 10 - test/test_jax/test_product_layer.py | 63 --- test/test_jax/test_sum_layer.py | 96 ----- test/test_jax/test_uniform_layer.py | 64 --- 8 files changed, 2 insertions(+), 795 deletions(-) diff --git a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py index 1ad7b9b..7c11ac7 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py @@ -58,30 +58,6 @@ def log_likelihood_of_nodes(self, x: jnp.array) -> jnp.array: """ return jax.vmap(self.log_likelihood_of_nodes_single)(x) - def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array: - """ - Calculate the cumulative distribution function of the distribution if applicable. - - :param x: The input vector. - :return: The cumulative distribution function of every node in the layer for x. - """ - raise NotImplementedError - - def cdf_of_nodes(self, x: jnp.array) -> jnp.array: - """ - Vectorized version of :meth:`cdf_of_nodes_single` - """ - return jax.vmap(self.cdf_of_nodes_single)(x) - - def probability_of_simple_event(self, event: SimpleEvent) -> jnp.array: - """ - Calculate the probability of a simple event P(E). - - :param event: The simple event to calculate the probability for. It has to contain every variable. - :return: P(E) - """ - raise NotImplementedError - def validate(self): """ Validate the parameters and their layouts. @@ -142,23 +118,6 @@ def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm """ raise NotImplementedError - @property - def impossible_condition_result(self) -> Tuple[None, jax.Array]: - """ - :return: The result that a layer yields if it is conditioned on an event E with P(E) = 0 - """ - return None, jnp.full((self.number_of_nodes,), -jnp.inf, dtype=jnp.float32) - - def log_conditional_of_simple_event(self, event: SimpleEvent, ) -> Tuple[Optional[Self], jax.Array]: - """ - Calculate the conditional probability distribution given a simple event P(X|E). - Also return the log probability of E log(P(E)). - - :param event: The event to calculate the conditional distribution for. - :return: The conditional distribution and the log probability of the event. - """ - raise NotImplementedError - @staticmethod def create_layers_from_nodes(nodes: List[Unit], child_layers: List[NXConverterLayer], progress_bar: bool = True) \ @@ -219,42 +178,12 @@ def number_of_trainable_parameters(self): return number_of_parameters @property - def number_of_components(self) -> 0: + def number_of_components(self) -> int: """ :return: The number of components (leaves + edges) of the entire circuit """ return self.number_of_nodes - def sample_from_frequencies(self, frequencies: np.array, result: np.array, start_index=0): - raise NotImplementedError - - def moment_of_nodes(self, order: jax.Array, center: jax.Array): - """ - Calculate the moment of the nodes. - The order and center vectors describe the moments for all variables in the entire model. Hence, they should - never be touched by the forward pass. - - :param order: The order of the moment for each variable. - :param center: The center of the moment for each variable. - :return: The moments of the nodes with shape (#nodes, #variables). - """ - raise NotImplementedError - - def merge_with(self, others: List[Self]) -> Self: - """ - Merge the layer with others of the same type. - """ - raise NotImplementedError - - def remove_nodes(self, remove_mask: jax.Array) -> Self: - """ - Remove nodes from the layer. - - :param remove_mask: A boolean mask of the nodes to remove. - :return: The layer with the nodes removed. - """ - raise NotImplementedError - class InnerLayer(Layer, ABC): """ @@ -298,12 +227,6 @@ def to_json(self) -> Dict[str, Any]: result["child_layers"] = [child_layer.to_json() for child_layer in self.child_layers] return result - def clean_up_orphans(self) -> Self: - """ - Clean up the layer by removing orphans in the child layers. - """ - raise NotImplementedError - class InputLayer(Layer, ABC): """ @@ -337,6 +260,7 @@ def variable(self): class SumLayer(InnerLayer): + log_weights: List[BCOO] child_layers: Union[List[[ProductLayer]], List[InputLayer]] @@ -421,160 +345,6 @@ def log_likelihood_of_nodes_single(self, x: jax.Array) -> jax.Array: return jnp.where(result > 0, jnp.log(result) - self.log_normalization_constants, -jnp.inf) - def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array: - result = jnp.zeros(self.number_of_nodes, dtype=jnp.float32) - - for log_weights, child_layer in self.log_weighted_child_layers: - # get the cdf of the child nodes - child_layer_cdf = child_layer.cdf_of_nodes_single(x) - - # weight the cdf of the child nodes by the weight for each node of this layer - cloned_log_weights = copy_bcoo(log_weights) # clone the weights - - # multiply the weights with the child layer cdf - cloned_log_weights.data = jnp.exp(cloned_log_weights.data) # exponent weights - cloned_log_weights.data *= child_layer_cdf[cloned_log_weights.indices[:, 1]] - - # sum the weights for each node - ll = cloned_log_weights.sum(1).todense() - - # sum the child layer result - result += ll - - # normalize the result - normalization_constants = jnp.exp(self.log_normalization_constants) - return result / normalization_constants - - def probability_of_simple_event(self, event: SimpleEvent) -> jnp.array: - result = jnp.zeros(self.number_of_nodes, dtype=jnp.float32) - - for log_weights, child_layer in self.log_weighted_child_layers: - # get the probability of the child nodes - child_layer_prob = child_layer.probability_of_simple_event(event) - - # weight the probability of the child nodes by the weight for each node of this layer - cloned_log_weights = copy_bcoo(log_weights) # clone the weights - - # multiply the weights with the child layer cdf - cloned_log_weights.data = jnp.exp(cloned_log_weights.data) # exponent weights - cloned_log_weights.data *= child_layer_prob[cloned_log_weights.indices[:, 1]] - - # sum the weights for each node - ll = cloned_log_weights.sum(1).todense() - - # sum the child layer result - result += ll - - # normalize the result - normalization_constants = jnp.exp(self.log_normalization_constants) - return result / normalization_constants - - def moment_of_nodes(self, order: jax.Array, center: jax.Array): - result = jnp.zeros((self.number_of_nodes, len(self.variables)), dtype=jnp.float32) - - for log_weights, child_layer in self.log_weighted_child_layers: - # get the moment of the child nodes - moment = child_layer.moment_of_nodes(order, center) # shape (#child_layer_nodes, #variables) - - # weight the moment of the child nodes by the weight for each node of this layer - weights = copy_bcoo(log_weights) # clone the weights, shape (#nodes, #child_layer_nodes) - weights.data = jnp.exp(weights.data) # exponent weights - - # calculate the weighted sum in layer - moment = weights @ moment - - # sum the child layer result - result += moment - - return result / jnp.exp(self.log_normalization_constants.reshape(-1, 1)) - - def sample_from_frequencies(self, frequencies: np.array, result: np.array, start_index=0): - node_to_child_frequency_map = self.node_to_child_frequency_map(frequencies) - - # offset for shifting through the frequencies of the node_to_child_frequency_map - prev_column_index = 0 - - consumed_indices = start_index - - for child_layer in self.child_layers: - # extract the frequencies for the child layer - current_frequency_block = node_to_child_frequency_map[:, - prev_column_index:prev_column_index + child_layer.number_of_nodes] - frequencies_for_child_nodes = current_frequency_block.sum(0) - child_layer.sample_from_frequencies(frequencies_for_child_nodes, result, consumed_indices) - consumed_indices += frequencies_for_child_nodes.sum() - - # shift the offset - prev_column_index += child_layer.number_of_nodes - - def node_to_child_frequency_map(self, frequencies: np.array): - """ - Sample from the exact distribution of the layer by interpreting every node as latent variable. - This is very slow due to BCOO.sum_duplicates being very slow. - - :param frequencies: - :param key: - :return: - """ - clw = self.normalized_weights - csr = coo_matrix((clw.data, clw.indices.T), shape=clw.shape).tocsr(copy=False) - return sample_from_sparse_probabilities_csc(csr, frequencies) - - def log_conditional_of_simple_event(self, event: SimpleEvent, ) -> Tuple[Optional[Self], jax.Array]: - conditional_child_layers = [] - conditional_log_weights = [] - - probabilities = jnp.zeros(self.number_of_nodes, dtype=jnp.float32) - - for log_weights, child_layer in self.log_weighted_child_layers: - # get the conditional of the child layer - conditional, child_log_prob = child_layer.log_conditional_of_simple_event(event) - if conditional is None: - continue - - # clone weights - log_weights = copy_bcoo(log_weights) - - # calculate the weighted sum of the child log probabilities - log_weights.data += child_log_prob[log_weights.indices[:, 1]] - - # skip if this layer is not connected to anything anymore - if jnp.all(log_weights.data == -jnp.inf): - continue - - log_weights.data = jnp.exp(log_weights.data) - - # calculate the probabilities of the child nodes in total - current_probabilities = log_weights.sum(1).todense() - probabilities += current_probabilities - - log_weights.data = jnp.log(log_weights.data) - - conditional_child_layers.append(conditional) - conditional_log_weights.append(log_weights) - - if len(conditional_child_layers) == 0: - return self.impossible_condition_result - - log_probabilities = jnp.log(probabilities) - - concatenated_log_weights = bcoo_concatenate(conditional_log_weights, dimension=1).sort_indices() - # remove rows and columns where all weights are -inf - cleaned_log_weights = sparse_remove_rows_and_cols_where_all(concatenated_log_weights, -jnp.inf) - - # normalize the weights - z = cleaned_log_weights.sum(1).todense() - cleaned_log_weights.data -= z[cleaned_log_weights.indices[:, 0]] - - # slice the weights for each child layer - log_weight_slices = jnp.array([0] + [ccl.number_of_nodes for ccl in conditional_child_layers]) - log_weight_slices = jnp.cumsum(log_weight_slices) - conditional_log_weights = [cleaned_log_weights[:, log_weight_slices[i]:log_weight_slices[i + 1]].sort_indices() - for i in range(len(conditional_child_layers))] - - resulting_layer = SumLayer(conditional_child_layers, conditional_log_weights) - return resulting_layer, (log_probabilities - self.log_normalization_constants) - def __deepcopy__(self): child_layers = [child_layer.__deepcopy__() for child_layer in self.child_layers] log_weights = [copy_bcoo(log_weight) for log_weight in self.log_weights] @@ -632,13 +402,6 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[SumUnit], sum_layer = cls([cl.layer for cl in filtered_child_layers], log_weights) return NXConverterLayer(sum_layer, nodes, result_hash_remap) - def remove_nodes(self, remove_mask: jax.Array) -> Self: - new_log_weights = [lw[~remove_mask] for lw in self.log_weights] - return self.__class__(self.child_layers, new_log_weights) - - def clean_up_orphans(self) -> Self: - raise NotImplementedError - def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm] = None) -> List[ Unit]: @@ -749,62 +512,6 @@ def log_likelihood_of_nodes_single(self, x: jax.Array) -> jax.Array: return result - def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array: - result = jnp.ones(self.number_of_nodes, dtype=jnp.float32) - - for edges, layer in zip(self.edges, self.child_layers): - # calculate the cdf over the columns of the child layer - cdf = layer.cdf_of_nodes_single(x[layer.variables]) # shape: #child_nodes - - # gather the cdf at the indices of the nodes that are required for the edges - cdf = cdf[edges.data] # shape: #len(edges.values()) - - # multiply the gathered values by the result where the edges define the indices - result = result.at[edges.indices[:, 0]].mul(cdf) - - return result - - def probability_of_simple_event(self, event: SimpleEvent) -> jnp.array: - result = jnp.ones(self.number_of_nodes, dtype=jnp.float32) - - for edges, layer in zip(self.edges, self.child_layers): - # calculate the cdf over the columns of the child layer - prob = layer.probability_of_simple_event(event) # shape: #child_nodes - - # gather the cdf at the indices of the nodes that are required for the edges - prob = prob[edges.data] # shape: #len(edges.values()) - - # multiply the gathered values by the result where the edges define the indices - result = result.at[edges.indices[:, 0]].mul(prob) - - return result - - def sample_from_frequencies(self, frequencies: np.array, result: np.array, start_index=0): - edges_csr = coo_array((self.edges.data, self.edges.indices.T), shape=self.edges.shape).tocsr() - for row_index, (start, end, child_layer) in enumerate( - zip(edges_csr.indptr[:-1], edges_csr.indptr[1:], self.child_layers)): - # get the edges for the current child layer - row = edges_csr.data[start:end] - column_indices = edges_csr.indices[start:end] - - frequencies_for_child_layer = np.zeros((child_layer.number_of_nodes,), dtype=np.int32) - frequencies_for_child_layer[row] = frequencies[column_indices] - - child_layer.sample_from_frequencies(frequencies_for_child_layer, result, start_index) - - def moment_of_nodes(self, order: jax.Array, center: jax.Array): - result = jnp.full((self.number_of_nodes, self.variables.shape[0]), jnp.nan) - for edges, layer in zip(self.edges, self.child_layers): - edges = edges.sum_duplicates(remove_zeros=False) - - # calculate the moments over the columns of the child layer - child_layer_moment = layer.moment_of_nodes(order, center) - - # gather the moments at the indices of the nodes that are required for the edges - result = result.at[edges.indices[:, 0], layer.variables].set(child_layer_moment[edges.data][:, 0]) - - return result - def __deepcopy__(self): child_layers = [child_layer.__deepcopy__() for child_layer in self.child_layers] edges = copy_bcoo(self.edges) @@ -815,102 +522,6 @@ def to_json(self) -> Dict[str, Any]: result["edges"] = (self.edges.data.tolist(), self.edges.indices.tolist(), self.edges.shape) return result - def log_conditional_of_simple_event(self, event: SimpleEvent, ) -> Tuple[Optional[Self], jax.Array]: - - # initialize the conditional child layers and the log probabilities - log_probabilities = jnp.zeros(self.number_of_nodes, dtype=jnp.float32) - conditional_child_layers = [] - remapped_edges = [] - - # for edge bundle and child layer - for index, (edges, child_layer) in enumerate(zip(self.edges, self.child_layers)): - edges: BCOO - edges = edges.sum_duplicates(remove_zeros=False) - - # condition the child layer - conditional, child_log_prob = child_layer.log_conditional_of_simple_event(event) - - # if it is entirely impossible, this layer also is - if conditional is None: - continue - - # update the log probabilities and child layers - log_probabilities = log_probabilities.at[edges.indices[:, 0]].add(child_log_prob[edges.data]) - conditional_child_layers.append(conditional) - - # create the remapping of the node indices. nan indicates the node got deleted - # enumerate the indices of the conditional child layer nodes - new_node_indices = jnp.arange(conditional.number_of_nodes) - - # initialize the remapping of the child layer node indices - layer_remap = jnp.full((child_layer.number_of_nodes,), jnp.nan, dtype=jnp.float32) - layer_remap = layer_remap.at[child_log_prob > -jnp.inf].set(new_node_indices) - - # update the edges - remapped_child_edges = layer_remap[edges.data] - valid_edges = ~jnp.isnan(remapped_child_edges) - - # create new indices for the edges - new_indices = edges.indices[valid_edges] - new_indices = jnp.concatenate([jnp.zeros((len(new_indices), 1), dtype=jnp.int32), new_indices], - axis=1) - - new_edges = BCOO((remapped_child_edges[valid_edges].astype(jnp.int32), - new_indices), - shape=(1, self.number_of_nodes), indices_sorted=True, - unique_indices=True) - remapped_edges.append(new_edges) - - remapped_edges = bcoo_concatenate(remapped_edges, dimension=0).sort_indices() - - # get nodes that should be removed as boolean mask - remove_mask = log_probabilities == -jnp.inf # shape (#nodes, ) - keep_mask = ~remove_mask - - # remove the nodes that have -inf log probabilities from remapped_edges - remapped_edges = coo_array((remapped_edges.data, remapped_edges.indices.T), shape=remapped_edges.shape).tocsc() - remapped_edges = remapped_edges[:, keep_mask].tocoo() - remapped_edges = BCOO((remapped_edges.data, jnp.stack((remapped_edges.row, remapped_edges.col)).T), - shape=remapped_edges.shape, indices_sorted=True, unique_indices=True) - - # construct result and clean it up - result = self.__class__(conditional_child_layers, remapped_edges) - result = result.clean_up_orphans() - return result, log_probabilities - - def clean_up_orphans(self): - """ - Clean up the layer by removing orphans in the child layers. - """ - new_child_layers = [] - - for index, (edges, child_layer) in enumerate(zip(self.edges, self.child_layers)): - edges: BCOO - edges = edges.sum_duplicates(remove_zeros=False) - # mask rather nodes have parent edges or not - orphans = jnp.ones(child_layer.number_of_nodes, dtype=jnp.bool) - - # mark nodes that have parents with False - data = edges.data - if len(data) > 0: - orphans = orphans.at[data].set(False) - - # if orphans exist - if orphans.any(): - # remove them from the child layer - child_layer = child_layer.remove_nodes(orphans) - new_child_layers.append(child_layer) - - # compress edges - shrunken_indices = shrink_index_array(self.edges.indices) - new_edges = BCOO((self.edges.data, shrunken_indices), shape=self.edges.shape, indices_sorted=True, - unique_indices=True) - return self.__class__(new_child_layers, new_edges) - - def remove_nodes(self, remove_mask: jax.Array) -> Self: - new_edges = self.edges[:, ~remove_mask] - return self.__class__(self.child_layers, new_edges) - @classmethod def _from_json(cls, data: Dict[str, Any]) -> Self: child_layer = [Layer.from_json(child_layer) for child_layer in data["child_layers"]] diff --git a/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py index 05833d9..f47c07f 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py @@ -27,108 +27,6 @@ class ContinuousLayer(InputLayer, ABC): Abstract base class for continuous univariate input units. """ - def probability_of_simple_event(self, event: SimpleEvent) -> jax.Array: - interval: Interval = list(event.values())[self.variables[0]] - return self.probability_of_interval(interval) - - def probability_of_interval(self, interval: Interval) -> jnp.array: - points = jnp.array([simple_interval_to_open_array(i) for i in interval.simple_sets]) - upper_bound_cdf = self.cdf_of_nodes(points[:, (1,)]) - lower_bound_cdf = self.cdf_of_nodes(points[:, (0,)]) - return (upper_bound_cdf - lower_bound_cdf).sum(axis=0) - - def probability_of_simple_interval(self, interval: SimpleInterval) -> jax.Array: - points = simple_interval_to_open_array(interval) - upper_bound_cdf = self.cdf_of_nodes_single(points[1]) - lower_bound_cdf = self.cdf_of_nodes_single(points[0]) - return upper_bound_cdf - lower_bound_cdf - - def log_conditional_of_simple_event(self, event: SimpleEvent) -> Tuple[ - Optional[Union[Self, DiracDeltaLayer]], jax.Array]: - if event.is_empty(): - return self.impossible_condition_result - - interval: Interval = list(event.values())[self.variable] - - if interval.is_singleton(): - return self.log_conditional_from_singleton(interval.simple_sets[0]) - - if len(interval.simple_sets) == 1: - return self.log_conditional_from_simple_interval(interval.simple_sets[0]) - else: - return self.log_conditional_from_interval(interval) - - def log_conditional_from_singleton(self, singleton: SimpleInterval) -> Tuple[DiracDeltaLayer, jax.Array]: - """ - Calculate the conditional distribution given a singleton interval. - - In this case, the conditional distribution is a Dirac delta distribution and the log-likelihood is chosen - instead of the log-probability. - - This method returns a Dirac delta layer that has at most the same number of nodes as the input layer. - - :param singleton: The singleton event - :return: The dirac delta layer and the log-likelihoods with shape (something <= #singletons, 1). - """ - value = singleton.lower - log_likelihoods = self.log_likelihood_of_nodes( - jnp.array(value).reshape(-1, 1))[:, 0] # shape: (#nodes, ) - - possible_indices = (log_likelihoods > -jnp.inf).nonzero()[0] # shape: (#dirac-nodes, ) - filtered_likelihood = log_likelihoods[possible_indices] - locations = jnp.full_like(filtered_likelihood, value) - layer = DiracDeltaLayer(self.variable, locations, jnp.exp(filtered_likelihood)) - return layer, log_likelihoods - - def log_conditional_from_simple_interval(self, interval: SimpleInterval) -> Tuple[Self, jax.Array]: - """ - Calculate the conditional distribution given a simple interval with p(interval) > 0. - The interval could also be a singleton. - - :param interval: The simple interval - :return: The conditional distribution and the log-probability of the interval. - """ - raise NotImplementedError - - def log_conditional_from_interval(self, interval: Interval) -> Tuple[SumLayer, jax.Array]: - """ - Calculate the conditional distribution given an interval with p(interval) > 0. - - :param interval: The simple interval - :return: The conditional distribution and the log-probability of the interval. - """ - - # get conditionals of each simple interval - results = [self.log_conditional_from_simple_interval(simple_interval) for simple_interval in - interval.simple_sets] - - layers, log_probs = zip(*results) - - # stack the log probabilities - stacked_log_probabilities = jnp.stack(log_probs, axis=1) # shape: (#simple_intervals, #nodes) - - # calculate the log probabilities of the entire interval - exp_stacked_log_probabilities = jnp.exp(stacked_log_probabilities) - summed_exp_stacked_log_probabilities = jnp.sum(exp_stacked_log_probabilities, axis=1) - total_log_probabilities = jnp.log(summed_exp_stacked_log_probabilities) # shape: (#nodes, 1) - - # create new input layer - possible_layers = [layer for layer in layers if layer is not None] - input_layer = possible_layers[0] - input_layer = input_layer.merge_with(possible_layers[1:]) - - # remove the rows that are entirely -inf and normalize weights - bcoo_data = remove_rows_and_cols_where_all(exp_stacked_log_probabilities / - summed_exp_stacked_log_probabilities.reshape(-1, 1), - 0) - - log_weights = BCOO.fromdense(bcoo_data) - log_weights.data = jnp.log(log_weights.data) - - resulting_layer = SumLayer([input_layer], [log_weights]) - return resulting_layer, total_log_probabilities - - class ContinuousLayerWithFiniteSupport(ContinuousLayer, ABC): """ Abstract class for continuous univariate input units with finite support. @@ -185,9 +83,6 @@ def to_json(self) -> Dict[str, Any]: def __deepcopy__(self): return self.__class__(self.variables[0].item(), self.interval.copy()) - def remove_nodes(self, remove_mask: jax.Array) -> Self: - return self.__class__(self.variable, self.interval[~remove_mask]) - class DiracDeltaLayer(ContinuousLayer): location: jax.Array = eqx.field(static=True) @@ -234,35 +129,6 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[Univariate result = cls(nodes[0].probabilistic_circuit.variables.index(nodes[0].variable), locations, density_caps) return NXConverterLayer(result, nodes, hash_remap) - def sample_from_frequencies(self, frequencies: np.array, result: np.array, start_index=0): - values = self.location.repeat(frequencies).reshape(-1, 1) - result[start_index:start_index + len(values), self.variables] = values - - def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array: - return jnp.where(x < self.location, 0., 1.) - - def moment_of_nodes(self, order: jax.Array, center: jax.Array): - order = order[self.variables[0]] - center = center[self.variables[0]] - if order == 0: - result = jnp.ones(self.number_of_nodes) - elif order == 1: - result = self.location - center - else: - result = jnp.zeros(self.number_of_nodes) - return result.reshape(-1, 1) - - def log_conditional_from_simple_interval(self, interval: SimpleInterval) -> Tuple[Self, jax.Array]: - log_probs = jnp.log(self.probability_of_simple_interval(interval)) - - valid_log_probs = log_probs > -jnp.inf - - if not valid_log_probs.any(): - return self.impossible_condition_result - - result = self.__class__(self.variable, self.location[valid_log_probs], - self.density_cap[valid_log_probs]) - return result, log_probs def to_json(self) -> Dict[str, Any]: result = super().to_json() @@ -274,13 +140,6 @@ def to_json(self) -> Dict[str, Any]: def _from_json(cls, data: Dict[str, Any]) -> Self: return cls(data["variable"], jnp.array(data["location"]), jnp.array(data["density_cap"])) - def merge_with(self, others: List[Self]) -> Self: - return self.__class__(self.variable, jnp.concatenate([self.location] + [other.location for other in others]), - jnp.concatenate([self.density_cap] + [other.density_cap for other in others])) - - def remove_nodes(self, remove_mask: jax.Array) -> Self: - return self.__class__(self.variable, self.location[~remove_mask], self.density_cap[~remove_mask]) - def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm] = None) -> List[ Unit]: nx_pc = NXProbabilisticCircuit() diff --git a/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py b/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py index 0bae6ab..4092d04 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py @@ -40,14 +40,6 @@ def __init__(self, variables: SortedSet, root: Layer): def log_likelihood(self, x: jax.Array) -> jax.Array: return self.root.log_likelihood_of_nodes(x)[:, 0] - def sample(self, amount: int) -> np.array: - result_array = np.full((amount, len(self.variables)), np.nan) - self.root.sample_from_frequencies(np.array([amount]), result_array) - return result_array - - def probability_of_simple_event(self, event: SimpleEvent): - return self.root.probability_of_simple_event(event) - @classmethod def from_nx(cls, pc: NXProbabilisticCircuit, progress_bar: bool = False) -> ProbabilisticCircuit: """ diff --git a/test/test_jax/test_input_layer.py b/test/test_jax/test_input_layer.py index 9a57000..a0d73c9 100644 --- a/test/test_jax/test_input_layer.py +++ b/test/test_jax/test_input_layer.py @@ -20,28 +20,6 @@ def test_likelihood(self): [-jnp.inf, -jnp.inf]] assert jnp.allclose(ll, jnp.array(result)) - def test_cdf(self): - data = jnp.array([-1, 0, 1, 2], dtype=jnp.float32).reshape(-1, 1) - cdf = self.layer.cdf_of_nodes(data) - result = jnp.array([[0, 0], [1, 0], [1, 1], [1, 1]], dtype=jnp.float32) - self.assertTrue(jnp.allclose(cdf, result)) - - def test_moment(self): - order = jnp.array([1.], dtype=jnp.int32) - center = jnp.array([1.5], dtype=jnp.float32) - moment = self.layer.moment_of_nodes(order, center) - result = jnp.array([-1.5, -0.5], dtype=jnp.float32).reshape(-1, 1) - self.assertTrue(jnp.allclose(moment, result)) - - def test_conditional_of_simple_interval(self): - interval = closed(-0.5, 0.5).simple_sets[0] - layer, ll = self.layer.log_conditional_from_simple_interval(interval) - result = jnp.log(jnp.array([1, 0], dtype=jnp.float32)) - self.assertTrue(jnp.allclose(ll, result)) - layer.validate() - self.assertEqual(layer.number_of_nodes, 1) - self.assertTrue(jnp.allclose(layer.location, jnp.array([0.]))) - self.assertTrue(jnp.allclose(layer.density_cap, jnp.array([1.]))) if __name__ == '__main__': diff --git a/test/test_jax/test_probabilistic_circuit.py b/test/test_jax/test_probabilistic_circuit.py index 9cd04ea..f8700ad 100644 --- a/test/test_jax/test_probabilistic_circuit.py +++ b/test/test_jax/test_probabilistic_circuit.py @@ -89,17 +89,7 @@ def test_trainable_parameters(self): number_of_parameters = sum([len(p) for p in flattened_params]) self.assertEqual(number_of_parameters, 10) - def test_serialization(self): - json = self.jax_model.to_json() - model = ProbabilisticCircuit.from_json(json) - samples = model.sample(1000) - jax_ll = model.log_likelihood(samples) - self.assertTrue((jax_ll > -jnp.inf).all()) - def test_sample(self): - samples = self.jax_model.sample(100) - ll = self.jax_model.log_likelihood(samples) - self.assertTrue((ll > -jnp.inf).all()) class JPTIntegrationTestCase(unittest.TestCase): number_of_variables = 2 diff --git a/test/test_jax/test_product_layer.py b/test/test_jax/test_product_layer.py index efb841c..10d680c 100644 --- a/test/test_jax/test_product_layer.py +++ b/test/test_jax/test_product_layer.py @@ -55,71 +55,8 @@ def test_likelihood(self): self.assertTrue(likelihood[0, 1] == -jnp.inf) self.assertTrue(likelihood[1, 0] == -jnp.inf) - def test_cdf(self): - data = jnp.array([[0, 0, 0], [0, 5, 6], [2, 4, 6], [10, 10, 10]], dtype=jnp.float32) - cdf = self.product_layer.cdf_of_nodes(data) - self.assertEqual(cdf.shape, (4, 2)) - result = jnp.array([[0, 0], [1., 0.], [0., 1], [1, 1]], dtype=jnp.float32) - self.assertTrue(jnp.allclose(cdf, result)) - def test_moment(self): - order = jnp.array([1, 1, 2], dtype=jnp.int32) - center = jnp.array([0., 1., 2], dtype=jnp.float32) - moment = self.product_layer.moment_of_nodes(order, center) - result = jnp.array([[0, 4., 0.], - [2., 3., 0.]], dtype=jnp.float32) - self.assertTrue(jnp.allclose(moment, result)) - - def test_probability(self): - event = SimpleEvent({self.x: closed(0,2), self.y: singleton(5), self.z: singleton(6.)}) - prob = self.product_layer.probability_of_simple_event(event) - result = jnp.array([1, 0], dtype=jnp.float32) - self.assertTrue(jnp.allclose(prob, result)) - - def test_conditioning(self): - - event = SimpleEvent({self.x: closed(-1, 1), - self.y: closed(4.5, 5.5), - self.z: closed(5.5, 6.5)}) - - conditional, log_prob = self.product_layer.log_conditional_of_simple_event(event) - conditional.validate() - self.assertTrue(jnp.allclose(log_prob, jnp.log(jnp.array([1., 0.])))) - self.assertEqual(conditional.number_of_nodes, 1) - self.assertEqual(len(conditional.child_layers), 3) - self.assertEqual(conditional.child_layers[0].number_of_nodes, 1) - self.assertEqual(conditional.child_layers[1].number_of_nodes, 1) - self.assertEqual(conditional.child_layers[2].number_of_nodes, 1) - - -class PCProductLayerTestCase(unittest.TestCase): - - x = Continuous("x") - y = Continuous("y") - z = Continuous("z") - - p1_x = DiracDeltaLayer(0, jnp.array([0., 1.]), jnp.array([1, 1])) - p2_x = DiracDeltaLayer(0, jnp.array([2., 3.]), jnp.array([1, 1])) - p_y = DiracDeltaLayer(1, jnp.array([4., 5.]), jnp.array([1, 1])) - p_z = DiracDeltaLayer(2, jnp.array([6.]), jnp.array([1])) - model: ProbabilisticCircuit - - def setUp(self): - indices = jnp.array([[0, 0], - [1, 0], - [3, 0]]) - values = jnp.array([0, 0, 1]) - edges = BCOO((values, indices), shape=(4, 2)).sum_duplicates(remove_zeros=False).sort_indices() - product_layer = ProductLayer([self.p_z, self.p1_x, self.p2_x, self.p_y, ], edges) - self.model = ProbabilisticCircuit(SortedSet([self.x, self.y, self.z]), product_layer) - - def test_sample(self): - samples = self.model.sample(3) - result = np.array([[0, 5, 6], - [0, 5, 6], - [0, 5, 6]]) - self.assertTrue(np.allclose(samples, result)) if __name__ == '__main__': diff --git a/test/test_jax/test_sum_layer.py b/test/test_jax/test_sum_layer.py index 2dda14e..f5a16fa 100644 --- a/test/test_jax/test_sum_layer.py +++ b/test/test_jax/test_sum_layer.py @@ -69,98 +69,6 @@ def test_ll(self): [0., 0.,]])) assert jnp.allclose(ll, result) - def test_cdf(self): - data = jnp.arange(7, dtype=jnp.float32).reshape(-1, 1) - 0.5 - cdf = self.sum_layer.cdf_of_nodes(data) - self.assertEqual(cdf.shape, (7, 2)) - result = jnp.array([[0, 0], # -0.5 - [0, 0.4], # 0.5 - [0.1, 0.4], # 1.5 - [0.3, 0.7], # 2.5 - [0.6, 0.7], # 3.5 - [0.6, 0.8], # 4.5 - [1, 1], # 5.5 - ], dtype=jnp.float32) - self.assertTrue(jnp.allclose(cdf, result)) - - def test_moment(self): - order = jnp.array([1], dtype=jnp.int32) - center = jnp.array([2.5], dtype=jnp.float32) - moment = self.sum_layer.moment_of_nodes(order, center) - result = jnp.array([0.9, -0.5], dtype=jnp.float32).reshape(-1, 1) - self.assertTrue(jnp.allclose(moment, result)) - - def test_probability(self): - event = SimpleEvent({self.x: closed(0.5, 2.5) | closed(4.5, 10)}) - prob = self.sum_layer.probability_of_simple_event(event) - result = jnp.array([0.7, 0.5], dtype=jnp.float32) - self.assertTrue(jnp.allclose(result, prob)) - - def test_conditional(self): - event = SimpleEvent({self.x: closed(0.5, 1.5)}) - c, lp = self.sum_layer.log_conditional_of_simple_event(event) - c.validate() - self.assertEqual(c.number_of_nodes, 1) - self.assertEqual(len(c.child_layers), 1) - self.assertEqual(c.child_layers[0].number_of_nodes, 1) - self.assertTrue(jnp.allclose(c.log_weights[0].todense(), jnp.array([[0.]]))) - self.assertTrue(jnp.allclose(lp, jnp.log(jnp.array([0.1, 0.])))) - - def test_conditional_2(self): - event = SimpleEvent({self.x: closed(1.5, 4.5)}) - c, lp = self.sum_layer.log_conditional_of_simple_event(event) - c.validate() - self.assertEqual(c.number_of_nodes, 2) - self.assertEqual(len(c.child_layers), 2) - self.assertEqual(c.child_layers[0].number_of_nodes, 1) - self.assertEqual(c.child_layers[1].number_of_nodes, 2) - - def test_remove(self): - result = self.sum_layer.remove_nodes(jnp.array([True, False])) - result.validate() - self.assertEqual(result.number_of_nodes, 1) - - -class PCSumUnitTestCase(unittest.TestCase): - x: Continuous = Continuous("x") - - p1_x = DiracDeltaLayer(0, jnp.array([0., 1.]), jnp.array([1, 2])) - p2_x = DiracDeltaLayer(0,jnp.array([2.]), jnp.array([3])) - p3_x = DiracDeltaLayer(0, jnp.array([3., 4., 5.]), jnp.array([4, 5, 6])) - p4_x = DiracDeltaLayer(0, jnp.array([6.]), jnp.array([1])) - model: ProbabilisticCircuit - - @classmethod - def setUpClass(cls): - weights_p1 = BCOO.fromdense(jnp.array([[0, 0.1]])) * 2 - weights_p1.data = jnp.log(weights_p1.data) - - weights_p2 = BCOO.fromdense(jnp.array([[0.2]])) * 2 - weights_p2.data = jnp.log(weights_p2.data) - - weights_p3 = BCOO.fromdense(jnp.array([[0.3, 0, 0.4]])) * 2 - weights_p3.data = jnp.log(weights_p3.data) - - weights_p4 = BCOO.fromdense(jnp.array([[0]])) * 2 - weights_p4.data = jnp.log(weights_p4.data) - - sum_layer = SumLayer([cls.p1_x, cls.p2_x, cls.p3_x, cls.p4_x], - log_weights=[weights_p1, weights_p2, weights_p3, weights_p4]) - sum_layer.validate() - cls.model = ProbabilisticCircuit(SortedSet([cls.x]), sum_layer) - - def test_sampling(self): - np.random.seed(69) - samples = self.model.sample(10) - self.assertEqual(samples.shape, (10, 1)) - result = np.array([2, 2, 2, 3, 3, 5, 5, 5, 5, 5]).reshape(-1, 1) - self.assertTrue(np.allclose(samples, result)) - - def test_conditioning(self): - event = SimpleEvent({self.x: closed(1.5, 4.5)}) - conditional, log_prob = self.model.root.log_conditional_of_simple_event(event) - # conditional.validate() - class NygaDistributionTestCase(unittest.TestCase): @@ -180,7 +88,3 @@ def setUpClass(cls): def test_log_likelihood(self): ll = self.jax_model.log_likelihood(self.data) self.assertTrue(jnp.all(ll > -jnp.inf)) - - def test_sampling(self): - data = self.jax_model.sample(1000) - self.assertEqual(data.shape, (1000, 1)) \ No newline at end of file diff --git a/test/test_jax/test_uniform_layer.py b/test/test_jax/test_uniform_layer.py index de90d55..33e2941 100644 --- a/test/test_jax/test_uniform_layer.py +++ b/test/test_jax/test_uniform_layer.py @@ -40,73 +40,9 @@ def test_from_interval(self): result = jnp.array([[0, 0, 1, 1], [0, 1, 0, 1]]) self.assertTrue(jnp.allclose(ll, result)) - def test_cdf(self): - data = jnp.array([0.5, 1.5, 4]).reshape(-1, 1) - cdf = self.p_x.cdf_of_nodes(data) - self.assertEqual(cdf.shape, (3, 2)) - result = jnp.array([[0.5, 0], [1, 0.25], [1, 1]]) - self.assertTrue(jnp.allclose(cdf, result)) - - def test_moment(self): - order = jnp.array([1], dtype=jnp.int32) - center = jnp.array([1.], dtype=jnp.float32) - moment = self.p_x.moment_of_nodes(order, center) - result = jnp.array([[-0.5], [1.]], dtype=jnp.float32) - self.assertTrue(jnp.allclose(moment, result)) - - def test_probability(self): - event = SimpleEvent({self.x: closed(0.5, 2.5) | closed(3, 5)}) - prob = self.p_x.probability_of_simple_event(event) - self.assertEqual(prob.shape, (2,)) - result = jnp.array([0.5, 0.75]) - self.assertTrue(jnp.allclose(prob, result)) - def test_to_json(self): data = self.p_x.to_json() json.dumps(data) p_x = UniformLayer.from_json(data) self.assertTrue(jnp.allclose(self.p_x.interval, p_x.interval)) - - def test_conditional_singleton(self): - event = SimpleEvent({self.x: closed(0.5, 0.5)}) - layer, ll = self.p_x.log_conditional_of_simple_event(event) - self.assertEqual(layer.number_of_nodes, 1) - self.assertTrue(jnp.allclose(jnp.array([0.5]), layer.location)) - self.assertTrue(jnp.allclose(jnp.array([1.]), layer.density_cap)) - - def test_conditional_single_truncation(self): - event = SimpleEvent({self.x: closed(0.5, 2.5)}) - layer, ll = self.p_x.log_conditional_of_simple_event(event) - layer.validate() - self.assertEqual(layer.number_of_nodes, 2) - self.assertTrue(jnp.allclose(layer.interval, jnp.array([[0.5, 1], [1, 2.5]]))) - self.assertTrue(jnp.allclose(jnp.log(jnp.array([0.5, 0.75])), ll)) - - def test_conditional_with_node_removal(self): - event = SimpleEvent({self.x: closed(0.25, 0.5)}) - layer, ll = self.p_x.log_conditional_of_simple_event(event) - layer.validate() - self.assertEqual(layer.number_of_nodes, 1) - self.assertTrue(jnp.allclose(layer.interval, jnp.array([[0.25, 0.5]]))) - self.assertTrue(jnp.allclose(jnp.log(jnp.array([0.25, 0.])), ll)) - - def test_conditional_multiple_truncation(self): - event = closed(-1, 0.5) | closed(0.7, 0.8) | closed(2., 3.) | closed(3.5, 4.) - - layer, ll = self.p_x.log_conditional_from_interval(event) - - self.assertTrue(jnp.allclose(jnp.log(jnp.array([0.6, 0.5])), ll)) - self.assertIsInstance(layer, SumLayer) - - layer.validate() - self.assertEqual(layer.number_of_nodes, 2) - self.assertEqual(len(layer.child_layers), 1) - self.assertTrue(jnp.allclose(layer.child_layers[0].interval, jnp.array([[0., 0.5], [0.7, 0.8], [2., 3.]]))) - - log_weights_by_hand = jnp.array([[0.5, 0.1, 0.], [0., 0., 0.5]]) - log_weights_by_hand /= jnp.sum(log_weights_by_hand, axis=1, keepdims=True) - log_weights_by_hand = BCOO.fromdense(log_weights_by_hand) - log_weights_by_hand.data = jnp.log(log_weights_by_hand.data) - self.assertTrue(jnp.allclose(layer.log_weights[0].data, log_weights_by_hand.data)) - self.assertTrue(jnp.allclose(layer.log_weights[0].indices, log_weights_by_hand.indices)) From c818cf5c40493c390ae8b2dfc18da76819cc284f Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Fri, 13 Dec 2024 17:15:38 +0100 Subject: [PATCH 02/10] Removed unsupported things from jax circuits --- scripts/gmm.py | 2 +- scripts/jpt_speed_comparison.py | 2 +- scripts/nyga_speed_comparison.py | 4 ++-- .../probabilistic_circuit/jax/inner_layer.py | 13 +++++++------ test/test_jax/test_sum_layer.py | 6 ++++++ 5 files changed, 17 insertions(+), 10 deletions(-) diff --git a/scripts/gmm.py b/scripts/gmm.py index e76685d..ee51138 100644 --- a/scripts/gmm.py +++ b/scripts/gmm.py @@ -28,7 +28,7 @@ number_of_variables = 2 number_of_samples_per_component = 100000 number_of_components = 2 -number_of_mixtures = 100 +number_of_mixtures = 1000 number_of_iterations = 1000 # model selection diff --git a/scripts/jpt_speed_comparison.py b/scripts/jpt_speed_comparison.py index ac01f85..4ee388e 100644 --- a/scripts/jpt_speed_comparison.py +++ b/scripts/jpt_speed_comparison.py @@ -110,7 +110,7 @@ def timed_jax_method(): # ll = jax_model.log_likelihood(samples) # assert (ll > -jnp.inf).all() -times_nx, times_jax = eval_performance(nx_model.log_likelihood, (data, ), compiled_ll_jax, (data_jax, ), 10, 5) +times_nx, times_jax = eval_performance(nx_model.log_likelihood, (data, ), compiled_ll_jax, (data_jax, ), 15, 10) # times_nx, times_jax = eval_performance(nx_model.probability_of_simple_event, (event,), jax_model.probability_of_simple_event, (event,), 20, 10) # times_nx, times_jax = eval_performance(nx_model.sample, (10000, ), jax_model.sample, (10000,), 10, 5) diff --git a/scripts/nyga_speed_comparison.py b/scripts/nyga_speed_comparison.py index 65b1cb9..1d97dcf 100644 --- a/scripts/nyga_speed_comparison.py +++ b/scripts/nyga_speed_comparison.py @@ -34,7 +34,7 @@ path_prefix = os.path.join(os.path.expanduser("~"), "Documents") nx_model_path = os.path.join(path_prefix, "nx_nyga.pm") jax_model_path = os.path.join(path_prefix, "jax_nyga.pm") -load_from_disc = True +load_from_disc = False save_to_disc = True @@ -103,7 +103,7 @@ def timed_jax_method(): # times_nx, times_jax = eval_performance(nx_model.log_likelihood, (data, ), compiled_ll_jax, (data_jax, ), 20, 2) # times_nx, times_jax = eval_performance(prob_nx, event, prob_jax, event, 15, 10) -times_nx, times_jax = eval_performance(nx_model.sample, (1000, ), jax_model.sample, (1000, ), 10, 5) +times_nx, times_jax = eval_performance(nx_model.sample, (1000, ), jax_model.sample, (1000, ), 5, 10) time_jax = np.mean(times_jax), np.std(times_jax) time_nx = np.mean(times_nx), np.std(times_nx) diff --git a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py index 7c11ac7..35e135a 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py @@ -20,7 +20,7 @@ from sortedcontainers import SortedSet from typing_extensions import List, Iterator, Tuple, Union, Type, Dict, Any, Self, Optional -from . import shrink_index_array +from . import shrink_index_array, embed_sparse_array_in_nan_array from .utils import copy_bcoo, sample_from_sparse_probabilities_csc, sparse_remove_rows_and_cols_where_all from ..nx.probabilistic_circuit import (SumUnit, ProductUnit, Unit, ProbabilisticCircuit as NXProbabilisticCircuit) @@ -310,7 +310,9 @@ def concatenated_log_weights(self) -> BCOO: def log_normalization_constants(self) -> jax.Array: result = self.concatenated_log_weights result.data = jnp.exp(result.data) + jax.debug.print("dense time") result = result.sum(1).todense() + jax.debug.print("post dense time") return jnp.log(result) @property @@ -337,11 +339,10 @@ def log_likelihood_of_nodes_single(self, x: jax.Array) -> jax.Array: cloned_log_weights.data += child_layer_log_likelihood[cloned_log_weights.indices[:, 1]] cloned_log_weights.data = jnp.exp(cloned_log_weights.data) # exponent weights - # sum the weights for each node - ll = cloned_log_weights.sum(1).todense() - - # sum the child layer result - result += ll + jax.debug.print("pre scatter add time") + result = result.at[cloned_log_weights.indices[:, 0]].add(cloned_log_weights.data, + indices_are_sorted=True, unique_indices=True) + jax.debug.print("post scatter add time") return jnp.where(result > 0, jnp.log(result) - self.log_normalization_constants, -jnp.inf) diff --git a/test/test_jax/test_sum_layer.py b/test/test_jax/test_sum_layer.py index f5a16fa..52ef529 100644 --- a/test/test_jax/test_sum_layer.py +++ b/test/test_jax/test_sum_layer.py @@ -69,6 +69,12 @@ def test_ll(self): [0., 0.,]])) assert jnp.allclose(ll, result) + def test_ll_single(self): + data = jnp.array([0]) + l = self.sum_layer.log_likelihood_of_nodes_single(data) + result = jnp.log(jnp.array([0., 0.4])) + assert jnp.allclose(l, result) + class NygaDistributionTestCase(unittest.TestCase): From 1c4947dc6be6178d68966520dea79a781bce34a8 Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Fri, 20 Dec 2024 20:03:24 +0100 Subject: [PATCH 03/10] Added region graphs and conversion to jax circuits --- doc/pendulum.md | 2 +- .../learning/region_graph/__init__.py | 0 .../learning/region_graph/region_graph.py | 204 ++++++++++++++++++ .../jax/gaussian_layer.py | 2 +- .../probabilistic_circuit/jax/inner_layer.py | 23 +- test/test_rat_spn/__init__.py | 0 test/test_rat_spn/test_region_graph.py | 41 ++++ 7 files changed, 258 insertions(+), 14 deletions(-) create mode 100644 src/probabilistic_model/learning/region_graph/__init__.py create mode 100644 src/probabilistic_model/learning/region_graph/region_graph.py create mode 100644 test/test_rat_spn/__init__.py create mode 100644 test/test_rat_spn/test_region_graph.py diff --git a/doc/pendulum.md b/doc/pendulum.md index 2bb64a2..2321e72 100644 --- a/doc/pendulum.md +++ b/doc/pendulum.md @@ -25,7 +25,7 @@ purely knowledge graph driven AI. A full swing takes a couple of years. Frank argued that there needs to be a middle ground where machine learning and knowledge graphs come together. -I belive that the implementation of probabilistic models in this package is capable of doing so. +I believe that the implementation of probabilistic models in this package is capable of doing so. Knowledge graphs generate sets that describe possible assignments that match the constraints and instance knowledge of ontologies (a random event, so to say). Probability distributions describe the likelihoods of every possible solution. diff --git a/src/probabilistic_model/learning/region_graph/__init__.py b/src/probabilistic_model/learning/region_graph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/probabilistic_model/learning/region_graph/region_graph.py b/src/probabilistic_model/learning/region_graph/region_graph.py new file mode 100644 index 0000000..a360d20 --- /dev/null +++ b/src/probabilistic_model/learning/region_graph/region_graph.py @@ -0,0 +1,204 @@ +from collections import deque + +import networkx as nx +import numpy as np +from jax.experimental.sparse import BCOO +from random_events.variable import Continuous +from sortedcontainers import SortedSet +from typing_extensions import List, Self, Type + +from ...distributions import GaussianDistribution +from ...probabilistic_circuit.jax import SumLayer, ProductLayer +from ...probabilistic_circuit.jax.gaussian_layer import GaussianLayer +from ...probabilistic_circuit.nx.probabilistic_circuit import ProbabilisticCircuit, SumUnit, ProductUnit +from ...probabilistic_circuit.nx.distributions.distributions import UnivariateContinuousLeaf +from ...probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit as JPC +import jax.numpy as jnp +import jax.random + + +class Region: + variables: SortedSet + + def __init__(self, variables: SortedSet): + self.variables = variables + + def __hash__(self) -> int: + return id(self) + + def random_partition(self, k=2) -> List[Self]: + indices = np.arange(len(self.variables)) + np.random.shuffle(indices) + partitions = [Region(SortedSet([self.variables[index] for index in split])) for split in np.array_split(indices, k)] + return partitions + + def __repr__(self) -> str: + return "{" + ", ".join([v.name for v in self.variables]) + "}" + +class Partition: + def __hash__(self) -> int: + return id(self) + + +class RegionGraph(nx.DiGraph): + + variables: SortedSet + + def __init__(self, variables: SortedSet, + partitions: int = 2, + depth:int = 2, + repetitions:int = 2): + super().__init__() + self.variables = variables + self.partitions = partitions + self.depth = depth + self.repetitions = repetitions + + + def create_random_region_graph(self): + root = Region(self.variables) + self.add_node(root) + for repetition in range(self.repetitions): + self.recursive_split(root) + return self + + def regions(self): + for node in self.nodes: + if isinstance(node, Region): + yield node + + def partition_nodes(self): + for node in self.nodes: + if isinstance(node, Partition): + yield node + + def recursive_split(self, node: Region): + root_partition = Partition() + self.add_edge(node, root_partition) + remaining_regions = deque([(node, self.depth, root_partition)]) + + while remaining_regions: + region, depth, partition = remaining_regions.popleft() + + + if len(region.variables) == 1: + continue + + if depth == 0: + for variable in region.variables: + self.add_edge(partition, Region(SortedSet([variable]))) + continue + + new_regions = region.random_partition(self.partitions) + for new_region in new_regions: + self.add_edge(partition, new_region) + new_partition = Partition() + self.add_edge(new_region, new_partition) + remaining_regions.append((new_region, depth - 1, new_partition)) + + @property + def root(self) -> Region: + """ + The root of the circuit is the node with in-degree 0. + This is the output node, that will perform the final computation. + + :return: The root of the circuit. + """ + possible_roots = [node for node in self.nodes() if self.in_degree(node) == 0] + if len(possible_roots) > 1: + raise ValueError(f"More than one root found. Possible roots are {possible_roots}") + + return possible_roots[0] + + def as_probabilistic_circuit(self, continuous_distribution_type: Type = GaussianDistribution, + input_units: int = 5, sum_units: int = 5) -> ProbabilisticCircuit: + model = ProbabilisticCircuit() + + # create nodes for each region + for region in self.regions(): + region: Region + children = list(self.successors(region)) + parents = list(self.predecessors(region)) + + # if the region is a leaf + if len(children) == 0: + variable = region.variables[0] + if isinstance(variable, Continuous): + region.nodes = [UnivariateContinuousLeaf(GaussianDistribution(variable, 0, 1), model) + for _ in range(input_units)] + + # region is root + elif len(parents) == 0: + region.nodes = [SumUnit(model)] + + # region is in the middle + else: + region.nodes = [SumUnit(model) for _ in range(sum_units)] + + # create nodes for each partition + for partition in self.partition_nodes(): + children = list(self.successors(partition)) + parent = list(self.predecessors(partition)) + assert len(parent) == 1, "Partition should only have one parent." + parent = parent[0] + + node_lengths = [len(child.nodes) for child in children] + assert (len(set(node_lengths)) == 1), "Node lengths must be all equal. Got {}".format(node_lengths) + + for index in range(node_lengths[0]): + prod = ProductUnit(model) + for child in children: + prod.add_subcircuit(child.nodes[index], mount=False) + for node in parent.nodes: + node.add_subcircuit(prod, 1., mount=False) + + return model + + def as_jax_pc(self, continuous_distribution_type: Type = GaussianDistribution, + input_units: int = 5, sum_units: int = 5, key=jax.random.PRNGKey(69)) -> JPC: + root = self.root + + # create nodes for each region + for layer in reversed(list(nx.bfs_layers(self, root))): + for node in layer: + children = list(self.successors(node)) + parents = list(self.predecessors(node)) + if isinstance(node, Region): + # if the region is a leaf + if len(children) == 0: + variable = node.variables[0] + variable_index = self.variables.index(variable) + if isinstance(variable, Continuous): + location = jax.random.uniform(key, shape=(input_units,), minval=-1., maxval=1.) + log_scale = jnp.log(jax.random.uniform(key, shape=(input_units,), minval=0.01, maxval=1.)) + node.layer = GaussianLayer(variable_index, location=location, log_scale=log_scale, min_scale=jnp.full_like(location, 0.01)) + node.layer.validate() + else: + raise NotImplementedError + + else: + # if the region is root + if len(parents) == 0: + current_sum_units = 1 + # if the region is in the middle + else: + current_sum_units = sum_units + log_weights = [BCOO.fromdense(jax.random.uniform(key, shape=(current_sum_units, child.layer.number_of_nodes), minval=0., maxval=1.)) for child in children] + for log_weight in log_weights: + log_weight.data = jnp.log(log_weight.data) + node.layer = SumLayer([child.layer for child in children], log_weights=log_weights) + node.layer.validate() + + + elif isinstance(node, Partition): + node_lengths = [child.layer.number_of_nodes for child in children] + assert (len(set(node_lengths)) == 1), "Node lengths must be all equal. Got {}".format(node_lengths) + + edges = jnp.arange(node_lengths[0]).reshape(1, -1).repeat(len(children), axis=0) + sparse_edges = BCOO.fromdense(jnp.ones_like(edges)) + sparse_edges.values = edges + node.layer = ProductLayer([child.layer for child in children], sparse_edges) + node.layer.validate() + + model = JPC(self.variables, root.layer) + return model diff --git a/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py index 6a34948..9653587 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py @@ -41,7 +41,7 @@ def validate(self): @property def number_of_nodes(self) -> int: - return len(self.location.shape) + return self.location.shape[0] @property def scale(self) -> jnp.array: diff --git a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py index 35e135a..4b2a72e 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py @@ -142,6 +142,7 @@ def create_layers_from_nodes(nodes: List[Unit], child_layers: List[NXConverterLa for scope in unique_scopes: nodes_of_current_type_and_scope = [node for node in nodes_of_current_type if tuple(node.variables) == scope] + layer = layer_type.create_layer_from_nodes_with_same_type_and_scope(nodes_of_current_type_and_scope, child_layers, progress_bar) result.append(layer) @@ -310,9 +311,7 @@ def concatenated_log_weights(self) -> BCOO: def log_normalization_constants(self) -> jax.Array: result = self.concatenated_log_weights result.data = jnp.exp(result.data) - jax.debug.print("dense time") result = result.sum(1).todense() - jax.debug.print("post dense time") return jnp.log(result) @property @@ -338,13 +337,10 @@ def log_likelihood_of_nodes_single(self, x: jax.Array) -> jax.Array: # multiply the weights with the child layer likelihood cloned_log_weights.data += child_layer_log_likelihood[cloned_log_weights.indices[:, 1]] cloned_log_weights.data = jnp.exp(cloned_log_weights.data) # exponent weights - - jax.debug.print("pre scatter add time") result = result.at[cloned_log_weights.indices[:, 0]].add(cloned_log_weights.data, - indices_are_sorted=True, unique_indices=True) - jax.debug.print("post scatter add time") + indices_are_sorted=False, unique_indices=False) - return jnp.where(result > 0, jnp.log(result) - self.log_normalization_constants, -jnp.inf) + return jnp.log(result) - self.log_normalization_constants def __deepcopy__(self): child_layers = [child_layer.__deepcopy__() for child_layer in self.child_layers] @@ -376,6 +372,11 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[SumUnit], number_of_nodes = len(nodes) # filter the child layers to only contain layers with the same scope as this one + print(variables) + print([child_layer.layer.variables for child_layer in child_layers]) + print(len(child_layers)) + print([n.variables for n in child_layers[0].nodes]) + print([n.variables for n in child_layers[1].nodes]) filtered_child_layers = [child_layer for child_layer in child_layers if (child_layer.layer.variables == variables).all()] log_weights = [] @@ -492,11 +493,9 @@ def number_of_components(self) -> int: @cached_property def variables(self) -> jax.Array: - child_layer_variables = jnp.concatenate([child_layer.variables for child_layer in self.child_layers]) - max_size = child_layer_variables.shape[0] - unique_values = jnp.unique(child_layer_variables, size=max_size, fill_value=-1) - unique_values = unique_values[unique_values >= 0] - return unique_values.sort() + variables = jnp.concatenate([child_layer.variables for child_layer in self.child_layers]) + variables = jnp.unique(variables) + return variables def log_likelihood_of_nodes_single(self, x: jax.Array) -> jax.Array: result = jnp.zeros(self.number_of_nodes, dtype=jnp.float32) diff --git a/test/test_rat_spn/__init__.py b/test/test_rat_spn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_rat_spn/test_region_graph.py b/test/test_rat_spn/test_region_graph.py new file mode 100644 index 0000000..441c247 --- /dev/null +++ b/test/test_rat_spn/test_region_graph.py @@ -0,0 +1,41 @@ +import random +import unittest + +from matplotlib import pyplot as plt +from networkx.drawing.nx_agraph import graphviz_layout +from random_events.variable import Continuous +import pydot +from probabilistic_model.learning.region_graph.region_graph import * +from probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit as JPC +from probabilistic_model.probabilistic_circuit.jax.gaussian_layer import GaussianLayer +import numpy as np + +np.random.seed(420) +random.seed(420) + +class RandomRegionGraphTestCase(unittest.TestCase): + + variables = SortedSet([Continuous(str(i)) for i in range(4)]) + + region_graph = RegionGraph(variables, partitions=2, depth=1, repetitions=2) + region_graph = region_graph.create_random_region_graph() + + def test_region_graph(self): + self.assertEqual(len(self.region_graph.nodes()), len(self.region_graph.nodes())) + + def test_as_pc(self): + model = self.region_graph.as_probabilistic_circuit(input_units=1, sum_units=1) + model.plot_structure() + plt.show() + jax_model = JPC.from_nx(model) + + def test_as_jpc(self): + model = self.region_graph.as_jax_pc(input_units=10, sum_units=5) + print(model) + nx_model = model.to_nx() + nx_model.plot_structure() + plt.show() + + +if __name__ == '__main__': + unittest.main() From e9aae7753e90124b7f6025e868f1dc256943c670 Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Fri, 20 Dec 2024 20:04:03 +0100 Subject: [PATCH 04/10] Added region graphs and conversion to jax circuits --- .../learning/region_graph/region_graph.py | 45 +------------------ 1 file changed, 1 insertion(+), 44 deletions(-) diff --git a/src/probabilistic_model/learning/region_graph/region_graph.py b/src/probabilistic_model/learning/region_graph/region_graph.py index a360d20..87f073e 100644 --- a/src/probabilistic_model/learning/region_graph/region_graph.py +++ b/src/probabilistic_model/learning/region_graph/region_graph.py @@ -110,51 +110,8 @@ def root(self) -> Region: return possible_roots[0] - def as_probabilistic_circuit(self, continuous_distribution_type: Type = GaussianDistribution, - input_units: int = 5, sum_units: int = 5) -> ProbabilisticCircuit: - model = ProbabilisticCircuit() - # create nodes for each region - for region in self.regions(): - region: Region - children = list(self.successors(region)) - parents = list(self.predecessors(region)) - - # if the region is a leaf - if len(children) == 0: - variable = region.variables[0] - if isinstance(variable, Continuous): - region.nodes = [UnivariateContinuousLeaf(GaussianDistribution(variable, 0, 1), model) - for _ in range(input_units)] - - # region is root - elif len(parents) == 0: - region.nodes = [SumUnit(model)] - - # region is in the middle - else: - region.nodes = [SumUnit(model) for _ in range(sum_units)] - - # create nodes for each partition - for partition in self.partition_nodes(): - children = list(self.successors(partition)) - parent = list(self.predecessors(partition)) - assert len(parent) == 1, "Partition should only have one parent." - parent = parent[0] - - node_lengths = [len(child.nodes) for child in children] - assert (len(set(node_lengths)) == 1), "Node lengths must be all equal. Got {}".format(node_lengths) - - for index in range(node_lengths[0]): - prod = ProductUnit(model) - for child in children: - prod.add_subcircuit(child.nodes[index], mount=False) - for node in parent.nodes: - node.add_subcircuit(prod, 1., mount=False) - - return model - - def as_jax_pc(self, continuous_distribution_type: Type = GaussianDistribution, + def as_probabilistic_circuit(self, continuous_distribution_type: Type = GaussianLayer, input_units: int = 5, sum_units: int = 5, key=jax.random.PRNGKey(69)) -> JPC: root = self.root From 332973f7c5d2524979a63a9cc5d56c658318baaf Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Fri, 20 Dec 2024 20:16:51 +0100 Subject: [PATCH 05/10] Added region graphs and conversion to jax circuits --- .../probabilistic_circuit/jax/inner_layer.py | 21 +++++-------------- test/test_rat_spn/test_region_graph.py | 10 ++------- 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py index 4b2a72e..06da38e 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py @@ -418,25 +418,15 @@ def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm child_layer_nx = [cl.to_nx(variables, progress_bar) for cl in self.child_layers] - clw = self.normalized_weights - csc_weights = coo_matrix((clw.data, clw.indices.T), shape=clw.shape).tocsc(copy=False) + for log_weights, child_layer in zip(self.log_weights, child_layer_nx): - # offset for shifting through the frequencies of the node_to_child_frequency_map - prev_column_index = 0 - - for child_layer in child_layer_nx: # extract the weights for the child layer - current_weight_block: csc_array = csc_weights[:, prev_column_index:prev_column_index + len(child_layer)] - current_weight_block: coo_array = current_weight_block.tocoo(False) - - for row, col, weight in zip(current_weight_block.row, current_weight_block.col, current_weight_block.data): - units[row].add_subcircuit(child_layer[col], weight) - + for ((row, col), log_weight) in zip(log_weights.indices, log_weights.data): + units[row].add_subcircuit(child_layer[col], jnp.exp(log_weight).item()) if progress_bar: progress_bar.update() - # shift the offset - prev_column_index += len(child_layer) + [unit.normalize() for unit in units] return units @@ -574,8 +564,7 @@ def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm progress_bar.set_postfix_str(f"Parsing Product Layer of variables {variables_}") nx_pc = NXProbabilisticCircuit() - units = [ProductUnit() for _ in range(self.number_of_nodes)] - nx_pc.add_nodes_from(units) + units = [ProductUnit(nx_pc) for _ in range(self.number_of_nodes)] child_layer_nx = [cl.to_nx(variables, progress_bar) for cl in self.child_layers] diff --git a/test/test_rat_spn/test_region_graph.py b/test/test_rat_spn/test_region_graph.py index 441c247..b9c8def 100644 --- a/test/test_rat_spn/test_region_graph.py +++ b/test/test_rat_spn/test_region_graph.py @@ -23,18 +23,12 @@ class RandomRegionGraphTestCase(unittest.TestCase): def test_region_graph(self): self.assertEqual(len(self.region_graph.nodes()), len(self.region_graph.nodes())) - def test_as_pc(self): - model = self.region_graph.as_probabilistic_circuit(input_units=1, sum_units=1) - model.plot_structure() - plt.show() - jax_model = JPC.from_nx(model) - def test_as_jpc(self): - model = self.region_graph.as_jax_pc(input_units=10, sum_units=5) - print(model) + model = self.region_graph.as_probabilistic_circuit(input_units=10, sum_units=5) nx_model = model.to_nx() nx_model.plot_structure() plt.show() + print(len(list(node for node in nx_model.nodes() if isinstance(node, SumUnit)))) if __name__ == '__main__': From 43ef8c5a0ccdaff9b9d1de867bb806799d6baa80 Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Sat, 21 Dec 2024 09:07:08 +0100 Subject: [PATCH 06/10] Fixed region graph to jax conversion --- .../learning/region_graph/region_graph.py | 11 +++--- .../jax/gaussian_layer.py | 8 ++--- .../probabilistic_circuit/jax/inner_layer.py | 34 ++++++++++--------- .../probabilistic_circuit/jax/input_layer.py | 7 ++-- .../jax/probabilistic_circuit.py | 4 ++- test/test_rat_spn/test_region_graph.py | 6 ++-- 6 files changed, 35 insertions(+), 35 deletions(-) diff --git a/src/probabilistic_model/learning/region_graph/region_graph.py b/src/probabilistic_model/learning/region_graph/region_graph.py index 87f073e..5ab1b1e 100644 --- a/src/probabilistic_model/learning/region_graph/region_graph.py +++ b/src/probabilistic_model/learning/region_graph/region_graph.py @@ -133,14 +133,13 @@ def as_probabilistic_circuit(self, continuous_distribution_type: Type = Gaussian else: raise NotImplementedError + # if the region is root or in the middle else: # if the region is root if len(parents) == 0: - current_sum_units = 1 - # if the region is in the middle - else: - current_sum_units = sum_units - log_weights = [BCOO.fromdense(jax.random.uniform(key, shape=(current_sum_units, child.layer.number_of_nodes), minval=0., maxval=1.)) for child in children] + sum_units = 1 + + log_weights = [BCOO.fromdense(jax.random.uniform(key, shape=(sum_units, child.layer.number_of_nodes), minval=0., maxval=1.)) for child in children] for log_weight in log_weights: log_weight.data = jnp.log(log_weight.data) node.layer = SumLayer([child.layer for child in children], log_weights=log_weights) @@ -153,7 +152,7 @@ def as_probabilistic_circuit(self, continuous_distribution_type: Type = Gaussian edges = jnp.arange(node_lengths[0]).reshape(1, -1).repeat(len(children), axis=0) sparse_edges = BCOO.fromdense(jnp.ones_like(edges)) - sparse_edges.values = edges + sparse_edges.data = edges.flatten() node.layer = ProductLayer([child.layer for child in children], sparse_edges) node.layer.validate() diff --git a/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py index 9653587..00da4ec 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py @@ -80,20 +80,18 @@ def _from_json(cls, data: Dict[str, Any]) -> Self: return cls(data["variable"], jnp.array(data["location"]), jnp.array(data["scale"]), jnp.array(data["min_scale"])) - def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm] = None) -> List[ - Unit]: + def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit, + progress_bar: Optional[tqdm.tqdm] = None) -> List[Unit]: variable = variables[self.variable] if progress_bar: progress_bar.set_postfix_str(f"Creating Gaussian distributions for variable {variable.name}") - nx_pc = NXProbabilisticCircuit() nodes = [UnivariateContinuousLeaf( - GaussianDistribution(variable=variable, location=location.item(), scale=scale.item())) + GaussianDistribution(variable=variable, location=location.item(), scale=scale.item()),result) for location, scale in zip(self.location, self.scale)] if progress_bar: progress_bar.update(self.number_of_nodes) - nx_pc.add_nodes_from(nodes) return nodes diff --git a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py index 06da38e..c8f6e4f 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py @@ -104,8 +104,8 @@ def nx_classes(cls) -> Tuple[Type, ...]: """ return tuple() - def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm] = None) -> List[ - Unit]: + def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit, + progress_bar: Optional[tqdm.tqdm] = None,) -> List[Unit]: """ Convert the layer to a networkx circuit. For every node in this circuit, a corresponding node in the networkx circuit @@ -113,7 +113,9 @@ def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm The nodes all belong to the same circuit. :param variables: The variables of the circuit. + :param result: The resulting circuit to write into :param progress_bar: A progress bar to show the progress. + :return: The nodes of the networkx circuit. """ raise NotImplementedError @@ -404,25 +406,23 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[SumUnit], sum_layer = cls([cl.layer for cl in filtered_child_layers], log_weights) return NXConverterLayer(sum_layer, nodes, result_hash_remap) - def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm] = None) -> List[ - Unit]: + def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit, + progress_bar: Optional[tqdm.tqdm] = None) -> List[Unit]: variables_ = [variables[i] for i in self.variables] if progress_bar: progress_bar.set_postfix_str(f"Parsing Sum Layer for variables {variables_}") - nx_pc = NXProbabilisticCircuit() - units = [SumUnit() for _ in range(self.number_of_nodes)] - nx_pc.add_nodes_from(units) + units = [SumUnit(result) for _ in range(self.number_of_nodes)] - child_layer_nx = [cl.to_nx(variables, progress_bar) for cl in self.child_layers] + child_layer_nx = [cl.to_nx(variables, result, progress_bar) for cl in self.child_layers] for log_weights, child_layer in zip(self.log_weights, child_layer_nx): # extract the weights for the child layer for ((row, col), log_weight) in zip(log_weights.indices, log_weights.data): - units[row].add_subcircuit(child_layer[col], jnp.exp(log_weight).item()) + units[row].add_subcircuit(child_layer[col], jnp.exp(log_weight).item(), False) if progress_bar: progress_bar.update() @@ -556,20 +556,22 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[Unit], layer = cls([cl.layer for cl in child_layers], edges) return NXConverterLayer(layer, nodes, hash_remap) - def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm] = None) -> List[ - Unit]: + def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit, + progress_bar: Optional[tqdm.tqdm] = None) -> List[Unit]: + + if result is None: + result = NXProbabilisticCircuit() variables_ = [variables[i] for i in self.variables] if progress_bar: progress_bar.set_postfix_str(f"Parsing Product Layer of variables {variables_}") - nx_pc = NXProbabilisticCircuit() - units = [ProductUnit(nx_pc) for _ in range(self.number_of_nodes)] - - child_layer_nx = [cl.to_nx(variables, progress_bar) for cl in self.child_layers] + units = [ProductUnit(result) for _ in range(self.number_of_nodes)] + child_layer_nx = [cl.to_nx(variables, result, progress_bar) for cl in self.child_layers] for (row, col), data in zip(self.edges.indices, self.edges.data): - units[col].add_subcircuit(child_layer_nx[row][data]) + units[col].add_subcircuit(child_layer_nx[row][data], mount=False) + if progress_bar: progress_bar.update() diff --git a/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py index f47c07f..6c4040f 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py @@ -140,17 +140,16 @@ def to_json(self) -> Dict[str, Any]: def _from_json(cls, data: Dict[str, Any]) -> Self: return cls(data["variable"], jnp.array(data["location"]), jnp.array(data["density_cap"])) - def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm] = None) -> List[ + def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit, + progress_bar: Optional[tqdm.tqdm] = None) -> List[ Unit]: - nx_pc = NXProbabilisticCircuit() variable = variables[self.variable] if progress_bar: progress_bar.set_postfix_str(f"Creating Dirac Delta distributions for variable {variable.name}") - nodes = [UnivariateContinuousLeaf(DiracDeltaDistribution(variable, location.item(), density_cap.item())) + nodes = [UnivariateContinuousLeaf(DiracDeltaDistribution(variable, location.item(), density_cap.item()), result) for location, density_cap in zip(self.location, self.density_cap)] progress_bar.update(self.number_of_nodes) - nx_pc.add_nodes_from(nodes) return nodes diff --git a/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py b/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py index 4092d04..a36f34d 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py @@ -81,7 +81,9 @@ def to_nx(self, progress_bar: bool = True) -> NXProbabilisticCircuit: progress_bar = tqdm.tqdm(total=number_of_edges, desc="Converting to nx") else: progress_bar = None - return self.root.to_nx(self.variables, progress_bar)[0].probabilistic_circuit + result = NXProbabilisticCircuit() + self.root.to_nx(self.variables, result, progress_bar) + return result def to_json(self) -> Dict[str, Any]: result = super().to_json() diff --git a/test/test_rat_spn/test_region_graph.py b/test/test_rat_spn/test_region_graph.py index b9c8def..8b7638a 100644 --- a/test/test_rat_spn/test_region_graph.py +++ b/test/test_rat_spn/test_region_graph.py @@ -21,14 +21,14 @@ class RandomRegionGraphTestCase(unittest.TestCase): region_graph = region_graph.create_random_region_graph() def test_region_graph(self): - self.assertEqual(len(self.region_graph.nodes()), len(self.region_graph.nodes())) + self.assertEqual(len(self.region_graph.nodes()), 19) def test_as_jpc(self): model = self.region_graph.as_probabilistic_circuit(input_units=10, sum_units=5) nx_model = model.to_nx() nx_model.plot_structure() - plt.show() - print(len(list(node for node in nx_model.nodes() if isinstance(node, SumUnit)))) + # plt.show() + self.assertEqual(len(list(node for node in nx_model.nodes() if isinstance(node, SumUnit))), 21) if __name__ == '__main__': From 58b17cf25bac2613fab0cf7dcf35aa3f186e2a72 Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Sat, 21 Dec 2024 09:36:26 +0100 Subject: [PATCH 07/10] Finished region graph to jax conversion --- .../probabilistic_circuit/jax/inner_layer.py | 19 +++++---- test/test_rat_spn/test_region_graph.py | 39 +++++++++++++++++++ 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py index c8f6e4f..f57e443 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py @@ -83,7 +83,7 @@ def all_layers_with_depth(self, depth: int = 0) -> List[Tuple[int, Layer]]: """ return [(depth, self)] - @property + @cached_property @abstractmethod def variables(self) -> jax.Array: """ @@ -202,7 +202,7 @@ def __init__(self, child_layers: List[Layer]): super().__init__() self.child_layers = child_layers - @property + @cached_property @abstractmethod def variables(self) -> jax.Array: raise NotImplementedError @@ -248,7 +248,7 @@ def __init__(self, variable: int): super().__init__() self._variables = jnp.array([variable]) - @property + @cached_property def variables(self) -> jax.Array: return self._variables @@ -454,6 +454,8 @@ class ProductLayer(InnerLayer): The shape is (#child_layers, #nodes). """ + _variables: Optional[jnp.array] = None + def __init__(self, child_layers: List[Layer], edges: BCOO): """ Initialize the product layer. @@ -463,6 +465,7 @@ def __init__(self, child_layers: List[Layer], edges: BCOO): """ super().__init__(child_layers) self.edges = edges + self.variables def validate(self): assert self.edges.shape == (len(self.child_layers), self.number_of_nodes), \ @@ -481,11 +484,13 @@ def nx_classes(cls) -> Tuple[Type, ...]: def number_of_components(self) -> int: return sum([cl.number_of_components for cl in self.child_layers]) + self.edges.nse - @cached_property + @property def variables(self) -> jax.Array: - variables = jnp.concatenate([child_layer.variables for child_layer in self.child_layers]) - variables = jnp.unique(variables) - return variables + if self._variables is None: + variables = jnp.concatenate([child_layer.variables for child_layer in self.child_layers]) + variables = jnp.unique(variables) + self._variables = variables + return self._variables def log_likelihood_of_nodes_single(self, x: jax.Array) -> jax.Array: result = jnp.zeros(self.number_of_nodes, dtype=jnp.float32) diff --git a/test/test_rat_spn/test_region_graph.py b/test/test_rat_spn/test_region_graph.py index 8b7638a..4e902a7 100644 --- a/test/test_rat_spn/test_region_graph.py +++ b/test/test_rat_spn/test_region_graph.py @@ -1,6 +1,7 @@ import random import unittest +from jax import tree_flatten from matplotlib import pyplot as plt from networkx.drawing.nx_agraph import graphviz_layout from random_events.variable import Continuous @@ -9,6 +10,10 @@ from probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit as JPC from probabilistic_model.probabilistic_circuit.jax.gaussian_layer import GaussianLayer import numpy as np +import equinox as eqx +import optax +import tqdm +import plotly.graph_objects as go np.random.seed(420) random.seed(420) @@ -31,5 +36,39 @@ def test_as_jpc(self): self.assertEqual(len(list(node for node in nx_model.nodes() if isinstance(node, SumUnit))), 21) + +class RandomRegionGraphLearningTestCase(unittest.TestCase): + + variables = SortedSet([Continuous(str(i)) for i in range(4)]) + region_graph = RegionGraph(variables, partitions=2, depth=1, repetitions=2) + region_graph = region_graph.create_random_region_graph() + + def test_learning(self): + data = np.random.uniform(0, 1, (10000, len(self.variables))) + model = self.region_graph.as_probabilistic_circuit(input_units=5, sum_units=5) + + root = model.root + + @eqx.filter_jit + def loss(model, x): + ll = model.log_likelihood_of_nodes(x) + return -jnp.mean(ll) + + optim = optax.adamw(0.01) + opt_state = optim.init(eqx.filter(root, eqx.is_inexact_array)) + + for _ in tqdm.trange(50): + loss_value, grads = eqx.filter_value_and_grad(loss)(root, data) + grads_of_sum_layer = eqx.filter(tree_flatten(grads), eqx.is_inexact_array)[0][0] + self.assertTrue(jnp.all(jnp.isfinite(grads_of_sum_layer))) + + updates, opt_state = optim.update( + grads, opt_state, eqx.filter(root, eqx.is_inexact_array) + ) + root = eqx.apply_updates(root, updates) + + + + if __name__ == '__main__': unittest.main() From ee5b7210813a05766064814de8df822bb4fb530f Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Sat, 21 Dec 2024 09:38:59 +0100 Subject: [PATCH 08/10] Fixed bug in uniform layer --- .../probabilistic_circuit/jax/inner_layer.py | 5 ----- .../probabilistic_circuit/jax/uniform_layer.py | 9 ++++----- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py index f57e443..c32f899 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py @@ -374,11 +374,6 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[SumUnit], number_of_nodes = len(nodes) # filter the child layers to only contain layers with the same scope as this one - print(variables) - print([child_layer.layer.variables for child_layer in child_layers]) - print(len(child_layers)) - print([n.variables for n in child_layers[0].nodes]) - print([n.variables for n in child_layers[1].nodes]) filtered_child_layers = [child_layer for child_layer in child_layers if (child_layer.layer.variables == variables).all()] log_weights = [] diff --git a/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py index d392a7b..8ed4c33 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py @@ -107,23 +107,22 @@ def merge_with(self, others: List[Self]) -> Self: def _from_json(cls, data: Dict[str, Any]) -> Self: return cls(data["variable"], jnp.array(data["interval"])) - def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm] = None) -> List[ - Unit]: + def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit, + progress_bar: Optional[tqdm.tqdm] = None) -> List[Unit]: variable = variables[self.variable] if progress_bar: progress_bar.set_postfix_str(f"Creating Uniform distributions for variable {variable.name}") - nx_pc = NXProbabilisticCircuit() nodes = [UnivariateContinuousLeaf( UniformDistribution(variable=variable, interval=random_events.interval.SimpleInterval(lower.item(), upper.item(), random_events.interval.Bound.OPEN, - random_events.interval.Bound.OPEN))) + random_events.interval.Bound.OPEN)), + result) for lower, upper in self.interval] if progress_bar: progress_bar.update(self.number_of_nodes) - nx_pc.add_nodes_from(nodes) return nodes From 12646bcffd98caabb5d3d4e1aa78898b6a9fef8a Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Sat, 21 Dec 2024 11:50:32 +0100 Subject: [PATCH 09/10] Removed dead code in uniform_layer.py --- .../jax/uniform_layer.py | 41 ------------------- 1 file changed, 41 deletions(-) diff --git a/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py index 8ed4c33..12f83e1 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py @@ -62,47 +62,6 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[Univariate result = cls(nodes[0].probabilistic_circuit.variables.index(variable), intervals) return NXConverterLayer(result, nodes, hash_remap) - def sample_from_frequencies(self, frequencies: np.array, result: np.array, start_index=0): - # sample from U(0,1) - standard_uniform_samples = np.random.uniform(size=(sum(frequencies), 1)) - - # calculate range for each node - range_per_sample = (self.upper - self.lower).repeat(frequencies).reshape(-1, 1) - - # calculate the right shift for each node - right_shift_per_sample = self.lower.repeat(frequencies).reshape(-1, 1) - - # apply the transformation to the desired intervals - samples = standard_uniform_samples * range_per_sample + right_shift_per_sample - - result[start_index:start_index + len(samples), self.variables] = samples - - def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array: - return jnp.clip((x - self.lower) / (self.upper - self.lower), 0, 1) - - def moment_of_nodes(self, order: jax.Array, center: jax.Array): - """ - Calculate the moment of the uniform distribution. - """ - order = order[self.variables[0]] - center = center[self.variables[0]] - pdf_value = jnp.exp(self.log_pdf_value()) - lower_integral_value = (pdf_value * (self.lower - center) ** (order + 1)) / (order + 1) - upper_integral_value = (pdf_value * (self.upper - center) ** (order + 1)) / (order + 1) - return (upper_integral_value - lower_integral_value).reshape(-1, 1) - - def log_conditional_from_simple_interval(self, interval: SimpleInterval) -> Tuple[Self, jax.Array]: - probabilities = jnp.log(self.probability_of_simple_interval(interval)) - open_interval_array = simple_interval_to_open_array(interval) - new_lowers = jnp.maximum(self.lower, open_interval_array[0]) - new_uppers = jnp.minimum(self.upper, open_interval_array[1]) - valid_intervals = new_lowers < new_uppers - new_intervals = jnp.stack([new_lowers[valid_intervals], new_uppers[valid_intervals]]).T - return self.__class__(self.variable, new_intervals), probabilities - - def merge_with(self, others: List[Self]) -> Self: - return self.__class__(self.variable, jnp.vstack([self.interval] + [other.interval for other in others])) - @classmethod def _from_json(cls, data: Dict[str, Any]) -> Self: return cls(data["variable"], jnp.array(data["interval"])) From 348eaff3519b6327c95fe1b0b523154f89a0dbbc Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Sat, 21 Dec 2024 12:39:37 +0100 Subject: [PATCH 10/10] Adeded discrete_layer.py --- .../jax/discrete_layer.py | 112 ++++++++++++++++++ .../jax/gaussian_layer.py | 2 +- test/test_jax/test_discrete_layer.py | 71 +++++++++++ 3 files changed, 184 insertions(+), 1 deletion(-) create mode 100644 src/probabilistic_model/probabilistic_circuit/jax/discrete_layer.py create mode 100644 test/test_jax/test_discrete_layer.py diff --git a/src/probabilistic_model/probabilistic_circuit/jax/discrete_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/discrete_layer.py new file mode 100644 index 0000000..1840f61 --- /dev/null +++ b/src/probabilistic_model/probabilistic_circuit/jax/discrete_layer.py @@ -0,0 +1,112 @@ +from typing import Tuple, Type, List, Dict, Any + +from random_events.set import SetElement +from random_events.variable import Symbolic, Variable +from sortedcontainers import SortedSet +from typing_extensions import Self, Optional + +from . import NXConverterLayer +from .inner_layer import InputLayer +import jax.numpy as jnp + +from ..nx.distributions import UnivariateDiscreteLeaf +from ..nx.probabilistic_circuit import Unit +from ...distributions import SymbolicDistribution +import tqdm +import numpy as np +from ..nx.probabilistic_circuit import ProbabilisticCircuit as NXProbabilisticCircuit + +from ...utils import MissingDict + + +class DiscreteLayer(InputLayer): + + log_probabilities: jnp.array + """ + The logarithm of probability for each state of the variable. + + The shape is (#nodes, #states). + """ + + def __init__(self, variable: int, log_probabilities: jnp.array): + super().__init__(variable) + self.log_probabilities = log_probabilities + + @classmethod + def nx_classes(cls) -> Tuple[Type, ...]: + return SymbolicDistribution, + + def validate(self): + return True + + @property + def normalization_constant(self) -> jnp.array: + return jnp.exp(self.log_probabilities).sum(1) + + @property + def log_normalization_constant(self) -> jnp.array: + return jnp.log(self.normalization_constant) + + @property + def normalized_log_probabilities(self) -> jnp.array: + return self.log_probabilities - self.log_normalization_constant[:, None] + + @property + def number_of_nodes(self) -> int: + return self.log_probabilities.shape[0] + + def log_likelihood_of_nodes_single(self, x: jnp.array) -> jnp.array: + return self.log_probabilities.T[x.astype(int)] - self.log_normalization_constant + + + @classmethod + def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[UnivariateDiscreteLeaf], + child_layers: List[NXConverterLayer], + progress_bar: bool = True) -> \ + NXConverterLayer: + hash_remap = {hash(node): index for index, node in enumerate(nodes)} + + variable: Symbolic = nodes[0].variable + + parameters = np.zeros((len(nodes), len(variable.domain.simple_sets))) + + for node in (tqdm.tqdm(nodes, desc=f"Creating discrete layer for variable {variable.name}") + if progress_bar else nodes): + for state, value in node.distribution.probabilities.items(): + parameters[hash_remap[hash(node)], state] = value + + + result = cls(nodes[0].probabilistic_circuit.variables.index(variable), jnp.log(parameters)) + return NXConverterLayer(result, nodes, hash_remap) + + def to_json(self) -> Dict[str, Any]: + return {**super().to_json(), + "variable": self.variable, "log_probabilities": self.log_probabilities.tolist()} + + @classmethod + def _from_json(cls, data: Dict[str, Any]) -> Self: + return cls(data["variable"], jnp.array(data["log_probabilities"])) + + def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit, + progress_bar: Optional[tqdm.tqdm] = None) -> List[Unit]: + + variable = variables[self.variable] + domain = variable.domain_type() + + if progress_bar: + progress_bar.set_postfix_str(f"Creating discrete distributions for variable {variable.name}") + + nodes = [UnivariateDiscreteLeaf(SymbolicDistribution(variable, MissingDict(float, + {domain(state): value for state, value in enumerate(jnp.exp(log_probabilities))})), result) + for log_probabilities in self.normalized_log_probabilities] + + if progress_bar: + progress_bar.update(self.number_of_nodes) + return nodes + + + + + + + diff --git a/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py index 00da4ec..8388e90 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/gaussian_layer.py @@ -37,7 +37,7 @@ def nx_classes(cls) -> Tuple[Type, ...]: def validate(self): assert self.location.shape == self.log_scale.shape, "The shapes of location and scale must match." assert self.min_scale.shape == self.log_scale.shape, "The shapes of the min_scale and scale bounds must match." - assert jnp.all(self.min_scale >= 0), "The scale must be positive." + assert jnp.all(self.min_scale >= 0), "The minimum scale must be positive." @property def number_of_nodes(self) -> int: diff --git a/test/test_jax/test_discrete_layer.py b/test/test_jax/test_discrete_layer.py new file mode 100644 index 0000000..9e976e6 --- /dev/null +++ b/test/test_jax/test_discrete_layer.py @@ -0,0 +1,71 @@ +import unittest + +from random_events.set import SetElement +from random_events.variable import Continuous, Symbolic +from sortedcontainers import SortedSet + +from probabilistic_model.distributions import SymbolicDistribution +from probabilistic_model.probabilistic_circuit.jax.discrete_layer import DiscreteLayer +from probabilistic_model.probabilistic_circuit.jax.gaussian_layer import GaussianLayer, GaussianDistribution +from probabilistic_model.probabilistic_circuit.nx.probabilistic_circuit import \ + ProbabilisticCircuit as NXProbabilisticCircuit, SumUnit +from probabilistic_model.probabilistic_circuit.nx.distributions import UnivariateContinuousLeaf, UnivariateDiscreteLeaf +from probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit +import jax.numpy as jnp + +from probabilistic_model.utils import MissingDict + + +class Animal(SetElement): + EMPTY_SET = -1 + CAT = 0 + DOG = 1 + FISH = 2 + +class DiscreteLayerTestCase(unittest.TestCase): + + model: DiscreteLayer + x = Symbolic("x", Animal) + + @classmethod + def setUpClass(cls): + cls.model = DiscreteLayer(0, jnp.log(jnp.array([[0, 1, 2], [3, 4, 0]]))) + cls.model.validate() + + def test_normalization(self): + result = self.model.normalization_constant + correct = jnp.array([3., 7.]) + self.assertTrue(jnp.allclose(result, correct, atol=1e-3)) + + def test_log_likelihood(self): + x = jnp.array([0.0]) + result = self.model.log_likelihood_of_nodes_single(x) + correct = jnp.log(jnp.array([.0, 3/7])) + self.assertTrue(jnp.allclose(result, correct, atol=1e-3)) + + def test_from_nx(self): + + p1 = MissingDict(float, {Animal.CAT: 0., Animal.DOG: 1, Animal.FISH: 2}) + d1 = UnivariateDiscreteLeaf(SymbolicDistribution(self.x, p1)) + + p2 = MissingDict(float, {Animal.CAT: 3, Animal.DOG: 4, Animal.FISH: 0}) + d2 = UnivariateDiscreteLeaf(SymbolicDistribution(self.x, p2)) + s = SumUnit() + s.add_subcircuit(d1, 0.5) + s.add_subcircuit(d2, 0.5) + nx_pc = s.probabilistic_circuit + jax_pc = ProbabilisticCircuit.from_nx(nx_pc) + discrete_layer = jax_pc.root.child_layers[0] + self.assertIsInstance(discrete_layer, DiscreteLayer) + self.assertEqual(discrete_layer.variable, 0) + self.assertEqual(discrete_layer.log_probabilities.shape, (2, 3)) + self.assertTrue(jnp.allclose(discrete_layer.log_probabilities, self.model.log_probabilities)) + + def test_to_nx(self): + nx_circuit = self.model.to_nx(SortedSet([self.x]), NXProbabilisticCircuit())[0].probabilistic_circuit + self.assertEqual(len(nx_circuit.nodes()), 2) + self.assertEqual(len(nx_circuit.edges()), 0) + + +if __name__ == '__main__': + unittest.main()