forked from pgmpy/pgmpy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
example notebook for extending pgmpy
- Loading branch information
1 parent
d705c4e
commit 16feab9
Showing
2 changed files
with
169 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"source": [ | ||
"It's really easy to extend pgmpy to quickly prototype your ideas. pgmpy has a base abstract class for most of main functionalities like: `BaseInference` for inference, `BaseFactor` for model parameters, `BaseEstimators` for parameter and model learning. For adding a new feature to pgmpy we just need to implement a new class inheriting one of these base classes and then we can use the new class with other functionality of pgmpy.\n", | ||
"\n", | ||
"In this example we will see how to write a new inference algorithm. We will take the example of a very simple algorithm in which we will multiply all the factors/CPD of the network and marginalize over variable to get the desired query." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# A simple Exact inference algorithm\n", | ||
"\n", | ||
"import itertools\n", | ||
"\n", | ||
"from pgmpy.inference.base import Inference\n", | ||
"from pgmpy.factors.discrete import factor_product\n", | ||
"\n", | ||
"class SimpleInference(Inference):\n", | ||
" # By inheriting Inference we can use self.model, self.factors and self.cardinality in our class\n", | ||
" def query(self, var, evidence):\n", | ||
" # self.factors is a dict of the form of {node: [factors_involving_node]}\n", | ||
" factors_list = set(itertools.chain(*self.factors.values()))\n", | ||
" product = factor_product(*factors_list)\n", | ||
" reduced_prod = product.reduce(evidence, inplace=False)\n", | ||
" reduced_prod.normalize()\n", | ||
" var_to_marg = set(self.model.nodes()) - set(var) - set([state[0] for state in evidence])\n", | ||
" marg_prod = reduced_prod.marginalize(var_to_marg, inplace=False)\n", | ||
" return marg_prod" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Defining a model\n", | ||
"\n", | ||
"from pgmpy.models import BayesianModel\n", | ||
"from pgmpy.factors.discrete import TabularCPD\n", | ||
"\n", | ||
"model = BayesianModel([('A', 'J'), ('R', 'J'), ('J', 'Q'), ('J', 'L'), ('G', 'L')])\n", | ||
"cpd_a = TabularCPD('A', 2, values=[[0.2], [0.8]])\n", | ||
"cpd_r = TabularCPD('R', 2, values=[[0.4], [0.6]])\n", | ||
"cpd_j = TabularCPD('J', 2, values=[[0.9, 0.6, 0.7, 0.1],\n", | ||
" [0.1, 0.4, 0.3, 0.9]],\n", | ||
" evidence=['A', 'R'], evidence_card=[2, 2])\n", | ||
"cpd_q = TabularCPD('Q', 2, values=[[0.9, 0.2], [0.1, 0.8]],\n", | ||
" evidence=['J'], evidence_card=[2])\n", | ||
"cpd_l = TabularCPD('L', 2, values=[[0.9, 0.45, 0.8, 0.1],\n", | ||
" [0.1, 0.55, 0.2, 0.9]],\n", | ||
" evidence=['J', 'G'], evidence_card=[2, 2])\n", | ||
"cpd_g = TabularCPD('G', 2, values=[[0.6], [0.4]])\n", | ||
"\n", | ||
"model.add_cpds(cpd_a, cpd_r, cpd_j, cpd_q, cpd_l, cpd_g)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Doing inference with our SimpleInference\n", | ||
"\n", | ||
"infer = SimpleInference(model)\n", | ||
"a = infer.query(var=['A'], evidence=[('J', 0), ('R', 1)])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"+-----+----------+\n", | ||
"| A | phi(A) |\n", | ||
"|-----+----------|\n", | ||
"| A_0 | 0.6000 |\n", | ||
"| A_1 | 0.4000 |\n", | ||
"+-----+----------+\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(a)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"+-----+----------+\n", | ||
"| A | phi(A) |\n", | ||
"|-----+----------|\n", | ||
"| A_0 | 0.6000 |\n", | ||
"| A_1 | 0.4000 |\n", | ||
"+-----+----------+\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Comparing the results with Variable Elimination Algorithm\n", | ||
"\n", | ||
"from pgmpy.inference import VariableElimination\n", | ||
"\n", | ||
"infer = VariableElimination(model)\n", | ||
"result = infer.query(['A'], evidence={'J': 0, 'R': 1})\n", | ||
"print(result['A'])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Similarly we can also create new classes for Factor or CPDs and add them to networks and do inference over it or can write a new estimator class." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"anaconda-cloud": {}, | ||
"kernelspec": { | ||
"display_name": "Python [conda env:ccns]", | ||
"language": "python", | ||
"name": "conda-env-ccns-py" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 1 | ||
} |
This file was deleted.
Oops, something went wrong.