From 7cf9e44dd8b6751afeaa9c8571b021fab22c0f65 Mon Sep 17 00:00:00 2001 From: Kevin P Murphy Date: Fri, 3 Nov 2023 17:31:40 -0700 Subject: [PATCH] fix documentation for pygmodels library --- notebooks/misc/pygmodels_doc.ipynb | 1035 ++++++++++++++++++++++++++++ 1 file changed, 1035 insertions(+) create mode 100644 notebooks/misc/pygmodels_doc.ipynb diff --git a/notebooks/misc/pygmodels_doc.ipynb b/notebooks/misc/pygmodels_doc.ipynb new file mode 100644 index 0000000000..f7079e2655 --- /dev/null +++ b/notebooks/misc/pygmodels_doc.ipynb @@ -0,0 +1,1035 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Probabilistic graphical models\n", + "\n", + "Update examples from\n", + "\n", + "https://d-k-e.github.io/graphical-models/html/d3/db0/md_graphical-models_docs_mdpages_usage.html\n", + "\n", + "First download\n", + "https://github.com/D-K-E/graphical-models/tree/master\n", + "Then `pip install .`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PGModel\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import pygmodels\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.32 0.32\n" + ] + } + ], + "source": [ + "# Example adapted from Darwiche 2009, p. 140\n", + "# A -> B -> C\n", + "\n", + "\n", + "from pygmodels.pgmtype.pgmodel import PGModel\n", + "from pygmodels.gtype.edge import Edge, EdgeType\n", + "from pygmodels.pgmtype.factor import Factor\n", + "from pygmodels.pgmtype.randomvariable import NumCatRVariable\n", + "\n", + "# Example adapted from Darwiche 2009, p. 140\n", + "idata = {\n", + " \"a\": {\"outcome-values\": [True, False]},\n", + " \"b\": {\"outcome-values\": [True, False]},\n", + " \"c\": {\"outcome-values\": [True, False]},\n", + "}\n", + "a = NumCatRVariable(node_id=\"a\", input_data=idata[\"a\"], marginal_distribution=lambda x: 0.6 if x else 0.4)\n", + "b = NumCatRVariable(node_id=\"b\", input_data=idata[\"b\"], marginal_distribution=lambda x: 0.5 if x else 0.5)\n", + "c = NumCatRVariable(node_id=\"c\", input_data=idata[\"c\"], marginal_distribution=lambda x: 0.5 if x else 0.5)\n", + "ab = Edge(\n", + " edge_id=\"ab\",\n", + " edge_type=EdgeType.UNDIRECTED,\n", + " start_node=a,\n", + " end_node=b,\n", + ")\n", + "bc = Edge(\n", + " edge_id=\"bc\",\n", + " edge_type=EdgeType.UNDIRECTED,\n", + " start_node=b,\n", + " end_node=c,\n", + ")\n", + "\n", + "\n", + "def phi_ba(scope_product):\n", + " \"\"\"\"\"\"\n", + " ss = set(scope_product)\n", + " if ss == set([(\"a\", True), (\"b\", True)]):\n", + " return 0.9\n", + " elif ss == set([(\"a\", True), (\"b\", False)]):\n", + " return 0.1\n", + " elif ss == set([(\"a\", False), (\"b\", True)]):\n", + " return 0.2\n", + " elif ss == set([(\"a\", False), (\"b\", False)]):\n", + " return 0.8\n", + " else:\n", + " raise ValueError(\"product error\")\n", + "\n", + "\n", + "def phi_cb(scope_product):\n", + " \"\"\"\"\"\"\n", + " ss = set(scope_product)\n", + " if ss == set([(\"c\", True), (\"b\", True)]):\n", + " return 0.3\n", + " elif ss == set([(\"c\", True), (\"b\", False)]):\n", + " return 0.5\n", + " elif ss == set([(\"c\", False), (\"b\", True)]):\n", + " return 0.7\n", + " elif ss == set([(\"c\", False), (\"b\", False)]):\n", + " return 0.5\n", + " else:\n", + " raise ValueError(\"product error\")\n", + "\n", + "\n", + "def phi_a(scope_product):\n", + " s = set(scope_product)\n", + " if s == set([(\"a\", True)]):\n", + " return 0.6\n", + " elif s == set([(\"a\", False)]):\n", + " return 0.4\n", + " else:\n", + " raise ValueError(\"product error\")\n", + "\n", + "\n", + "ba_f = Factor(gid=\"ba\", scope_vars=set([b, a]), factor_fn=phi_ba)\n", + "cb_f = Factor(gid=\"cb\", scope_vars=set([c, b]), factor_fn=phi_cb)\n", + "a_f = Factor(gid=\"a\", scope_vars=set([a]), factor_fn=phi_a)\n", + "pgm = PGModel(\n", + " gid=\"pgm\",\n", + " nodes=set([a, b, c]),\n", + " edges=set([ab, bc]),\n", + " factors=set([ba_f, cb_f, a_f]),\n", + ")\n", + "evidences = set([(\"a\", True)])\n", + "queries = set([c])\n", + "product_factor, a = pgm.cond_prod_by_variable_elimination(queries, evidences)\n", + "val = round(product_factor.phi_normal(set([(\"c\", True)])), 4)\n", + "expected = 0.32\n", + "print(val, expected)\n", + "assert np.allclose(val, expected)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Bayesian network" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.774" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Bayesian network\n", + "\n", + "\n", + "from pygmodels.pgmodel.bayesian import BayesianNetwork\n", + "from pygmodels.gtype.edge import Edge, EdgeType\n", + "from pygmodels.pgmtype.factor import Factor\n", + "from pygmodels.pgmtype.randomvariable import NumCatRVariable\n", + "\n", + "# data and nodes\n", + "idata = {\"outcome-values\": [True, False]}\n", + "\n", + "C = NumCatRVariable(node_id=\"C\", input_data=idata, marginal_distribution=lambda x: 0.5)\n", + "E = NumCatRVariable(node_id=\"E\", input_data=idata, marginal_distribution=lambda x: 0.5)\n", + "F = NumCatRVariable(node_id=\"F\", input_data=idata, marginal_distribution=lambda x: 0.5)\n", + "D = NumCatRVariable(node_id=\"D\", input_data=idata, marginal_distribution=lambda x: 0.5)\n", + "\n", + "# edges\n", + "CE = Edge(\n", + " edge_id=\"CE\",\n", + " start_node=C,\n", + " end_node=E,\n", + " edge_type=EdgeType.DIRECTED,\n", + ")\n", + "ED = Edge(\n", + " edge_id=\"ED\",\n", + " start_node=E,\n", + " end_node=D,\n", + " edge_type=EdgeType.DIRECTED,\n", + ")\n", + "EF = Edge(\n", + " edge_id=\"EF\",\n", + " start_node=E,\n", + " end_node=F,\n", + " edge_type=EdgeType.DIRECTED,\n", + ")\n", + "\n", + "# define factor functions\n", + "\n", + "\n", + "def phi_c(scope_product):\n", + " ss = set(scope_product)\n", + " if ss == set([(\"C\", True)]):\n", + " return 0.8\n", + " elif ss == set([(\"C\", False)]):\n", + " return 0.2\n", + " else:\n", + " raise ValueError(\"scope product unknown\")\n", + "\n", + "\n", + "def phi_ec(scope_product):\n", + " ss = set(scope_product)\n", + " if ss == set([(\"C\", True), (\"E\", True)]):\n", + " return 0.9\n", + " elif ss == set([(\"C\", True), (\"E\", False)]):\n", + " return 0.1\n", + " elif ss == set([(\"C\", False), (\"E\", True)]):\n", + " return 0.7\n", + " elif ss == set([(\"C\", False), (\"E\", False)]):\n", + " return 0.3\n", + " else:\n", + " raise ValueError(\"scope product unknown\")\n", + "\n", + "\n", + "def phi_fe(scope_product):\n", + " ss = set(scope_product)\n", + " if ss == set([(\"E\", True), (\"F\", True)]):\n", + " return 0.9\n", + " elif ss == set([(\"E\", True), (\"F\", False)]):\n", + " return 0.1\n", + " elif ss == set([(\"E\", False), (\"F\", True)]):\n", + " return 0.5\n", + " elif ss == set([(\"E\", False), (\"F\", False)]):\n", + " return 0.5\n", + " else:\n", + " raise ValueError(\"scope product unknown\")\n", + "\n", + "\n", + "def phi_de(scope_product):\n", + " ss = set(scope_product)\n", + " if ss == set([(\"E\", True), (\"D\", True)]):\n", + " return 0.7\n", + " elif ss == set([(\"E\", True), (\"D\", False)]):\n", + " return 0.3\n", + " elif ss == set([(\"E\", False), (\"D\", True)]):\n", + " return 0.4\n", + " elif ss == set([(\"E\", False), (\"D\", False)]):\n", + " return 0.6\n", + " else:\n", + " raise ValueError(\"scope product unknown\")\n", + "\n", + "\n", + "# instantiate factors with given factor function and implied random variables\n", + "CE_f = Factor(gid=\"CE_f\", scope_vars=set([C, E]), factor_fn=phi_ec)\n", + "C_f = Factor(gid=\"C_f\", scope_vars=set([C]), factor_fn=phi_c)\n", + "FE_f = Factor(gid=\"FE_f\", scope_vars=set([F, E]), factor_fn=phi_fe)\n", + "DE_f = Factor(gid=\"DE_f\", scope_vars=set([D, E]), factor_fn=phi_de)\n", + "bayes_n = BayesianNetwork(\n", + " gid=\"ba\",\n", + " nodes=set([C, E, D, F]),\n", + " edges=set([EF, CE, ED]),\n", + " factors=set([C_f, DE_f, CE_f, FE_f]),\n", + ")\n", + "query_vars = set([E])\n", + "evidences = set([(\"F\", True)])\n", + "probs, alpha = bayes_n.cond_prod_by_variable_elimination(query_vars, evidences=evidences)\n", + "query_value = set([(\"E\", True)])\n", + "round(probs.phi(query_value), 4)\n", + "# 0.774" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Markov network" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Markov network\n", + "# Example from Koller and Friedman, p104\n", + "# A - B\n", + "# | |\n", + "# C - D\n", + "\n", + "from pygmodels.pgmodel.markov import MarkovNetwork\n", + "from pygmodels.gtype.edge import Edge, EdgeType\n", + "from pygmodels.pgmtype.factor import Factor\n", + "from pygmodels.pgmtype.randomvariable import NumCatRVariable\n", + "\n", + "\n", + "def make_markov_net():\n", + "\n", + " # define data and random variable nodes\n", + " idata = {\n", + " \"A\": {\"outcome-values\": [True, False]},\n", + " \"B\": {\"outcome-values\": [True, False]},\n", + " \"C\": {\"outcome-values\": [True, False]},\n", + " \"D\": {\"outcome-values\": [True, False]},\n", + " }\n", + "\n", + " A = NumCatRVariable(node_id=\"A\", input_data=idata[\"A\"], marginal_distribution=lambda x: 0.5)\n", + " B = NumCatRVariable(node_id=\"B\", input_data=idata[\"B\"], marginal_distribution=lambda x: 0.5)\n", + " C = NumCatRVariable(node_id=\"C\", input_data=idata[\"C\"], marginal_distribution=lambda x: 0.5)\n", + " D = NumCatRVariable(node_id=\"D\", input_data=idata[\"D\"], marginal_distribution=lambda x: 0.5)\n", + "\n", + " # define edges\n", + " AB = Edge(\n", + " edge_id=\"AB\",\n", + " edge_type=EdgeType.UNDIRECTED,\n", + " start_node=A,\n", + " end_node=B,\n", + " )\n", + " AD = Edge(\n", + " edge_id=\"AD\",\n", + " edge_type=EdgeType.UNDIRECTED,\n", + " start_node=A,\n", + " end_node=D,\n", + " )\n", + " DC = Edge(\n", + " edge_id=\"DC\",\n", + " edge_type=EdgeType.UNDIRECTED,\n", + " start_node=D,\n", + " end_node=C,\n", + " )\n", + " BC = Edge(\n", + " edge_id=\"BC\",\n", + " edge_type=EdgeType.UNDIRECTED,\n", + " start_node=B,\n", + " end_node=C,\n", + " )\n", + "\n", + " # define factor functions\n", + "\n", + " def phi_AB(scope_product):\n", + " \"\"\"\"\"\"\n", + " ss = frozenset(scope_product)\n", + " if ss == frozenset([(\"A\", False), (\"B\", False)]):\n", + " return 30.0\n", + " elif ss == frozenset([(\"A\", False), (\"B\", True)]):\n", + " return 5.0\n", + " elif ss == frozenset([(\"A\", True), (\"B\", False)]):\n", + " return 1.0\n", + " elif ss == frozenset([(\"A\", True), (\"B\", True)]):\n", + " return 10.0\n", + " else:\n", + " raise ValueError(\"product error\")\n", + "\n", + " def phi_BC(scope_product):\n", + " \"\"\"\"\"\"\n", + " ss = frozenset(scope_product)\n", + " if ss == frozenset([(\"B\", False), (\"C\", False)]):\n", + " return 100.0\n", + " elif ss == frozenset([(\"B\", False), (\"C\", True)]):\n", + " return 1.0\n", + " elif ss == frozenset([(\"B\", True), (\"C\", False)]):\n", + " return 1.0\n", + " elif ss == frozenset([(\"B\", True), (\"C\", True)]):\n", + " return 100.0\n", + " else:\n", + " raise ValueError(\"product error\")\n", + "\n", + " def phi_CD(scope_product):\n", + " \"\"\"\"\"\"\n", + " ss = frozenset(scope_product)\n", + " if ss == frozenset([(\"C\", False), (\"D\", False)]):\n", + " return 1.0\n", + " elif ss == frozenset([(\"C\", False), (\"D\", True)]):\n", + " return 100.0\n", + " elif ss == frozenset([(\"C\", True), (\"D\", False)]):\n", + " return 100.0\n", + " elif ss == frozenset([(\"C\", True), (\"D\", True)]):\n", + " return 1.0\n", + " else:\n", + " raise ValueError(\"product error\")\n", + "\n", + " def phi_DA(scope_product):\n", + " \"\"\"\"\"\"\n", + " ss = frozenset(scope_product)\n", + " if ss == frozenset([(\"D\", False), (\"A\", False)]):\n", + " return 100.0\n", + " elif ss == frozenset([(\"D\", False), (\"A\", True)]):\n", + " return 1.0\n", + " elif ss == frozenset([(\"D\", True), (\"A\", False)]):\n", + " return 1.0\n", + " elif ss == frozenset([(\"D\", True), (\"A\", True)]):\n", + " return 100.0\n", + " else:\n", + " raise ValueError(\"product error\")\n", + "\n", + " # instantiate factors with factor functions and implied\n", + " # random variables in scope\n", + "\n", + " AB_f = Factor(gid=\"ab_f\", scope_vars=set([A, B]), factor_fn=phi_AB)\n", + " BC_f = Factor(gid=\"bc_f\", scope_vars=set([B, C]), factor_fn=phi_BC)\n", + " CD_f = Factor(gid=\"cd_f\", scope_vars=set([C, D]), factor_fn=phi_CD)\n", + " DA_f = Factor(gid=\"da_f\", scope_vars=set([D, A]), factor_fn=phi_DA)\n", + "\n", + " # instantiate markov network and make a query\n", + " mnetwork = MarkovNetwork(\n", + " gid=\"mnet\",\n", + " nodes=set([A, B, C, D]),\n", + " edges=set([AB, AD, BC, DC]),\n", + " factors=set([DA_f, CD_f, BC_f, AB_f]),\n", + " )\n", + "\n", + " return mnetwork, A, B, C, D" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.69" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mnetwork, A, B, C, D = make_markov_net()\n", + "\n", + "queries = set([A, B])\n", + "evidences = set()\n", + "prob, a = mnetwork.cond_prod_by_variable_elimination(queries, evidences)\n", + "q2 = set([(\"A\", False), (\"B\", True)])\n", + "round(prob.phi_normal(q2), 2)\n", + "# 0.69" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.74 0.74\n" + ] + } + ], + "source": [ + "mnetwork, A, B, C, D = make_markov_net()\n", + "\n", + "\n", + "# p(B=1)=0.26\n", + "queries = set([B])\n", + "evidences = set()\n", + "prob, a = mnetwork.cond_prod_by_variable_elimination(queries, evidences)\n", + "q2 = set([(\"B\", True)])\n", + "val = round(prob.phi_normal(q2), 2)\n", + "expected = 0.74 # 0.26 # Typo in book?\n", + "print(val, expected)\n", + "assert np.allclose(val, expected)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.06 0.06\n" + ] + } + ], + "source": [ + "mnetwork, A, B, C, D = make_markov_net()\n", + "\n", + "# p(B=1|C=0)=0.06\n", + "queries = set([B])\n", + "evidences = set([(\"C\", False)])\n", + "prob, a = mnetwork.cond_prod_by_variable_elimination(queries, evidences)\n", + "q2 = set([(\"B\", True)])\n", + "val = round(prob.phi_normal(q2), 2)\n", + "expected = 0.06\n", + "print(val, expected)\n", + "assert np.allclose(val, expected)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Conditional random field" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "frozenset({('X_2', False), ('X_1', True), ('X_3', False)})\n", + "1.0 1.0\n" + ] + } + ], + "source": [ + "from pygmodels.pgmtype.randomvariable import NumCatRVariable\n", + "from pygmodels.pgmodel.markov import ConditionalRandomField\n", + "from pygmodels.gtype.edge import Edge, EdgeType\n", + "from pygmodels.pgmtype.factor import Factor\n", + "import math\n", + "from random import choice\n", + "\n", + "# Example from Koller ad Friedman p144\n", + "# phi(xi, y) = exp(wi * ind(xii=1, y=1))\n", + "# phi(y) = exp(w0 * ind(y=1))\n", + "\n", + "\n", + "# define data and nodes\n", + "idata = {\"A\": {\"outcome-values\": [True, False]}}\n", + "\n", + "# from Koller, Friedman 2009, p. 144-145, example 4.20\n", + "X_1 = NumCatRVariable(node_id=\"X_1\", input_data=idata[\"A\"], marginal_distribution=lambda x: 0.5)\n", + "X_2 = NumCatRVariable(node_id=\"X_2\", input_data=idata[\"A\"], marginal_distribution=lambda x: 0.5)\n", + "X_3 = NumCatRVariable(node_id=\"X_3\", input_data=idata[\"A\"], marginal_distribution=lambda x: 0.5)\n", + "Y_1 = NumCatRVariable(node_id=\"Y_1\", input_data=idata[\"A\"], marginal_distribution=lambda x: 0.5)\n", + "\n", + "# define edges\n", + "\n", + "X1_Y1 = Edge(\n", + " edge_id=\"X1_Y1\",\n", + " edge_type=EdgeType.UNDIRECTED,\n", + " start_node=X_1,\n", + " end_node=Y_1,\n", + ")\n", + "X2_Y1 = Edge(\n", + " edge_id=\"X2_Y1\",\n", + " edge_type=EdgeType.UNDIRECTED,\n", + " start_node=X_2,\n", + " end_node=Y_1,\n", + ")\n", + "X3_Y1 = Edge(\n", + " edge_id=\"X3_Y1\",\n", + " edge_type=EdgeType.UNDIRECTED,\n", + " start_node=X_3,\n", + " end_node=Y_1,\n", + ")\n", + "\n", + "# define factor functions\n", + "\n", + "\n", + "def phi_X1_Y1(scope_product):\n", + " \"\"\"\"\"\"\n", + " w = 0.5\n", + " ss = frozenset(scope_product)\n", + " if ss == frozenset([(\"X_1\", True), (\"Y_1\", True)]):\n", + " return math.exp(1.0 * w)\n", + " else:\n", + " return math.exp(0.0)\n", + "\n", + "\n", + "def phi_X2_Y1(scope_product):\n", + " \"\"\"\"\"\"\n", + " w = 5.0\n", + " ss = frozenset(scope_product)\n", + " if ss == frozenset([(\"X_2\", True), (\"Y_1\", True)]):\n", + " return math.exp(1.0 * w)\n", + " else:\n", + " return math.exp(0.0)\n", + "\n", + "\n", + "def phi_X3_Y1(scope_product):\n", + " \"\"\"\"\"\"\n", + " w = 9.4\n", + " ss = frozenset(scope_product)\n", + " if ss == frozenset([(\"X_3\", True), (\"Y_1\", True)]):\n", + " return math.exp(1.0 * w)\n", + " else:\n", + " return math.exp(0.0)\n", + "\n", + "\n", + "def phi_Y1(scope_product):\n", + " \"\"\"\"\"\"\n", + " w = 0.6\n", + " ss = frozenset(scope_product)\n", + " if ss == frozenset([(\"Y_1\", True)]):\n", + " return math.exp(1.0 * w)\n", + " else:\n", + " return math.exp(0.0)\n", + "\n", + "\n", + "# instantiate factors with factor functions and implied random variables\n", + "X1_Y1_f = Factor(gid=\"x1_y1_f\", scope_vars=set([X_1, Y_1]), factor_fn=phi_X1_Y1)\n", + "X2_Y1_f = Factor(gid=\"x2_y1_f\", scope_vars=set([X_2, Y_1]), factor_fn=phi_X2_Y1)\n", + "X3_Y1_f = Factor(gid=\"x3_y1_f\", scope_vars=set([X_3, Y_1]), factor_fn=phi_X3_Y1)\n", + "Y1_f = Factor(gid=\"y1_f\", scope_vars=set([Y_1]), factor_fn=phi_Y1)\n", + "\n", + "\n", + "# Instantiate conditional random field and make a query\n", + "crf_koller = ConditionalRandomField(\n", + " \"crf\",\n", + " observed_vars=set([X_1, X_2, X_3]),\n", + " target_vars=set([Y_1]),\n", + " edges=set([X1_Y1, X2_Y1, X3_Y1]),\n", + " factors=set([X1_Y1_f, X2_Y1_f, X3_Y1_f, Y1_f]),\n", + ")\n", + "\n", + "\n", + "evidence = set([(\"Y_1\", False)])\n", + "query_vars = set([X_1, X_2, X_3])\n", + "query = frozenset(\n", + " [\n", + " (\"X_1\", choice([False, True])),\n", + " (\"X_2\", choice([False, True])),\n", + " (\"X_3\", choice([False, True])),\n", + " ]\n", + ")\n", + "print(query)\n", + "out, a1 = crf_koller.cond_prod_by_variable_elimination(queries=query_vars, evidences=evidence)\n", + "val = out.phi(query)\n", + "expected = 1.0\n", + "print(val, expected)\n", + "assert np.allclose(val, expected)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Chain graph" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.6 0.6000049070121203\n" + ] + } + ], + "source": [ + "# Chain graph\n", + "# Based on Cowell et al, 1999\n", + "# \"Probabilistic Networks and Expert Systems\"\n", + "# https://link.springer.com/book/10.1007/b97670\n", + "# p110\n", + "# For a diagram, see https://gist.github.com/murphyk/e157531793f87f3d1eb9542234e6bc7d\n", + "\n", + "# Parameters are same as directed Asia network on p21\n", + "# except for the chain component p(C,D | B,E) = phi(C,D,B) phi(C,B,E) 1/Z(B,E)\n", + "# where Z(B,E) is derived by normalizing the numerator and C=cough.\n", + "# In the code below, these variables are renamed as p(H,I|B,D) = phi(H,I,B) phi(H,B,D) phi(B,D)\n", + "# where H=Cough, I=Dysponea, B=Bronchitis, D=Either.\n", + "\n", + "# In the book, the evidence is (a, b, xbar) which in the notation of this code is\n", + "# (E=True, A=True, G=False). The query variable is B.\n", + "\n", + "from pygmodels.pgmodel.lwfchain import LWFChainGraph\n", + "from pygmodels.gtype.edge import Edge, EdgeType\n", + "from pygmodels.pgmtype.factor import Factor\n", + "from pygmodels.pgmtype.randomvariable import NumCatRVariable\n", + "\n", + "\n", + "# define data and nodes\n", + "idata = {\"outcome-values\": [True, False]}\n", + "A = NumCatRVariable(node_id=\"A\", input_data=idata, marginal_distribution=lambda x: 0.5)\n", + "B = NumCatRVariable(node_id=\"B\", input_data=idata, marginal_distribution=lambda x: 0.5)\n", + "C = NumCatRVariable(node_id=\"C\", input_data=idata, marginal_distribution=lambda x: 0.5)\n", + "D = NumCatRVariable(node_id=\"D\", input_data=idata, marginal_distribution=lambda x: 0.5)\n", + "E = NumCatRVariable(node_id=\"E\", input_data=idata, marginal_distribution=lambda x: 0.5)\n", + "F = NumCatRVariable(node_id=\"F\", input_data=idata, marginal_distribution=lambda x: 0.5)\n", + "G = NumCatRVariable(node_id=\"G\", input_data=idata, marginal_distribution=lambda x: 0.5)\n", + "H = NumCatRVariable(node_id=\"H\", input_data=idata, marginal_distribution=lambda x: 0.5)\n", + "I = NumCatRVariable(node_id=\"I\", input_data=idata, marginal_distribution=lambda x: 0.5)\n", + "K = NumCatRVariable(node_id=\"K\", input_data=idata, marginal_distribution=lambda x: 0.5)\n", + "L = NumCatRVariable(node_id=\"L\", input_data=idata, marginal_distribution=lambda x: 0.5)\n", + "\n", + "\n", + "AB_c = Edge(\n", + " edge_id=\"AB\",\n", + " start_node=A,\n", + " end_node=B,\n", + " edge_type=EdgeType.DIRECTED,\n", + ")\n", + "AC_c = Edge(\n", + " edge_id=\"AC\",\n", + " start_node=A,\n", + " end_node=C,\n", + " edge_type=EdgeType.DIRECTED,\n", + ")\n", + "CD_c = Edge(\n", + " edge_id=\"CD\",\n", + " start_node=C,\n", + " end_node=D,\n", + " edge_type=EdgeType.DIRECTED,\n", + ")\n", + "EF_c = Edge(\n", + " edge_id=\"EF\",\n", + " start_node=E,\n", + " end_node=F,\n", + " edge_type=EdgeType.DIRECTED,\n", + ")\n", + "FD_c = Edge(\n", + " edge_id=\"FD\",\n", + " start_node=F,\n", + " end_node=D,\n", + " edge_type=EdgeType.DIRECTED,\n", + ")\n", + "DG_c = Edge(\n", + " edge_id=\"DG\",\n", + " start_node=D,\n", + " end_node=G,\n", + " edge_type=EdgeType.DIRECTED,\n", + ")\n", + "DH_c = Edge(\n", + " edge_id=\"DH\",\n", + " start_node=D,\n", + " end_node=H,\n", + " edge_type=EdgeType.DIRECTED,\n", + ")\n", + "BH_c = Edge(\n", + " edge_id=\"BH\",\n", + " start_node=B,\n", + " end_node=H,\n", + " edge_type=EdgeType.DIRECTED,\n", + ")\n", + "BI_c = Edge(\n", + " edge_id=\"BI\",\n", + " start_node=B,\n", + " end_node=I,\n", + " edge_type=EdgeType.DIRECTED,\n", + ")\n", + "HI_c = Edge(\n", + " edge_id=\"HI\",\n", + " start_node=H,\n", + " end_node=I,\n", + " edge_type=EdgeType.UNDIRECTED,\n", + ")\n", + "\n", + "# define factor functions\n", + "\n", + "\n", + "def phi_e(scope_product):\n", + " \"Visit to Asia factor p(a)\"\n", + " ss = set(scope_product)\n", + " if ss == set([(\"E\", True)]):\n", + " return 0.01\n", + " elif ss == set([(\"E\", False)]):\n", + " return 0.99\n", + " else:\n", + " raise ValueError(\"Unknown scope product\")\n", + "\n", + "\n", + "def phi_fe(scope_product):\n", + " \"Tuberculosis | Visit to Asia factor p(t,a)\"\n", + " ss = set(scope_product)\n", + " if ss == set([(\"F\", True), (\"E\", True)]):\n", + " return 0.05\n", + " elif ss == set([(\"F\", False), (\"E\", True)]):\n", + " return 0.95\n", + " elif ss == set([(\"F\", True), (\"E\", False)]):\n", + " return 0.01\n", + " elif ss == set([(\"F\", False), (\"E\", False)]):\n", + " return 0.99\n", + " else:\n", + " raise ValueError(\"Unknown scope product\")\n", + "\n", + "\n", + "def phi_dg(scope_product):\n", + " \"either tuberculosis or lung cancer | x ray p(e,x)\"\n", + " ss = set(scope_product)\n", + " if ss == set([(\"D\", True), (\"G\", True)]):\n", + " return 0.98\n", + " elif ss == set([(\"D\", False), (\"G\", True)]):\n", + " return 0.05\n", + " elif ss == set([(\"D\", True), (\"G\", False)]):\n", + " return 0.02\n", + " elif ss == set([(\"D\", False), (\"G\", False)]):\n", + " return 0.95\n", + " else:\n", + " raise ValueError(\"Unknown scope product\")\n", + "\n", + "\n", + "def phi_a(scope_product):\n", + " \"smoke factor p(s)\"\n", + " ss = set(scope_product)\n", + " if ss == set([(\"A\", True)]):\n", + " return 0.5\n", + " elif ss == set([(\"A\", False)]):\n", + " return 0.5\n", + " else:\n", + " raise ValueError(\"Unknown scope product\")\n", + "\n", + "\n", + "def phi_ab(scope_product):\n", + " \"smoke given bronchitis p(s,b)\"\n", + " ss = set(scope_product)\n", + " if ss == set([(\"A\", True), (\"B\", True)]):\n", + " return 0.6\n", + " elif ss == set([(\"A\", False), (\"B\", True)]):\n", + " return 0.3\n", + " elif ss == set([(\"A\", True), (\"B\", False)]):\n", + " return 0.4\n", + " elif ss == set([(\"A\", False), (\"B\", False)]):\n", + " return 0.7\n", + " else:\n", + " raise ValueError(\"Unknown scope product\")\n", + "\n", + "\n", + "def phi_ac(scope_product):\n", + " \"lung cancer given smoke p(s,l)\"\n", + " ss = set(scope_product)\n", + " if ss == set([(\"A\", True), (\"C\", True)]):\n", + " return 0.1\n", + " elif ss == set([(\"A\", False), (\"C\", True)]):\n", + " return 0.01\n", + " elif ss == set([(\"A\", True), (\"C\", False)]):\n", + " return 0.9\n", + " elif ss == set([(\"A\", False), (\"C\", False)]):\n", + " return 0.99\n", + " else:\n", + " raise ValueError(\"Unknown scope product\")\n", + "\n", + "\n", + "def phi_cdf(scope_product):\n", + " \"either tuberculosis or lung given lung cancer and tuberculosis p(e, l, t)\"\n", + " ss = set(scope_product)\n", + " if ss == set([(\"C\", True), (\"D\", True), (\"F\", True)]):\n", + " return 1\n", + " elif ss == set([(\"C\", True), (\"D\", False), (\"F\", True)]):\n", + " return 0\n", + " elif ss == set([(\"C\", False), (\"D\", True), (\"F\", True)]):\n", + " return 1\n", + " elif ss == set([(\"C\", False), (\"D\", False), (\"F\", True)]):\n", + " return 0\n", + " elif ss == set([(\"C\", True), (\"D\", True), (\"F\", False)]):\n", + " return 1\n", + " elif ss == set([(\"C\", True), (\"D\", False), (\"F\", False)]):\n", + " return 0\n", + " elif ss == set([(\"C\", False), (\"D\", True), (\"F\", False)]):\n", + " return 0\n", + " elif ss == set([(\"C\", False), (\"D\", False), (\"F\", False)]):\n", + " return 1\n", + " else:\n", + " raise ValueError(\"Unknown scope product\")\n", + "\n", + "\n", + "def phi_ihb(scope_product):\n", + " \"cough, dyspnoea, bronchitis I, H, B p(c,d,b)\"\n", + " # In book this is phi(C,D,B) in table 6.4\n", + " ss = set(scope_product)\n", + " if ss == set([(\"H\", True), (\"I\", True), (\"B\", True)]):\n", + " return 16\n", + " elif ss == set([(\"H\", True), (\"I\", False), (\"B\", True)]):\n", + " return 1\n", + " elif ss == set([(\"H\", False), (\"I\", True), (\"B\", True)]):\n", + " return 4\n", + " elif ss == set([(\"H\", False), (\"I\", False), (\"B\", True)]):\n", + " return 1\n", + " elif ss == set([(\"H\", True), (\"I\", True), (\"B\", False)]):\n", + " return 2\n", + " elif ss == set([(\"H\", True), (\"I\", False), (\"B\", False)]):\n", + " return 1\n", + " elif ss == set([(\"H\", False), (\"I\", True), (\"B\", False)]):\n", + " return 1\n", + " elif ss == set([(\"H\", False), (\"I\", False), (\"B\", False)]):\n", + " return 1\n", + " else:\n", + " raise ValueError(\"Unknown scope product\")\n", + "\n", + "\n", + "def phi_hbd(scope_product):\n", + " \"cough, either tuberculosis or lung cancer, bronchitis D, H, B p(c,b,e)\"\n", + " # In book this is phi(C,B,E) in table 6.4\n", + " ss = set(scope_product)\n", + " if ss == set([(\"H\", True), (\"D\", True), (\"B\", True)]):\n", + " return 5\n", + " elif ss == set([(\"H\", True), (\"D\", False), (\"B\", True)]):\n", + " return 2\n", + " elif ss == set([(\"H\", False), (\"D\", True), (\"B\", True)]):\n", + " return 1\n", + " elif ss == set([(\"H\", False), (\"D\", False), (\"B\", True)]):\n", + " return 1\n", + " elif ss == set([(\"H\", True), (\"D\", True), (\"B\", False)]):\n", + " return 3\n", + " elif ss == set([(\"H\", True), (\"D\", False), (\"B\", False)]):\n", + " return 1\n", + " elif ss == set([(\"H\", False), (\"D\", True), (\"B\", False)]):\n", + " return 1\n", + " elif ss == set([(\"H\", False), (\"D\", False), (\"B\", False)]):\n", + " return 1\n", + " else:\n", + " raise ValueError(\"Unknown scope product\")\n", + "\n", + "\n", + "def phi_bd(scope_product):\n", + " \"bronchitis, either tuberculosis or lung cancer B, D p(b,e)\"\n", + " # In book this is phi(B,E) in table 6.4\n", + " ss = set(scope_product)\n", + " if ss == set([(\"B\", True), (\"D\", True)]):\n", + " return 1 / 90\n", + " elif ss == set([(\"B\", False), (\"D\", True)]):\n", + " return 1 / 11\n", + " elif ss == set([(\"B\", True), (\"D\", False)]):\n", + " return 1 / 39\n", + " elif ss == set([(\"B\", False), (\"D\", False)]):\n", + " return 1 / 5\n", + " else:\n", + " raise ValueError(\"Unknown scope product\")\n", + "\n", + "\n", + "# instantiate factors with factor functions and implied random\n", + "# variables in scope\n", + "\n", + "E_cf = Factor(gid=\"E_cf\", scope_vars=set([E]), factor_fn=phi_e)\n", + "EF_cf = Factor(gid=\"EF_cf\", scope_vars=set([E, F]), factor_fn=phi_fe)\n", + "DG_cf = Factor(gid=\"DG_cf\", scope_vars=set([D, G]), factor_fn=phi_dg)\n", + "A_cf = Factor(gid=\"A_cf\", scope_vars=set([A]), factor_fn=phi_a)\n", + "AB_cf = Factor(gid=\"AB_cf\", scope_vars=set([A, B]), factor_fn=phi_ab)\n", + "AC_cf = Factor(gid=\"AC_cf\", scope_vars=set([A, C]), factor_fn=phi_ac)\n", + "CDF_cf = Factor(gid=\"CDF_cf\", scope_vars=set([D, C, F]), factor_fn=phi_cdf)\n", + "\n", + "IHB_cf = Factor(gid=\"IHB_cf\", scope_vars=set([H, I, B]), factor_fn=phi_ihb)\n", + "\n", + "HBD_cf = Factor(gid=\"HBD_cf\", scope_vars=set([H, D, B]), factor_fn=phi_hbd)\n", + "BD_cf = Factor(gid=\"BD_cf\", scope_vars=set([D, B]), factor_fn=phi_bd)\n", + "\n", + "\n", + "# instantiate lwf chain graph and make a query\n", + "cowell = LWFChainGraph(\n", + " gid=\"cowell\",\n", + " nodes=set([A, B, C, D, E, F, G, H, I]),\n", + " edges=set([AB_c, AC_c, CD_c, EF_c, FD_c, DG_c, DH_c, BH_c, BI_c, HI_c]),\n", + " factors=set([E_cf, EF_cf, DG_cf, A_cf, AB_cf, AC_cf, CDF_cf, IHB_cf, HBD_cf, BD_cf]),\n", + ")\n", + "evidences = set([(\"E\", True), (\"A\", True), (\"G\", False)])\n", + "\n", + "final_factor, a = cowell.cond_prod_by_variable_elimination(set([B]), evidences)\n", + "\n", + "val = round(final_factor.phi_normal(set([(\"B\", True)])), 4) # 0.60\n", + "\n", + "pots = np.array([2.4455, 1.6303]) # from table 6.10\n", + "expected = pots[0] / np.sum(pots)\n", + "\n", + "assert np.allclose(val, expected)\n", + "print(val, expected)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Factor: d69e2a81-63a9-4cd7-a573-c9eba4ef3b5f\n", + "Scope variables: {}Factor function: .psi at 0x7fa168a00790>\n", + "\n" + ] + } + ], + "source": [ + "print(a)\n", + "print(type(product_factor))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.6000049070121203\n" + ] + } + ], + "source": [ + "pots = np.array([2.4455, 1.6303])\n", + "print(pots[0] / np.sum(pots))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}