Skip to content

Commit

Permalink
First version of LearnRSPN
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Jan 28, 2025
1 parent b2ce154 commit 56aaa0b
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@
from collections import deque
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from inspect import isclass

import networkx as nx
import pandas as pd
import sqlalchemy
from sqlalchemy import Column, Table, select
from sqlalchemy.orm import Session, DeclarativeBase
from typing_extensions import Type, Iterable, Dict, Tuple, Any, Self, List
from typing_extensions import Type, Iterable, Dict, Tuple, Any, Self, List, Optional

from .helper import fully_factorized
from .probabilistic_circuit import ProbabilisticCircuit
from .probabilistic_circuit import ProbabilisticCircuit, ProductUnit, Unit
from ...learning.jpt.jpt import JPT
from ...learning.jpt.variables import infer_variables_from_dataframe


class EdgeType(str, Enum):
Expand Down Expand Up @@ -154,6 +158,10 @@ def __init__(self, base_table: Type[PartDecompositionBaseMixin], **attr):
super().__init__(**attr)
self.base_table = base_table

@property
def roots(self):
return [node for node in self.nodes if len(list(self.predecessors(node))) == 0]

def all_wrapped_classes(self) -> Iterable[WrappedTable]:
"""
:return: List of all classes (tables) in the database wrapped into the WrappedTable class.
Expand Down Expand Up @@ -240,24 +248,28 @@ def make_graph(self):
return self

def plot(self):
roots = [node for node in self.nodes if len(list(self.predecessors(node))) == 0]

roots = self.roots
pos = nx.bfs_layout(self, roots[0])
edge_labels = {edge: label.value for edge, label in self.edge_labels().items()}
nx.draw(self, pos=pos, with_labels=True)
nx.draw_networkx_edge_labels(self, pos=pos, edge_labels=edge_labels)


class ExchangeableDistributionTemplate:
class ExchangeableDistributionTemplate(Unit):
"""
A distribution template that is exchangeable.
Exchangeable means that it is permutation invariant, e.g. P(X, Y, Z) = P(Y, X, Z) = P(Y, Z, X) = ...
"""

template_distribution: ProbabilisticCircuit
template_model: ProbabilisticCircuit

def ground(self, variables):
return fully_factorized(variables)
def __init__(self, template_model: ProbabilisticCircuit,
probabilistic_circuit: Optional[ProbabilisticCircuit] = None):
super().__init__(probabilistic_circuit)
self.template_model = template_model

def ground(self, instances: List[PartDecompositionBaseMixin]):
...


class RelationalProbabilisticCircuit(ProbabilisticCircuit):
Expand All @@ -275,16 +287,56 @@ def __init__(self, base_table: Type[PartDecompositionBaseMixin], session: Sessio
def ground(self, session):
...


def learn(self):
tasks = deque()
roots = self.part_decomposition.roots
assert len(roots) == 1, "I think that this must be 1"

initial_instances = self.session.scalars(select(roots[0].table)).all()
self.fitting_step(initial_instances)


def fitting_step(self, instances: Iterable[PartDecompositionBaseMixin]) -> ProbabilisticCircuit:

# infer current class (table) that is handled
table: Type[PartDecompositionBaseMixin] = instances[0].__class__

# construct dataframe
attribute_column_names = [column.name for column in table.attributes()]
aggregated_column_names = [column.attrname for column in table._aggregated_columns.values()]
columns_names = attribute_column_names + aggregated_column_names
df = pd.DataFrame(columns=columns_names,
data=[[getattr(instance, column_name) for column_name in columns_names]
for instance in instances])

# fit jpt
variables = infer_variables_from_dataframe(df)
class_model = JPT(variables, min_samples_leaf=20)
class_model.keep_sample_indices = True
class_model.fit(df)

# replace aggregated columns with EDT
for relationship, prop in table._aggregated_columns.items():

relationship_attribute_name: str = relationship.class_attribute.key

for product in class_model.root.subcircuits:
product: ProductUnit

univariate_model = [subcircuit for subcircuit in product.subcircuits if
subcircuit.variables[0].name == prop.attrname][0]

assert len(univariate_model.variables) == 1, "I think that this must be 1"

instances_of_relationship = [instance for index in product.sample_indices for instance in
getattr(instances[index], relationship_attribute_name)]

template_model = self.fitting_step(instances_of_relationship)
edt = ExchangeableDistributionTemplate(template_model, class_model)

class_model.remove_node(univariate_model)
product.add_subcircuit(edt, mount=False)

def gather_data(self, cls: Type[PartDecompositionBaseMixin]):
node = WrappedTable(cls)
result = self.session.scalars(select(cls)).all()
return class_model

attributes = cls.attributes()

for element in result:
for attribute in attributes:
print(attribute.name, element.__dict__[attribute.name])
print("-----------------")
11 changes: 4 additions & 7 deletions test/test_relational/test_mutagenesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from matplotlib import pyplot as plt
from sqlalchemy import create_engine, select, ForeignKey
from sqlalchemy.orm import Mapped, MappedAsDataclass, mapped_column, relationship, Session
from sqlalchemy.orm import Mapped, MappedAsDataclass, mapped_column, relationship, Session, column_property

from probabilistic_model.probabilistic_circuit.nx.relational_probabilistic_circuit import PartDecompositionBaseMixin, \
AssociationMixin, PartDecomposition, RelationalProbabilisticCircuit
Expand Down Expand Up @@ -80,12 +80,9 @@ def test_data_getting(self):
self.assertGreater(len(r), 0)

def test_aggregation_statistics(self):
print(Molecule._aggregated_columns)
print(Atom._aggregated_columns)
exit()
for m in self.session.scalars(select(Molecule)).all():
print(m.mean_charge_of_atoms)
exit()
for k, v in m._aggregated_columns.items():
self.assertIsInstance(m.__getattribute__(v.attrname), float)

def test_pd(self):
pd = PartDecomposition(Base).make_graph()
Expand All @@ -94,4 +91,4 @@ def test_pd(self):

def test_rspn(self):
model = RelationalProbabilisticCircuit(Base, self.session)
model.gather_data(Atom)
model.learn()

0 comments on commit 56aaa0b

Please sign in to comment.