-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptimized_query.py
81 lines (69 loc) · 3.9 KB
/
optimized_query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
Author : Zach Seiss
Email : [email protected]
Written : June 4, 2022
Last Update : June 18, 2022
"""
import sys
import pandas as pd
import copy
from tqdm import tqdm
from pgmpy.inference.ExactInference import VariableElimination
from pgmpy.inference.ExactInference import BeliefPropagation
def environment_map(data_frame, universe):
""" FUNCTION environment_map
__________________________________________________________________________________________
User should pass in the list of all environment variables as 'universe'. The returned
object is a nested dictionary mapping each environment variable to the dictionary which
maps its respective states to their indexes in the appropriate CPT.
data_frame - A DataFrame object which contains the environment variables
that we are interested in.
universe - An iterable containing all the names (string format) of the
environment variables. Should be a subset of data_frame.columns
__________________________________________________________________________________________"""
return {variable: state_mapping(data_frame[variable].unique()) for variable in universe}
def state_mapping(state_space):
""" FUNCTION state_mapping
__________________________________________________________________________________________
Given an environment variable this function returns a dictionary mapping every state
of the variable to the integer that corresponds to its index in the conditional
probability table.
state_space - The iterable containing all possible states of an environment
variable.
_________________________________________________________________________________________"""
return dict([(b, a) for a, b in enumerate(sorted(state_space))])
def fast_query(bns: list, test_grp_indexes, environment_variables: list, data_frame: pd.DataFrame, target: str):
inferences = [VariableElimination(bn) for bn in bns]
env_map = environment_map(data_frame, environment_variables)
quick_lookup_tables = []
error_count = 0
for i in range(len(inferences)):
df = data_frame.iloc[test_grp_indexes[i]]
groupby = df.groupby(environment_variables[:-1])[environment_variables[-1]]
multi_index = groupby.value_counts().index
query_evidence_table = pd.DataFrame(multi_index)
for j in tqdm(range(query_evidence_table.size),
desc=f'Testing group {i+1} of {len(inferences)} : ',
colour='GREEN'):
query_evidence = \
{v: env_map[v][s] for v, s in zip(environment_variables,
query_evidence_table.loc[j][0])
if s != 'N'}
try:
inference = copy.deepcopy(inferences[i])
inference = inference.query([target], query_evidence, show_progress=False)
query_evidence_table.loc[j][0] = inference.values[1]
except IndexError as e:
""" For the time being if this happens we will predict 'satisfied.' """
query_evidence_table.loc[j][0] = 1.0
error_count += 1
print(e)
except ValueError as e:
print(f'query_evidence : {query_evidence}')
error_count += 1
print(e)
mymap = pd.DataFrame(range(len(multi_index)), index=multi_index)
quick_lookup = pd.merge(mymap, query_evidence_table, left_on=mymap.columns[0], right_index=True)
quick_lookup_tables.append(quick_lookup)
num_queries = sum(len(e) for e in quick_lookup_tables)
return quick_lookup_tables, num_queries, type(inferences[0]), error_count