diff --git a/popt/loop/ensemble.py b/popt/loop/ensemble.py index 42aea92..a96e4e2 100644 --- a/popt/loop/ensemble.py +++ b/popt/loop/ensemble.py @@ -119,6 +119,9 @@ def __set__variable(var_name=None, defalut=None): # Inflation factor used in SmcOpt self.inflation_factor = None + self.survival_factor = None + self.particles = np.empty((self.cov.shape[0],0)) + self.particle_values = np.empty((0)) # Initialize variables for bias correction if 'bias_file' in self.sim.input_dict: # use bias correction @@ -363,7 +366,7 @@ def calc_ensemble_weights(self, x, *args): Control vector, shape (number of controls, ) args : tuple - Inflation factor and covarice (:math:`C_x`), shape (number of controls, number of controls) + Inflation factor, covariance (:math:`C_x`, shape (number of controls, number of controls)) and survival factor Returns ------- @@ -377,14 +380,18 @@ def calc_ensemble_weights(self, x, *args): # Set the inflation factor and covariance equal to the input self.inflation_factor = args[0] self.cov = args[1] - + self.survival_factor = args[2] + # If bias correction is used we need to temporarily store the initial state initial_state = None if self.bias_file is not None and self.bias_factors is None: # first iteration initial_state = deepcopy(self.state) # store this to update current objective values # Generate ensemble of states - self.ne = self.num_samples + if self.particles.shape[1] == 0: + self.ne = self.num_samples + else: + self.ne = int(np.round(self.num_samples*self.survival_factor)) self._aux_input() self.state = self._gen_state_ensemble() @@ -393,6 +400,9 @@ def calc_ensemble_weights(self, x, *args): self._scale_state() # scale back to [0, 1] self.ens_func_values = self.obj_func(self.pred_data, self.sim.input_dict, self.sim.true_order) self.ens_func_values = np.array(self.ens_func_values) + state_ens = at.aug_state(self.state, list(self.state.keys())) + self.particles = np.hstack((self.particles, state_ens)) + self.particle_values = np.hstack((self.particle_values,self.ens_func_values)) # If bias correction is used we need to calculate the bias factors, J(u_j,m_j)/J(u_j,m) if self.bias_file is not None: # use bias corrections @@ -400,17 +410,21 @@ def calc_ensemble_weights(self, x, *args): # Calculate the weights and ensemble sensitivity matrix warnings.filterwarnings('ignore') # suppress warnings - weights = np.zeros(self.ne) - for i in np.arange(self.ne): - weights[i] = np.exp(-self.ens_func_values[i]*self.inflation_factor) + weights = np.zeros(self.num_samples) + for i in np.arange(self.num_samples): + weights[i] = np.exp(-(self.particle_values[i]-np.min(self.particle_values))*self.inflation_factor) + weights = weights + 0.000001 weights = weights/np.sum(weights) # TODO: Sjekke at disse er riktig - state_ens = at.aug_state(self.state, list(self.state.keys())) - sens_matrix = state_ens @ weights - index = np.argmin(self.ens_func_values) - best_ens = state_ens[:, index] - best_func = self.ens_func_values[index] + sens_matrix = self.particles @ weights + index = np.argmin(self.particle_values) + best_ens = self.particles[:, index] + best_func = self.particle_values[index] + resample_index = np.random.choice(self.num_samples,int(np.round(self.num_samples-self.num_samples*self.survival_factor)), + replace=True,p=weights) + self.particles = self.particles[:, resample_index] + self.particle_values = self.particle_values[resample_index] return sens_matrix, best_ens, best_func def _gen_state_ensemble(self): diff --git a/popt/update_schemes/optimizers.py b/popt/update_schemes/optimizers.py index 3fb9146..2c5ba22 100644 --- a/popt/update_schemes/optimizers.py +++ b/popt/update_schemes/optimizers.py @@ -117,7 +117,7 @@ def apply_smc_update(self, control, gradient, **kwargs): alpha = self._step_size # apply update - new_control = alpha * control + (1-alpha) * gradient + new_control = (1-alpha) * control + alpha * gradient return new_control def apply_backtracking(self): diff --git a/popt/update_schemes/smcopt.py b/popt/update_schemes/smcopt.py index 63646ea..5db01a2 100644 --- a/popt/update_schemes/smcopt.py +++ b/popt/update_schemes/smcopt.py @@ -42,6 +42,7 @@ def __init__(self, fun, x, args, sens, bounds=None, **options): - resample: number indicating how many times resampling is tried if no improvement is found - cov_factor: factor used to shrink the covariance for each resampling trial (defalut 0.5) - inflation_factor: term used to weight down prior influence (defalult 1) + - survival_factor: fraction of surviving samples - savedata: specify which class variables to save to the result files (state, objective function value, iteration number, number of function evaluations, and number of gradient evaluations, are always saved) @@ -72,6 +73,7 @@ def __set__variable(var_name=None, defalut=None): self.max_resample = __set__variable('resample', 0) self.cov_factor = __set__variable('cov_factor', 0.5) self.inflation_factor = __set__variable('inflation_factor', 1) + self.survival_factor = __set__variable('survival_factor', 0) # Calculate objective function of startpoint if not self.restart: @@ -110,7 +112,8 @@ def calc_update(self,): shrink = self.cov_factor ** resampling_iter # Calc sensitivity - (sens_matrix, self.best_state, best_func_tmp) = self.sens(self.mean_state, inflate, shrink*self.cov) + (sens_matrix, self.best_state, best_func_tmp) = self.sens(self.mean_state, inflate, + shrink*self.cov, self.survival_factor) self.njev += 1 # Initialize for this step