Skip to content

Commit

Permalink
example notebook for extending pgmpy
Browse files Browse the repository at this point in the history
  • Loading branch information
ankurankan committed Oct 15, 2016
1 parent d705c4e commit 16feab9
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 35 deletions.
169 changes: 169 additions & 0 deletions examples/Extending pgmpy.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"It's really easy to extend pgmpy to quickly prototype your ideas. pgmpy has a base abstract class for most of main functionalities like: `BaseInference` for inference, `BaseFactor` for model parameters, `BaseEstimators` for parameter and model learning. For adding a new feature to pgmpy we just need to implement a new class inheriting one of these base classes and then we can use the new class with other functionality of pgmpy.\n",
"\n",
"In this example we will see how to write a new inference algorithm. We will take the example of a very simple algorithm in which we will multiply all the factors/CPD of the network and marginalize over variable to get the desired query."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# A simple Exact inference algorithm\n",
"\n",
"import itertools\n",
"\n",
"from pgmpy.inference.base import Inference\n",
"from pgmpy.factors.discrete import factor_product\n",
"\n",
"class SimpleInference(Inference):\n",
" # By inheriting Inference we can use self.model, self.factors and self.cardinality in our class\n",
" def query(self, var, evidence):\n",
" # self.factors is a dict of the form of {node: [factors_involving_node]}\n",
" factors_list = set(itertools.chain(*self.factors.values()))\n",
" product = factor_product(*factors_list)\n",
" reduced_prod = product.reduce(evidence, inplace=False)\n",
" reduced_prod.normalize()\n",
" var_to_marg = set(self.model.nodes()) - set(var) - set([state[0] for state in evidence])\n",
" marg_prod = reduced_prod.marginalize(var_to_marg, inplace=False)\n",
" return marg_prod"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Defining a model\n",
"\n",
"from pgmpy.models import BayesianModel\n",
"from pgmpy.factors.discrete import TabularCPD\n",
"\n",
"model = BayesianModel([('A', 'J'), ('R', 'J'), ('J', 'Q'), ('J', 'L'), ('G', 'L')])\n",
"cpd_a = TabularCPD('A', 2, values=[[0.2], [0.8]])\n",
"cpd_r = TabularCPD('R', 2, values=[[0.4], [0.6]])\n",
"cpd_j = TabularCPD('J', 2, values=[[0.9, 0.6, 0.7, 0.1],\n",
" [0.1, 0.4, 0.3, 0.9]],\n",
" evidence=['A', 'R'], evidence_card=[2, 2])\n",
"cpd_q = TabularCPD('Q', 2, values=[[0.9, 0.2], [0.1, 0.8]],\n",
" evidence=['J'], evidence_card=[2])\n",
"cpd_l = TabularCPD('L', 2, values=[[0.9, 0.45, 0.8, 0.1],\n",
" [0.1, 0.55, 0.2, 0.9]],\n",
" evidence=['J', 'G'], evidence_card=[2, 2])\n",
"cpd_g = TabularCPD('G', 2, values=[[0.6], [0.4]])\n",
"\n",
"model.add_cpds(cpd_a, cpd_r, cpd_j, cpd_q, cpd_l, cpd_g)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Doing inference with our SimpleInference\n",
"\n",
"infer = SimpleInference(model)\n",
"a = infer.query(var=['A'], evidence=[('J', 0), ('R', 1)])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----+----------+\n",
"| A | phi(A) |\n",
"|-----+----------|\n",
"| A_0 | 0.6000 |\n",
"| A_1 | 0.4000 |\n",
"+-----+----------+\n"
]
}
],
"source": [
"print(a)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----+----------+\n",
"| A | phi(A) |\n",
"|-----+----------|\n",
"| A_0 | 0.6000 |\n",
"| A_1 | 0.4000 |\n",
"+-----+----------+\n"
]
}
],
"source": [
"# Comparing the results with Variable Elimination Algorithm\n",
"\n",
"from pgmpy.inference import VariableElimination\n",
"\n",
"infer = VariableElimination(model)\n",
"result = infer.query(['A'], evidence={'J': 0, 'R': 1})\n",
"print(result['A'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similarly we can also create new classes for Factor or CPDs and add them to networks and do inference over it or can write a new estimator class."
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [conda env:ccns]",
"language": "python",
"name": "conda-env-ccns-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
35 changes: 0 additions & 35 deletions examples/Writing an inference algorithm.ipynb

This file was deleted.

0 comments on commit 16feab9

Please sign in to comment.