-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from tomsch420/jax-dev
Jax dev
- Loading branch information
Showing
20 changed files
with
482 additions
and
904 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
160 changes: 160 additions & 0 deletions
160
src/probabilistic_model/learning/region_graph/region_graph.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
from collections import deque | ||
|
||
import networkx as nx | ||
import numpy as np | ||
from jax.experimental.sparse import BCOO | ||
from random_events.variable import Continuous | ||
from sortedcontainers import SortedSet | ||
from typing_extensions import List, Self, Type | ||
|
||
from ...distributions import GaussianDistribution | ||
from ...probabilistic_circuit.jax import SumLayer, ProductLayer | ||
from ...probabilistic_circuit.jax.gaussian_layer import GaussianLayer | ||
from ...probabilistic_circuit.nx.probabilistic_circuit import ProbabilisticCircuit, SumUnit, ProductUnit | ||
from ...probabilistic_circuit.nx.distributions.distributions import UnivariateContinuousLeaf | ||
from ...probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit as JPC | ||
import jax.numpy as jnp | ||
import jax.random | ||
|
||
|
||
class Region: | ||
variables: SortedSet | ||
|
||
def __init__(self, variables: SortedSet): | ||
self.variables = variables | ||
|
||
def __hash__(self) -> int: | ||
return id(self) | ||
|
||
def random_partition(self, k=2) -> List[Self]: | ||
indices = np.arange(len(self.variables)) | ||
np.random.shuffle(indices) | ||
partitions = [Region(SortedSet([self.variables[index] for index in split])) for split in np.array_split(indices, k)] | ||
return partitions | ||
|
||
def __repr__(self) -> str: | ||
return "{" + ", ".join([v.name for v in self.variables]) + "}" | ||
|
||
class Partition: | ||
def __hash__(self) -> int: | ||
return id(self) | ||
|
||
|
||
class RegionGraph(nx.DiGraph): | ||
|
||
variables: SortedSet | ||
|
||
def __init__(self, variables: SortedSet, | ||
partitions: int = 2, | ||
depth:int = 2, | ||
repetitions:int = 2): | ||
super().__init__() | ||
self.variables = variables | ||
self.partitions = partitions | ||
self.depth = depth | ||
self.repetitions = repetitions | ||
|
||
|
||
def create_random_region_graph(self): | ||
root = Region(self.variables) | ||
self.add_node(root) | ||
for repetition in range(self.repetitions): | ||
self.recursive_split(root) | ||
return self | ||
|
||
def regions(self): | ||
for node in self.nodes: | ||
if isinstance(node, Region): | ||
yield node | ||
|
||
def partition_nodes(self): | ||
for node in self.nodes: | ||
if isinstance(node, Partition): | ||
yield node | ||
|
||
def recursive_split(self, node: Region): | ||
root_partition = Partition() | ||
self.add_edge(node, root_partition) | ||
remaining_regions = deque([(node, self.depth, root_partition)]) | ||
|
||
while remaining_regions: | ||
region, depth, partition = remaining_regions.popleft() | ||
|
||
|
||
if len(region.variables) == 1: | ||
continue | ||
|
||
if depth == 0: | ||
for variable in region.variables: | ||
self.add_edge(partition, Region(SortedSet([variable]))) | ||
continue | ||
|
||
new_regions = region.random_partition(self.partitions) | ||
for new_region in new_regions: | ||
self.add_edge(partition, new_region) | ||
new_partition = Partition() | ||
self.add_edge(new_region, new_partition) | ||
remaining_regions.append((new_region, depth - 1, new_partition)) | ||
|
||
@property | ||
def root(self) -> Region: | ||
""" | ||
The root of the circuit is the node with in-degree 0. | ||
This is the output node, that will perform the final computation. | ||
:return: The root of the circuit. | ||
""" | ||
possible_roots = [node for node in self.nodes() if self.in_degree(node) == 0] | ||
if len(possible_roots) > 1: | ||
raise ValueError(f"More than one root found. Possible roots are {possible_roots}") | ||
|
||
return possible_roots[0] | ||
|
||
|
||
def as_probabilistic_circuit(self, continuous_distribution_type: Type = GaussianLayer, | ||
input_units: int = 5, sum_units: int = 5, key=jax.random.PRNGKey(69)) -> JPC: | ||
root = self.root | ||
|
||
# create nodes for each region | ||
for layer in reversed(list(nx.bfs_layers(self, root))): | ||
for node in layer: | ||
children = list(self.successors(node)) | ||
parents = list(self.predecessors(node)) | ||
if isinstance(node, Region): | ||
# if the region is a leaf | ||
if len(children) == 0: | ||
variable = node.variables[0] | ||
variable_index = self.variables.index(variable) | ||
if isinstance(variable, Continuous): | ||
location = jax.random.uniform(key, shape=(input_units,), minval=-1., maxval=1.) | ||
log_scale = jnp.log(jax.random.uniform(key, shape=(input_units,), minval=0.01, maxval=1.)) | ||
node.layer = GaussianLayer(variable_index, location=location, log_scale=log_scale, min_scale=jnp.full_like(location, 0.01)) | ||
node.layer.validate() | ||
else: | ||
raise NotImplementedError | ||
|
||
# if the region is root or in the middle | ||
else: | ||
# if the region is root | ||
if len(parents) == 0: | ||
sum_units = 1 | ||
|
||
log_weights = [BCOO.fromdense(jax.random.uniform(key, shape=(sum_units, child.layer.number_of_nodes), minval=0., maxval=1.)) for child in children] | ||
for log_weight in log_weights: | ||
log_weight.data = jnp.log(log_weight.data) | ||
node.layer = SumLayer([child.layer for child in children], log_weights=log_weights) | ||
node.layer.validate() | ||
|
||
|
||
elif isinstance(node, Partition): | ||
node_lengths = [child.layer.number_of_nodes for child in children] | ||
assert (len(set(node_lengths)) == 1), "Node lengths must be all equal. Got {}".format(node_lengths) | ||
|
||
edges = jnp.arange(node_lengths[0]).reshape(1, -1).repeat(len(children), axis=0) | ||
sparse_edges = BCOO.fromdense(jnp.ones_like(edges)) | ||
sparse_edges.data = edges.flatten() | ||
node.layer = ProductLayer([child.layer for child in children], sparse_edges) | ||
node.layer.validate() | ||
|
||
model = JPC(self.variables, root.layer) | ||
return model |
112 changes: 112 additions & 0 deletions
112
src/probabilistic_model/probabilistic_circuit/jax/discrete_layer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from typing import Tuple, Type, List, Dict, Any | ||
|
||
from random_events.set import SetElement | ||
from random_events.variable import Symbolic, Variable | ||
from sortedcontainers import SortedSet | ||
from typing_extensions import Self, Optional | ||
|
||
from . import NXConverterLayer | ||
from .inner_layer import InputLayer | ||
import jax.numpy as jnp | ||
|
||
from ..nx.distributions import UnivariateDiscreteLeaf | ||
from ..nx.probabilistic_circuit import Unit | ||
from ...distributions import SymbolicDistribution | ||
import tqdm | ||
import numpy as np | ||
from ..nx.probabilistic_circuit import ProbabilisticCircuit as NXProbabilisticCircuit | ||
|
||
from ...utils import MissingDict | ||
|
||
|
||
class DiscreteLayer(InputLayer): | ||
|
||
log_probabilities: jnp.array | ||
""" | ||
The logarithm of probability for each state of the variable. | ||
The shape is (#nodes, #states). | ||
""" | ||
|
||
def __init__(self, variable: int, log_probabilities: jnp.array): | ||
super().__init__(variable) | ||
self.log_probabilities = log_probabilities | ||
|
||
@classmethod | ||
def nx_classes(cls) -> Tuple[Type, ...]: | ||
return SymbolicDistribution, | ||
|
||
def validate(self): | ||
return True | ||
|
||
@property | ||
def normalization_constant(self) -> jnp.array: | ||
return jnp.exp(self.log_probabilities).sum(1) | ||
|
||
@property | ||
def log_normalization_constant(self) -> jnp.array: | ||
return jnp.log(self.normalization_constant) | ||
|
||
@property | ||
def normalized_log_probabilities(self) -> jnp.array: | ||
return self.log_probabilities - self.log_normalization_constant[:, None] | ||
|
||
@property | ||
def number_of_nodes(self) -> int: | ||
return self.log_probabilities.shape[0] | ||
|
||
def log_likelihood_of_nodes_single(self, x: jnp.array) -> jnp.array: | ||
return self.log_probabilities.T[x.astype(int)] - self.log_normalization_constant | ||
|
||
|
||
@classmethod | ||
def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[UnivariateDiscreteLeaf], | ||
child_layers: List[NXConverterLayer], | ||
progress_bar: bool = True) -> \ | ||
NXConverterLayer: | ||
hash_remap = {hash(node): index for index, node in enumerate(nodes)} | ||
|
||
variable: Symbolic = nodes[0].variable | ||
|
||
parameters = np.zeros((len(nodes), len(variable.domain.simple_sets))) | ||
|
||
for node in (tqdm.tqdm(nodes, desc=f"Creating discrete layer for variable {variable.name}") | ||
if progress_bar else nodes): | ||
for state, value in node.distribution.probabilities.items(): | ||
parameters[hash_remap[hash(node)], state] = value | ||
|
||
|
||
result = cls(nodes[0].probabilistic_circuit.variables.index(variable), jnp.log(parameters)) | ||
return NXConverterLayer(result, nodes, hash_remap) | ||
|
||
def to_json(self) -> Dict[str, Any]: | ||
return {**super().to_json(), | ||
"variable": self.variable, "log_probabilities": self.log_probabilities.tolist()} | ||
|
||
@classmethod | ||
def _from_json(cls, data: Dict[str, Any]) -> Self: | ||
return cls(data["variable"], jnp.array(data["log_probabilities"])) | ||
|
||
def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit, | ||
progress_bar: Optional[tqdm.tqdm] = None) -> List[Unit]: | ||
|
||
variable = variables[self.variable] | ||
domain = variable.domain_type() | ||
|
||
if progress_bar: | ||
progress_bar.set_postfix_str(f"Creating discrete distributions for variable {variable.name}") | ||
|
||
nodes = [UnivariateDiscreteLeaf(SymbolicDistribution(variable, MissingDict(float, | ||
{domain(state): value for state, value in enumerate(jnp.exp(log_probabilities))})), result) | ||
for log_probabilities in self.normalized_log_probabilities] | ||
|
||
if progress_bar: | ||
progress_bar.update(self.number_of_nodes) | ||
return nodes | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.