Skip to content

Commit

Permalink
Implemented resampling in SMCopt (#58)
Browse files Browse the repository at this point in the history
Co-authored-by: Andreas Størksen Stordal <[email protected]>
  • Loading branch information
Ninjahh83 and Andreas Størksen Stordal authored Mar 11, 2024
1 parent b426856 commit fb59553
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
36 changes: 25 additions & 11 deletions popt/loop/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def __set__variable(var_name=None, defalut=None):

# Inflation factor used in SmcOpt
self.inflation_factor = None
self.survival_factor = None
self.particles = np.empty((self.cov.shape[0],0))
self.particle_values = np.empty((0))

# Initialize variables for bias correction
if 'bias_file' in self.sim.input_dict: # use bias correction
Expand Down Expand Up @@ -363,7 +366,7 @@ def calc_ensemble_weights(self, x, *args):
Control vector, shape (number of controls, )
args : tuple
Inflation factor and covarice (:math:`C_x`), shape (number of controls, number of controls)
Inflation factor, covariance (:math:`C_x`, shape (number of controls, number of controls)) and survival factor
Returns
-------
Expand All @@ -377,14 +380,18 @@ def calc_ensemble_weights(self, x, *args):
# Set the inflation factor and covariance equal to the input
self.inflation_factor = args[0]
self.cov = args[1]

self.survival_factor = args[2]

# If bias correction is used we need to temporarily store the initial state
initial_state = None
if self.bias_file is not None and self.bias_factors is None: # first iteration
initial_state = deepcopy(self.state) # store this to update current objective values

# Generate ensemble of states
self.ne = self.num_samples
if self.particles.shape[1] == 0:
self.ne = self.num_samples
else:
self.ne = int(np.round(self.num_samples*self.survival_factor))
self._aux_input()
self.state = self._gen_state_ensemble()

Expand All @@ -393,24 +400,31 @@ def calc_ensemble_weights(self, x, *args):
self._scale_state() # scale back to [0, 1]
self.ens_func_values = self.obj_func(self.pred_data, self.sim.input_dict, self.sim.true_order)
self.ens_func_values = np.array(self.ens_func_values)
state_ens = at.aug_state(self.state, list(self.state.keys()))
self.particles = np.hstack((self.particles, state_ens))
self.particle_values = np.hstack((self.particle_values,self.ens_func_values))

# If bias correction is used we need to calculate the bias factors, J(u_j,m_j)/J(u_j,m)
if self.bias_file is not None: # use bias corrections
self._bias_factors(self.ens_func_values, initial_state)

# Calculate the weights and ensemble sensitivity matrix
warnings.filterwarnings('ignore') # suppress warnings
weights = np.zeros(self.ne)
for i in np.arange(self.ne):
weights[i] = np.exp(-self.ens_func_values[i]*self.inflation_factor)
weights = np.zeros(self.num_samples)
for i in np.arange(self.num_samples):
weights[i] = np.exp(-(self.particle_values[i]-np.min(self.particle_values))*self.inflation_factor)

weights = weights + 0.000001
weights = weights/np.sum(weights) # TODO: Sjekke at disse er riktig
state_ens = at.aug_state(self.state, list(self.state.keys()))
sens_matrix = state_ens @ weights
index = np.argmin(self.ens_func_values)
best_ens = state_ens[:, index]
best_func = self.ens_func_values[index]

sens_matrix = self.particles @ weights
index = np.argmin(self.particle_values)
best_ens = self.particles[:, index]
best_func = self.particle_values[index]
resample_index = np.random.choice(self.num_samples,int(np.round(self.num_samples-self.num_samples*self.survival_factor)),
replace=True,p=weights)
self.particles = self.particles[:, resample_index]
self.particle_values = self.particle_values[resample_index]
return sens_matrix, best_ens, best_func

def _gen_state_ensemble(self):
Expand Down
2 changes: 1 addition & 1 deletion popt/update_schemes/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def apply_smc_update(self, control, gradient, **kwargs):
alpha = self._step_size

# apply update
new_control = alpha * control + (1-alpha) * gradient
new_control = (1-alpha) * control + alpha * gradient
return new_control

def apply_backtracking(self):
Expand Down
5 changes: 4 additions & 1 deletion popt/update_schemes/smcopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self, fun, x, args, sens, bounds=None, **options):
- resample: number indicating how many times resampling is tried if no improvement is found
- cov_factor: factor used to shrink the covariance for each resampling trial (defalut 0.5)
- inflation_factor: term used to weight down prior influence (defalult 1)
- survival_factor: fraction of surviving samples
- savedata: specify which class variables to save to the result files (state, objective function
value, iteration number, number of function evaluations, and number of gradient
evaluations, are always saved)
Expand Down Expand Up @@ -72,6 +73,7 @@ def __set__variable(var_name=None, defalut=None):
self.max_resample = __set__variable('resample', 0)
self.cov_factor = __set__variable('cov_factor', 0.5)
self.inflation_factor = __set__variable('inflation_factor', 1)
self.survival_factor = __set__variable('survival_factor', 0)

# Calculate objective function of startpoint
if not self.restart:
Expand Down Expand Up @@ -110,7 +112,8 @@ def calc_update(self,):
shrink = self.cov_factor ** resampling_iter

# Calc sensitivity
(sens_matrix, self.best_state, best_func_tmp) = self.sens(self.mean_state, inflate, shrink*self.cov)
(sens_matrix, self.best_state, best_func_tmp) = self.sens(self.mean_state, inflate,
shrink*self.cov, self.survival_factor)
self.njev += 1

# Initialize for this step
Expand Down

0 comments on commit fb59553

Please sign in to comment.