-
Notifications
You must be signed in to change notification settings - Fork 3
/
CV_driver.py
159 lines (127 loc) · 6.12 KB
/
CV_driver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import pythonNetica as pyn
import CV_tools as CVT
import numpy as np
import pickle
import gzip
import sys
import shutil
'''
CV_driver.py
a cross-validation driver for Netica
a m!ke@usgs joint
'''
############
# CONFIGURATION FILE NAME
try:
parfile = sys.argv[1]
except:
parfile = 'example.xml'
############
# initialize
cdat = pyn.pynetica()
# read in the problem parameters
cdat.probpars = CVT.input_parameters(parfile)
# Initialize a pynetica instance/env using password in a text file
cdat.pyt.start_environment(cdat.probpars.pwdfile)
cdat.pyt.LimitMemoryUsage(5.0e16) # --> crank up the memory available
# read in the data from a base cas file
cdat.read_cas_file(cdat.probpars.baseCAS)
# perform rebinning if requested
if cdat.probpars.rebin_flag:
# copy over the originalNET neta file to be the baseNET for this work
shutil.copyfile(cdat.probpars.originalNET, cdat.probpars.baseNET)
# sets equiprobable bins for each node as requested
cdat.UpdateNeticaBinThresholds()
# set up the experience node indexing
cdat.NodeParentIndexing(cdat.probpars.baseNET, cdat.probpars.baseCAS)
# create the folds desired
cdat.allfolds = CVT.all_folds()
cdat.allfolds.k_fold_maker(cdat.N, cdat.probpars.numfolds)
# associate the casefile with the net
print '*'*5 + 'Learning base casefile in base net' + '*'*5 + '\n\n'
cdat.pyt.rebuild_net(cdat.probpars.baseNET,
cdat.probpars.baseCAS,
cdat.probpars.voodooPar,
cdat.probpars.baseNET,
cdat.probpars.EMflag)
# run the predictions using the base net -->
cdat.basepred, cdat.NETNODES = cdat.predictBayes(cdat.probpars.baseNET, cdat.N, cdat.casdata)
print '*'*5 + 'Making Base Case Testing using built-in Netica Functions' + '*'*5 + '\n\n'
# ############### Now run the Netica built-in testing stuff ################
cdat.PredictBayesNeticaCV(-999, cdat.probpars.baseNET, None)
print '*'*5 + 'Finished --> Base Case Testing using built-in Netica Functions' + '*'*5 + '\n\n'
# write the results to a post-processing world
cdat.PredictBayesPostProc(cdat.basepred,
cdat.probpars.scenario.name + '_base_stats.dat',
cdat.probpars.baseCAS,
cdat.BaseNeticaTests)
# also need to postprocess experience
cdat.ExperiencePostProc()
# optionally perform sensitivity analysis on the base case
if cdat.probpars.report_sens:
cdat.SensitivityAnalysis()
# if requested, perform K-fold cross validation
if cdat.probpars.CVflag:
print '\n' * 2 + '#'*20 + '\n Performing k-fold cross-validation for %d folds\n' %(cdat.probpars.numfolds) \
+ '#'*20+'\n' * 2
# set up for cross validation
print '\nSetting up cas files and file pointers for cross validation'
kfoldOFP_Val, kfoldOFP_Cal = cdat.cross_val_setup()
# now build all the nets
for cfold in np.arange(cdat.probpars.numfolds):
print ' ' * 10 + '#' * 20 + '\n' + ' ' * 10 + '# F O L D --> {0:d} #\n'.format(cfold)\
+ ' ' * 10 + '#' * 20
# rebuild the net
cname = cdat.allfolds.casfiles[cfold]
cdat.pyt.rebuild_net(cdat.probpars.baseNET,
cname,
cdat.probpars.voodooPar,
cname[:-4] + '.neta',
cdat.probpars.EMflag)
# make predictions for both validation and calibration data sets
print '*'*5 + 'Calibration predictions' + '*'*5
cdat.allfolds.calpred[cfold], cdat.allfolds.calNODES[cfold] = (
cdat.predictBayes(cname[:-4] + '.neta',
cdat.allfolds.calN[cfold],
cdat.allfolds.caldata[cfold]))
print '*'*5 + 'End Calibration predictions' + '*'*5 + '\n\n'
print '*'*5 + 'Making Calibration Testing using built-in Netica Functions' + '*'*5 + '\n\n'
# ############### Now run the Netica built-in testing stuff ################
cdat.PredictBayesNeticaCV(cfold,cname[:-4] + '.neta', 'CAL')
print '*'*5 + 'Finished --> Calibration Testing using built-in Netica Functions' + '*'*5 + '\n\n'
print '*'*5 + 'Start Validation predictions' + '*'*5
cdat.allfolds.valpred[cfold], cdat.allfolds.valNODES[cfold] = (
cdat.predictBayes(cname[:-4] + '.neta',
cdat.allfolds.valN[cfold],
cdat.allfolds.valdata[cfold]))
print '*'*5 + 'End Validation predictions' + '*'*5 + '\n\n'
print '*'*5 + 'Making Validation Testing using built-in Netica Functions' + '*'*5 + '\n\n'
# ############### Now run the Netica built-in testing stuff ################
cdat.PredictBayesNeticaCV(cfold, cname[:-4] + '.neta', 'VAL')
print '*'*5 + 'Finished --> Validation Testing using built-in Netica Functions' + '*'*5 + '\n\n'
print "write out validation"
cdat.PredictBayesPostProcCV(cdat.allfolds.valpred,
cdat.probpars.numfolds,
kfoldOFP_Val,
'Validation',
cdat.NeticaTests['VAL'])
print "write out calibration"
cdat.PredictBayesPostProcCV(cdat.allfolds.calpred,
cdat.probpars.numfolds,
kfoldOFP_Cal,
'Calibration',
cdat.NeticaTests['CAL'])
kfoldOFP_Cal.close()
kfoldOFP_Val.close()
# summarize over all the folds to make a consolidated text file
cdat.SummarizePostProcCV()
# Done with Netica so shut it down
cdat.pyt.CloseNetica()
# first need to sanitize away any ctypes/Netica pointers
cdat.pyt.sanitize()
# now dump into a pickle file
outfilename = parfile[:-4] + '_cdat.pklz'
print 'Dumping cdat to pickle file --> {0:s}'.format(outfilename)
ofp = gzip.open(outfilename, 'wb')
pickle.dump(cdat, ofp)
ofp.close()