Skip to content

Commit

Permalink
fixes to rates only fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
Jhsmit committed Jul 17, 2020
1 parent 90a1b0c commit 204acb0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
7 changes: 3 additions & 4 deletions pyhdx/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def global_fit(self, initial_result, k_int=None, learning_rate=0.01, l1=1e2, l2=
funcs = []
inputs = []
losses = []
gen = self._prepare_global_fit_gen(initial_result, k_int=k_int, learning_rate=learning_rate, l1=l1, l2=l2)
gen = self._prepare_global_fit_gen(initial_result, k_int=k_int, l1=l1, l2=l2)

early_stop = ftf.EarlyStopping(monitor='loss', min_delta=0.1, patience=50)
callbacks = [early_stop] if callbacks is None else callbacks
Expand All @@ -767,7 +767,7 @@ def global_fit(self, initial_result, k_int=None, learning_rate=0.01, l1=1e2, l2=
cb = ftf.LossHistory()
#early_stop = ftf.EarlyStopping(monitor='loss', min_delta=0.1, patience=50)

model.compile(loss='mse', optimizer=ftf.Adagrad(learning_rate=0.01))
model.compile(loss='mse', optimizer=ftf.Adagrad(learning_rate=learning_rate))
result = model.fit(input_data, output_data, verbose=0, epochs=epochs, callbacks=callbacks + [cb])
losses.append(result.history['loss'])
print('number of epochs', len(result.history['loss']))
Expand All @@ -778,7 +778,7 @@ def global_fit(self, initial_result, k_int=None, learning_rate=0.01, l1=1e2, l2=

return tf_fitresult

def _prepare_global_fit_gen(self, initial_result, k_int=None, learning_rate=0.01, l1=1e2, l2=0.):
def _prepare_global_fit_gen(self, initial_result, k_int=None, l1=1e2, l2=0.):
"""
Parameters
Expand All @@ -803,7 +803,6 @@ def _prepare_global_fit_gen(self, initial_result, k_int=None, learning_rate=0.01

regularizer = ftf.L1L2Differential(l1, l2)
if k_int is not None:

indices = np.searchsorted(k_int['r_number'], section.cov.r_number)
if not len(indices) == len(np.unique(indices)):
raise ValueError('Invalid match between section r number and k_int r number')
Expand Down
19 changes: 13 additions & 6 deletions pyhdx/fitting_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def __call__(self, inputs, **parameters):
uptake = 1 - tf.exp(-tf.matmul((inputs[1]/(1 + pfact)), self.timepoints))
return 100*tf.matmul(inputs[0], uptake)

def compute_output_shape(self, input_shape):
return input_shape[0], len(self.timepoints)
# def compute_output_shape(self, input_shape):
# return input_shape[0], len(self.timepoints)

@staticmethod
def output(weights):
Expand All @@ -134,16 +134,23 @@ def output(weights):
class AssociationRateFunc(object):
parameter_name = 'log_k'

"""
Function passed to CurveFit layer to calculate forward propagation through the network
"""
def __init__(self, timepoints):

self.timepoints = tf.dtypes.cast(tf.expand_dims(timepoints, 0), tf.float32)

def __call__(self, X, **parameters):
def __call__(self, inputs, **parameters):
k = 10**parameters[self.parameter_name]
uptake = 1 - tf.exp(-tf.matmul(k, self.timepoints))
return 100*tf.matmul(X, uptake)
return 100*tf.matmul(inputs[0], uptake)

def compute_output_shape(self, input_shape):
return input_shape[0], len(self.timepoints)
# def compute_output_shape(self, input_shape):
# return input_shape[0], len(self.timepoints)


class TFFitResult(object):
Expand Down

0 comments on commit 204acb0

Please sign in to comment.