Skip to content

Commit

Permalink
Started to work on relational SPNs
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Jan 22, 2025
1 parent 9736d4a commit d9c3ff0
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/probabilistic_model/probabilistic_circuit/nx/helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from random_events.product_algebra import Event, SimpleEvent
from random_events.variable import Continuous, Integer, Symbolic, Variable
from typing_extensions import Iterable
from typing_extensions import Iterable, Optional, Dict

from .distributions import UnivariateContinuousLeaf, UnivariateDiscreteLeaf
from .probabilistic_circuit import ProductUnit, SumUnit, ProbabilisticCircuit
Expand Down Expand Up @@ -70,7 +70,8 @@ def uniform_measure_of_simple_event(simple_event: SimpleEvent) -> ProbabilisticC
return uniform_model.probabilistic_circuit


def fully_factorized(variables: Iterable[Variable], means: dict, variances: dict) -> ProbabilisticCircuit:
def fully_factorized(variables: Iterable[Variable], means: Optional[Dict[Continuous, float]] = None,
variances: Optional[Dict[Continuous, float]] = None) -> ProbabilisticCircuit:
"""
Create a fully factorized distribution over a set of variables.
For symbolic variables, the distribution is uniform.
Expand All @@ -82,6 +83,11 @@ def fully_factorized(variables: Iterable[Variable], means: dict, variances: dict
:return: The circuit describing the fully factorized normal distribution
"""

if means is None:
means = {variable: 0. for variable in variables}
if variances is None:
variances = {variable: 1. for variable in variables}

# initialize the root of the circuit
root = ProductUnit()
for variable in variables:
Expand Down
Empty file.
154 changes: 154 additions & 0 deletions test/test_relational/test_rspn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from __future__ import annotations
import unittest
from dataclasses import dataclass

from random_events.set import SetElement
from sqlalchemy import create_engine, select, ForeignKey, Column, Integer, UniqueConstraint, Engine, inspect, Table, \
MetaData
from sqlalchemy.orm import MappedAsDataclass, DeclarativeBase, mapped_column, Mapped, Session, relationship
from typing_extensions import List, Iterable, Type

from probabilistic_model.probabilistic_circuit.nx.helper import fully_factorized
from probabilistic_model.probabilistic_model import ProbabilisticModel


class Base(MappedAsDataclass, DeclarativeBase):

@classmethod
def attributes(cls):
return [column for column in cls.__table__.columns if column.primary_key is False
and not column.foreign_keys ]

@classmethod
def exchangeable_parts(cls):
for table in cls.metadata.tables.values():
for fk in table.foreign_keys:
if fk.column.table == cls.__table__:
yield table

@classmethod
def unique_parts(cls, engine: Engine):
inspector = inspect(engine)
table_name = cls.__tablename__
unique_constraints = inspector.get_unique_constraints(table_name)

table = Table(table_name, MetaData(), autoload_with=engine)
result = []
for constraint in unique_constraints:
constraint_columns = [table.c[col_name] for col_name in constraint['column_names']]
result.append((constraint['name'], constraint_columns))

return result
class Government(Base):
__tablename__ = "Government"

id: Mapped[int] = mapped_column(init=False, primary_key=True)
name: Mapped[str]
form: Mapped[str]

nation_id: Mapped[int] = mapped_column(ForeignKey('Nation.id'), init=False)
nation: Mapped[Nation] = relationship(back_populates="government", single_parent=True)
__table_args__ = (UniqueConstraint("nation_id"),)

class Person(Base):
__tablename__ = "Person"
id: Mapped[int] = mapped_column(init=False, primary_key=True)
name: Mapped[str]
age: Mapped[int]
nation_id: Mapped[int] = mapped_column(ForeignKey('Nation.id'), init=False)
nation: Mapped[Nation] = relationship("Nation", foreign_keys=[nation_id])

class Region(Base):
id: Mapped[int] = mapped_column(init=False, primary_key=True)
name: Mapped[str]
nations: Mapped[List[Nation]] = relationship("Nation", back_populates="region", init=False)
__tablename__ = "Region"

class Nation(Base):
__tablename__ = "Nation"
id: Mapped[int] = mapped_column(init=False, primary_key=True)
region_id: Mapped[int] = mapped_column(ForeignKey('Region.id'), init=False)
region: Mapped[Region] = relationship("Region", foreign_keys=[region_id], back_populates="nations")
government: Mapped[Government] = relationship(back_populates="nation", init=False)
high_gdp: Mapped[bool] # this is an attribute

class Adjacent(Base):
id: Mapped[int] = mapped_column(init=False, primary_key=True)
__tablename__ = "Adjacent"
nation_1_id: Mapped[int] = mapped_column(ForeignKey('Nation.id'), init=False)
nation_2_id: Mapped[int] = mapped_column(ForeignKey('Nation.id'), init=False)
nation_1: Mapped[Nation] = relationship("Nation", foreign_keys=[nation_1_id])
nation_2: Mapped[Nation] = relationship("Nation", foreign_keys=[nation_2_id])

class Conflict(Base):
id: Mapped[int] = mapped_column(init=False, primary_key=True)
__tablename__ = "Conflict"
nation_1_id: Mapped[int] = mapped_column(ForeignKey('Nation.id'), init=False)
nation_2_id: Mapped[int] = mapped_column(ForeignKey('Nation.id'), init=False)
nation_1: Mapped[Nation] = relationship("Nation", foreign_keys=[nation_1_id])
nation_2: Mapped[Nation] = relationship("Nation", foreign_keys=[nation_2_id])

class Supports(Base):
__tablename__ = "Supports"
person_id: Mapped[int] = mapped_column(ForeignKey('Person.id'), primary_key=True, init=False)
person: Mapped[Person] = relationship("Person", foreign_keys=[person_id])

nation_id: Mapped[int] = mapped_column(ForeignKey('Nation.id'), primary_key=True, init=False)
nation: Mapped[Nation] = relationship("Nation", foreign_keys=[nation_id])
value: Mapped[bool]

class RSPNClass:
attributes: set
unique_parts: set
exchangeable_parts: set
relations: set
model: ProbabilisticModel


class ExchangeableDistributionTemplate:
def ground(self, variables):
return fully_factorized(variables)


class RelationalSPN:

table: Type[Base]
session: Session

def __init__(self, table: Type[Base], session: Session):
self.table = table
self.session = session

def learn(self):
...

class RSPNTestCase(unittest.TestCase):

session: Session

@classmethod
def setUpClass(cls):
engine = create_engine('sqlite:///:memory:')
Base.metadata.create_all(engine)
cls.session = Session(engine)

def setUp(self):

na = Region("North America")
usa = Nation(na, True)
usa_gov = Government("Trump", "Republic", usa)
anna = Person("Anna", 20, usa)
bob = Person("Bob", 30, usa)
s1 = Supports(anna, usa, True)
s2 = Supports(bob, usa, False)

self.session.add_all([anna, bob, na, usa, s1, s2])
self.session.commit()

def test_learn(self):
model = RelationalSPN(Person, self.session)
print(*(Region.exchangeable_parts()), sep="\n")
print(*(Government.unique_parts(self.session.get_bind())), sep="\n")

if __name__ == '__main__':
unittest.main()

0 comments on commit d9c3ff0

Please sign in to comment.