Skip to content

Commit

Permalink
Merge pull request #13 from tomsch420/jax-dev
Browse files Browse the repository at this point in the history
Jax dev
  • Loading branch information
tomsch420 authored Dec 21, 2024
2 parents 5714067 + 348eaff commit bf02e39
Show file tree
Hide file tree
Showing 20 changed files with 482 additions and 904 deletions.
2 changes: 1 addition & 1 deletion doc/pendulum.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ purely knowledge graph driven AI.
A full swing takes a couple of years.
Frank argued that there needs to be a middle ground where machine learning and knowledge graphs come together.

I belive that the implementation of probabilistic models in this package is capable of doing so.
I believe that the implementation of probabilistic models in this package is capable of doing so.
Knowledge graphs generate sets that describe possible assignments
that match the constraints and instance knowledge of ontologies (a random event, so to say).
Probability distributions describe the likelihoods of every possible solution.
Expand Down
2 changes: 1 addition & 1 deletion scripts/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
number_of_variables = 2
number_of_samples_per_component = 100000
number_of_components = 2
number_of_mixtures = 100
number_of_mixtures = 1000
number_of_iterations = 1000

# model selection
Expand Down
2 changes: 1 addition & 1 deletion scripts/jpt_speed_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def timed_jax_method():
# ll = jax_model.log_likelihood(samples)
# assert (ll > -jnp.inf).all()

times_nx, times_jax = eval_performance(nx_model.log_likelihood, (data, ), compiled_ll_jax, (data_jax, ), 10, 5)
times_nx, times_jax = eval_performance(nx_model.log_likelihood, (data, ), compiled_ll_jax, (data_jax, ), 15, 10)
# times_nx, times_jax = eval_performance(nx_model.probability_of_simple_event, (event,), jax_model.probability_of_simple_event, (event,), 20, 10)
# times_nx, times_jax = eval_performance(nx_model.sample, (10000, ), jax_model.sample, (10000,), 10, 5)

Expand Down
4 changes: 2 additions & 2 deletions scripts/nyga_speed_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
path_prefix = os.path.join(os.path.expanduser("~"), "Documents")
nx_model_path = os.path.join(path_prefix, "nx_nyga.pm")
jax_model_path = os.path.join(path_prefix, "jax_nyga.pm")
load_from_disc = True
load_from_disc = False
save_to_disc = True


Expand Down Expand Up @@ -103,7 +103,7 @@ def timed_jax_method():

# times_nx, times_jax = eval_performance(nx_model.log_likelihood, (data, ), compiled_ll_jax, (data_jax, ), 20, 2)
# times_nx, times_jax = eval_performance(prob_nx, event, prob_jax, event, 15, 10)
times_nx, times_jax = eval_performance(nx_model.sample, (1000, ), jax_model.sample, (1000, ), 10, 5)
times_nx, times_jax = eval_performance(nx_model.sample, (1000, ), jax_model.sample, (1000, ), 5, 10)

time_jax = np.mean(times_jax), np.std(times_jax)
time_nx = np.mean(times_nx), np.std(times_nx)
Expand Down
Empty file.
160 changes: 160 additions & 0 deletions src/probabilistic_model/learning/region_graph/region_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from collections import deque

import networkx as nx
import numpy as np
from jax.experimental.sparse import BCOO
from random_events.variable import Continuous
from sortedcontainers import SortedSet
from typing_extensions import List, Self, Type

from ...distributions import GaussianDistribution
from ...probabilistic_circuit.jax import SumLayer, ProductLayer
from ...probabilistic_circuit.jax.gaussian_layer import GaussianLayer
from ...probabilistic_circuit.nx.probabilistic_circuit import ProbabilisticCircuit, SumUnit, ProductUnit
from ...probabilistic_circuit.nx.distributions.distributions import UnivariateContinuousLeaf
from ...probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit as JPC
import jax.numpy as jnp
import jax.random


class Region:
variables: SortedSet

def __init__(self, variables: SortedSet):
self.variables = variables

def __hash__(self) -> int:
return id(self)

def random_partition(self, k=2) -> List[Self]:
indices = np.arange(len(self.variables))
np.random.shuffle(indices)
partitions = [Region(SortedSet([self.variables[index] for index in split])) for split in np.array_split(indices, k)]
return partitions

def __repr__(self) -> str:
return "{" + ", ".join([v.name for v in self.variables]) + "}"

class Partition:
def __hash__(self) -> int:
return id(self)


class RegionGraph(nx.DiGraph):

variables: SortedSet

def __init__(self, variables: SortedSet,
partitions: int = 2,
depth:int = 2,
repetitions:int = 2):
super().__init__()
self.variables = variables
self.partitions = partitions
self.depth = depth
self.repetitions = repetitions


def create_random_region_graph(self):
root = Region(self.variables)
self.add_node(root)
for repetition in range(self.repetitions):
self.recursive_split(root)
return self

def regions(self):
for node in self.nodes:
if isinstance(node, Region):
yield node

def partition_nodes(self):
for node in self.nodes:
if isinstance(node, Partition):
yield node

def recursive_split(self, node: Region):
root_partition = Partition()
self.add_edge(node, root_partition)
remaining_regions = deque([(node, self.depth, root_partition)])

while remaining_regions:
region, depth, partition = remaining_regions.popleft()


if len(region.variables) == 1:
continue

if depth == 0:
for variable in region.variables:
self.add_edge(partition, Region(SortedSet([variable])))
continue

new_regions = region.random_partition(self.partitions)
for new_region in new_regions:
self.add_edge(partition, new_region)
new_partition = Partition()
self.add_edge(new_region, new_partition)
remaining_regions.append((new_region, depth - 1, new_partition))

@property
def root(self) -> Region:
"""
The root of the circuit is the node with in-degree 0.
This is the output node, that will perform the final computation.
:return: The root of the circuit.
"""
possible_roots = [node for node in self.nodes() if self.in_degree(node) == 0]
if len(possible_roots) > 1:
raise ValueError(f"More than one root found. Possible roots are {possible_roots}")

return possible_roots[0]


def as_probabilistic_circuit(self, continuous_distribution_type: Type = GaussianLayer,
input_units: int = 5, sum_units: int = 5, key=jax.random.PRNGKey(69)) -> JPC:
root = self.root

# create nodes for each region
for layer in reversed(list(nx.bfs_layers(self, root))):
for node in layer:
children = list(self.successors(node))
parents = list(self.predecessors(node))
if isinstance(node, Region):
# if the region is a leaf
if len(children) == 0:
variable = node.variables[0]
variable_index = self.variables.index(variable)
if isinstance(variable, Continuous):
location = jax.random.uniform(key, shape=(input_units,), minval=-1., maxval=1.)
log_scale = jnp.log(jax.random.uniform(key, shape=(input_units,), minval=0.01, maxval=1.))
node.layer = GaussianLayer(variable_index, location=location, log_scale=log_scale, min_scale=jnp.full_like(location, 0.01))
node.layer.validate()
else:
raise NotImplementedError

# if the region is root or in the middle
else:
# if the region is root
if len(parents) == 0:
sum_units = 1

log_weights = [BCOO.fromdense(jax.random.uniform(key, shape=(sum_units, child.layer.number_of_nodes), minval=0., maxval=1.)) for child in children]
for log_weight in log_weights:
log_weight.data = jnp.log(log_weight.data)
node.layer = SumLayer([child.layer for child in children], log_weights=log_weights)
node.layer.validate()


elif isinstance(node, Partition):
node_lengths = [child.layer.number_of_nodes for child in children]
assert (len(set(node_lengths)) == 1), "Node lengths must be all equal. Got {}".format(node_lengths)

edges = jnp.arange(node_lengths[0]).reshape(1, -1).repeat(len(children), axis=0)
sparse_edges = BCOO.fromdense(jnp.ones_like(edges))
sparse_edges.data = edges.flatten()
node.layer = ProductLayer([child.layer for child in children], sparse_edges)
node.layer.validate()

model = JPC(self.variables, root.layer)
return model
112 changes: 112 additions & 0 deletions src/probabilistic_model/probabilistic_circuit/jax/discrete_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import Tuple, Type, List, Dict, Any

from random_events.set import SetElement
from random_events.variable import Symbolic, Variable
from sortedcontainers import SortedSet
from typing_extensions import Self, Optional

from . import NXConverterLayer
from .inner_layer import InputLayer
import jax.numpy as jnp

from ..nx.distributions import UnivariateDiscreteLeaf
from ..nx.probabilistic_circuit import Unit
from ...distributions import SymbolicDistribution
import tqdm
import numpy as np
from ..nx.probabilistic_circuit import ProbabilisticCircuit as NXProbabilisticCircuit

from ...utils import MissingDict


class DiscreteLayer(InputLayer):

log_probabilities: jnp.array
"""
The logarithm of probability for each state of the variable.
The shape is (#nodes, #states).
"""

def __init__(self, variable: int, log_probabilities: jnp.array):
super().__init__(variable)
self.log_probabilities = log_probabilities

@classmethod
def nx_classes(cls) -> Tuple[Type, ...]:
return SymbolicDistribution,

def validate(self):
return True

@property
def normalization_constant(self) -> jnp.array:
return jnp.exp(self.log_probabilities).sum(1)

@property
def log_normalization_constant(self) -> jnp.array:
return jnp.log(self.normalization_constant)

@property
def normalized_log_probabilities(self) -> jnp.array:
return self.log_probabilities - self.log_normalization_constant[:, None]

@property
def number_of_nodes(self) -> int:
return self.log_probabilities.shape[0]

def log_likelihood_of_nodes_single(self, x: jnp.array) -> jnp.array:
return self.log_probabilities.T[x.astype(int)] - self.log_normalization_constant


@classmethod
def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[UnivariateDiscreteLeaf],
child_layers: List[NXConverterLayer],
progress_bar: bool = True) -> \
NXConverterLayer:
hash_remap = {hash(node): index for index, node in enumerate(nodes)}

variable: Symbolic = nodes[0].variable

parameters = np.zeros((len(nodes), len(variable.domain.simple_sets)))

for node in (tqdm.tqdm(nodes, desc=f"Creating discrete layer for variable {variable.name}")
if progress_bar else nodes):
for state, value in node.distribution.probabilities.items():
parameters[hash_remap[hash(node)], state] = value


result = cls(nodes[0].probabilistic_circuit.variables.index(variable), jnp.log(parameters))
return NXConverterLayer(result, nodes, hash_remap)

def to_json(self) -> Dict[str, Any]:
return {**super().to_json(),
"variable": self.variable, "log_probabilities": self.log_probabilities.tolist()}

@classmethod
def _from_json(cls, data: Dict[str, Any]) -> Self:
return cls(data["variable"], jnp.array(data["log_probabilities"]))

def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit,
progress_bar: Optional[tqdm.tqdm] = None) -> List[Unit]:

variable = variables[self.variable]
domain = variable.domain_type()

if progress_bar:
progress_bar.set_postfix_str(f"Creating discrete distributions for variable {variable.name}")

nodes = [UnivariateDiscreteLeaf(SymbolicDistribution(variable, MissingDict(float,
{domain(state): value for state, value in enumerate(jnp.exp(log_probabilities))})), result)
for log_probabilities in self.normalized_log_probabilities]

if progress_bar:
progress_bar.update(self.number_of_nodes)
return nodes







Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def nx_classes(cls) -> Tuple[Type, ...]:
def validate(self):
assert self.location.shape == self.log_scale.shape, "The shapes of location and scale must match."
assert self.min_scale.shape == self.log_scale.shape, "The shapes of the min_scale and scale bounds must match."
assert jnp.all(self.min_scale >= 0), "The scale must be positive."
assert jnp.all(self.min_scale >= 0), "The minimum scale must be positive."

@property
def number_of_nodes(self) -> int:
return len(self.location.shape)
return self.location.shape[0]

@property
def scale(self) -> jnp.array:
Expand Down Expand Up @@ -80,20 +80,18 @@ def _from_json(cls, data: Dict[str, Any]) -> Self:
return cls(data["variable"], jnp.array(data["location"]), jnp.array(data["scale"]),
jnp.array(data["min_scale"]))

def to_nx(self, variables: SortedSet[Variable], progress_bar: Optional[tqdm.tqdm] = None) -> List[
Unit]:
def to_nx(self, variables: SortedSet[Variable], result: NXProbabilisticCircuit,
progress_bar: Optional[tqdm.tqdm] = None) -> List[Unit]:
variable = variables[self.variable]

if progress_bar:
progress_bar.set_postfix_str(f"Creating Gaussian distributions for variable {variable.name}")

nx_pc = NXProbabilisticCircuit()
nodes = [UnivariateContinuousLeaf(
GaussianDistribution(variable=variable, location=location.item(), scale=scale.item()))
GaussianDistribution(variable=variable, location=location.item(), scale=scale.item()),result)
for location, scale in zip(self.location, self.scale)]

if progress_bar:
progress_bar.update(self.number_of_nodes)

nx_pc.add_nodes_from(nodes)
return nodes
Loading

0 comments on commit bf02e39

Please sign in to comment.