diff --git a/.gitignore b/.gitignore index 8b624483..e8aeb638 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,9 @@ old_results/ log/ tf_summaries/ time_lines/ - +lenart_internal/ +*.pdf +*.csv mars/ # Byte-compiled / optimized / DLL files diff --git a/experiments/data_provider.py b/experiments/data_provider.py index 0707b180..8babc249 100644 --- a/experiments/data_provider.py +++ b/experiments/data_provider.py @@ -37,9 +37,8 @@ 'obs_noise_std': 0.05, 'x_support_mode_train': 'full', 'param_mode': 'random', - 'num_cells': 5, - 'num_genes': 15, - 'sergio_dim': 5 * 15, + 'num_cells': 10, + 'num_genes': 200, } DEFAULTS_RACECAR = { @@ -77,21 +76,21 @@ }, 'Greenhouse': { - 'likelihood_std': {'value': [0.05 for _ in range(16)]}, + 'likelihood_std': {'value': [0.01 for _ in range(16)]}, 'num_samples_train': {'value': 20}, }, 'Greenhouse_hf': { - 'likelihood_std': {'value': [0.05 for _ in range(16)]}, + 'likelihood_std': {'value': [0.01 for _ in range(16)]}, 'num_samples_train': {'value': 20}, }, 'Sergio': { - 'likelihood_std': {'value': [0.05 for _ in range(DEFAULTS_SERGIO['sergio_dim'])]}, + 'likelihood_std': {'value': [0.05 for _ in range(2 * DEFAULTS_SERGIO['num_cells'])]}, 'num_samples_train': {'value': 20}, }, 'Sergio_hf': { - 'likelihood_std': {'value': [0.05 for _ in range(DEFAULTS_SERGIO['sergio_dim'])]}, + 'likelihood_std': {'value': [0.05 for _ in range(2 * DEFAULTS_SERGIO['num_cells'])]}, 'num_samples_train': {'value': 20}, }, diff --git a/experiments/lf_hf_transfer_exp/run_regression_exp.py b/experiments/lf_hf_transfer_exp/run_regression_exp.py index 6874080f..b8b10ca4 100644 --- a/experiments/lf_hf_transfer_exp/run_regression_exp.py +++ b/experiments/lf_hf_transfer_exp/run_regression_exp.py @@ -250,14 +250,14 @@ def main(args): # print(f"Setting added_gp_outputscale to data_source default value from DATASET_CONFIGS " # f"which is {exp_params['added_gp_outputscale']}") elif 'pendulum' in exp_params['data_source']: - exp_params['added_gp_outputscale'] = [factor * 0.05, 0.05, 0.5] + exp_params['added_gp_outputscale'] = [factor * 0.05, factor * 0.05, factor * 0.5] elif 'Sergio' in exp_params['data_source']: from experiments.data_provider import DEFAULTS_SERGIO - exp_params['added_gp_outputscale'] = [factor * 0.1 for _ in range(DEFAULTS_SERGIO['sergio_dim'])] + exp_params['added_gp_outputscale'] = [factor * 0.1 for _ in range(2 * DEFAULTS_SERGIO['num_cells'])] elif 'Greenhouse' in exp_params['data_source']: - exp_params['added_gp_outputscale'] = [factor * 0.1 for _ in range(16)] + exp_params['added_gp_outputscale'] = [factor * 0.05 for _ in range(16)] # We are quite confident about exogenous effects - exp_params['added_gp_outputscale'][-6:] = [0.0001 for _ in range(6)] + exp_params['added_gp_outputscale'][5:] = [0.005 for _ in range(11)] else: raise AssertionError('passed negative value for added_gp_outputscale') # set likelihood_std to default value if not specified @@ -332,16 +332,16 @@ def main(args): parser.add_argument('--use_wandb', type=bool, default=False) # data parameters - parser.add_argument('--data_source', type=str, default='Greenhouse_hf') - parser.add_argument('--pred_diff', type=int, default=1) - parser.add_argument('--num_samples_train', type=int, default=100) + parser.add_argument('--data_source', type=str, default='Sergio_hf') + parser.add_argument('--pred_diff', type=int, default=0) + parser.add_argument('--num_samples_train', type=int, default=6400) parser.add_argument('--data_seed', type=int, default=77698) # standard BNN parameters parser.add_argument('--model', type=str, default='BNN_FSVGD_SimPrior_gp') parser.add_argument('--model_seed', type=int, default=892616) parser.add_argument('--likelihood_std', type=float, default=None) - parser.add_argument('--learn_likelihood_std', type=int, default=0) + parser.add_argument('--learn_likelihood_std', type=int, default=1) parser.add_argument('--likelihood_reg', type=float, default=0.0) parser.add_argument('--data_batch_size', type=int, default=8) parser.add_argument('--min_train_steps', type=int, default=2500) diff --git a/sim_transfer/sims/dynamics_models.py b/sim_transfer/sims/dynamics_models.py index bf0c47b3..bdfac01d 100644 --- a/sim_transfer/sims/dynamics_models.py +++ b/sim_transfer/sims/dynamics_models.py @@ -96,9 +96,9 @@ class GreenHouseParams(NamedTuple): cg: Union[jax.Array, float] = jnp.array(32 * (10 ** 3)) # green_house_heat_capacity cp_w: Union[jax.Array, float] = jnp.array(4180.0) # specific_heat_water cs: Union[jax.Array, float] = jnp.array(120 * (10 ** 3)) # green_house_soil_heat_capacity - cp_a: Union[jax.Array, float] = 1010 # air_specific_heat_water + cp_a: Union[jax.Array, float] = jnp.array(1010) # air_specific_heat_water d1: Union[jax.Array, float] = jnp.array(2.1332 * (10 ** (-7))) # plant development rate 1 - d2: Union[jax.Array, float] = jnp.array(2.4664 * (10 ** (-1))) # plant development rate 2 + d2: Union[jax.Array, float] = jnp.array(2.4664 * (10 ** (-7))) # plant development rate 2 d3: Union[jax.Array, float] = jnp.array(20) # plant development rate 3 d4: Union[jax.Array, float] = jnp.array(7.4966 * (10 ** (-11))) # plant development rate 4 f: Union[jax.Array, float] = jnp.array(1.2) # fruit assimilate requirment @@ -663,7 +663,6 @@ def _ode(self, x, u, params: CarParams): class SergioDynamics(ABC): - l_b: float = 0 lam_lb: float = 0.2 lam_ub: float = 0.9 @@ -673,6 +672,7 @@ def __init__(self, n_genes: int, params: SergioParams = SergioParams(), dt_integration: float = 0.01, + state_ub: float = 500.0, ): super().__init__() self.dt = dt @@ -680,21 +680,34 @@ def __init__(self, self.n_genes = n_genes self.params = params self.x_dim = self.n_cells * self.n_genes - + self.state_ub = state_ub self.dt_integration = dt_integration assert dt >= dt_integration assert (dt / dt_integration - int(dt / dt_integration)) < 1e-4, 'dt must be multiple of dt_integration' self._num_steps_integrate = int(dt / dt_integration) def next_step(self, x: jax.Array, params: PyTree) -> jax.Array: + x = self.transform_state(x) + def body(carry, _): q = carry + self.dt_integration * self.ode(carry, params) - q = jnp.clip(q, a_min=self.l_b) + q = jnp.clip(q, a_min=0) return q, None next_state, _ = jax.lax.scan(body, x, xs=None, length=self._num_steps_integrate) + next_state = self.inv_transform_state(next_state) return next_state + def transform_state(self, x): + # x is between [0, 1] -> [0, state_ub] + x = x * self.state_ub + return x + + def inv_transform_state(self, x): + # [0, state_ub] -> [0, 1] + x = x / self.state_ub + return x + def ode(self, x: jax.Array, params) -> jax.Array: assert x.shape[-1] == self.x_dim return self._ode(x, params) @@ -809,7 +822,7 @@ class GreenHouseDynamics(DynamicsModel): ] ) - input_ub = jnp.array([60, 1.0, 1.0, 2.0]) + input_ub = jnp.array([80, 1.0, 1.0, 2.1]) input_lb = jnp.array([0, 0.0, 0.0, 0.0]) noise_to = 5 noise_td = 0.01 @@ -821,11 +834,11 @@ class GreenHouseDynamics(DynamicsModel): noise_std = jnp.array( [ 0.05, 0.1, 0.05, 0.01, 0.05, 0.05, 0.1, 0.1, 0.01, noise_to, - noise_td, noise_co, noise_vo, noise_w, noise_g, 2, + noise_td, noise_co, noise_vo, noise_w, noise_g, 0.1, ] ) - def __init__(self, use_hf: bool = False, dt: float = 60): + def __init__(self, use_hf: bool = False, dt: float = 300): self.use_hf = use_hf self.greenhouse_state_dim = 5 @@ -847,6 +860,7 @@ def __init__(self, use_hf: bool = False, dt: float = 60): def next_step(self, x: jnp.array, u: jnp.array, params: GreenHouseParams) -> jnp.array: x, u = self.transform_state(x), self.transform_action(u) + def body(carry, _): q = carry + self.dt_integration * self.ode(carry, u, params) q = jnp.clip(q, a_min=self.constraint_lb, a_max=self.constraint_ub) @@ -883,9 +897,9 @@ def buffer_switching_func(B, b1): return 1 - jnp.exp(-b1 * B) def get_respiration_param(self, x, u, params: GreenHouseParams): - ml = x[self.greenhouse_state_dim + 2] - l_lai = (ml / params.wr) ** (params.laim) / (1 + (ml / params.wr) ** (params.laim)) - R = - params.p1 - params.p5 * l_lai + # ml = x[self.greenhouse_state_dim + 2] + # l_lai = (ml / params.wr) ** (params.laim) / (1 + (ml / params.wr) ** (params.laim)) + R = - params.p1 - params.p5 return R def get_crop_photosynthesis(self, x, u, params: GreenHouseParams): @@ -893,7 +907,12 @@ def get_crop_photosynthesis(self, x, u, params: GreenHouseParams): ml = x[self.greenhouse_state_dim + 2] G = x[-2] i_par = params.eta * G * params.mp * params.pg - c_ppm = (10 ** 6) * params.rg / (params.patm * params.Mco2) * (t_g + params.T0) * c_i + # note patm is in kPa. c_i is in g/m^3 + # c_ppm units: m^3 Pa/(mol K) * K * g/m^3 /(kPa * kg/mol) + # c_ppm: units 10^-3 m^3 kPa/mol * 10^-3 kg/m^3 /(kPa * kg/mol) + # c_ppm: units 10^-6 [] -> need to multiply with 10^-6 to get right units + # c_ppm = 1/mol * g/kg + c_ppm = params.rg / (params.patm * params.Mco2) * (t_g + params.T0) * c_i l_lai = (ml / params.wr) ** (params.laim) / (1 + (ml / params.wr) ** (params.laim)) p_g = params.pm * l_lai * i_par / (params.p3 + i_par) * c_ppm / (params.p4 + c_ppm) return p_g @@ -902,7 +921,8 @@ def get_harvest_coefficient(self, x, u, params: GreenHouseParams): d_p = x[self.greenhouse_state_dim + 3] t = x[-1] t_g = x[0] - h = (d_p >= 1) * (params.d1 + params.d2 * jnp.log(t_g / params.d3) - params.d4 * t) + temp_ratio = jnp.clip(t_g / params.d3, a_min=1e-6) + h = (d_p >= 1) * (params.d1 + params.d2 * jnp.log(temp_ratio) - params.d4 * t) return h def transform_state(self, x): @@ -915,12 +935,12 @@ def transform_action(self, u): return u def inv_transform_state(self, x): - x = (x - self.state_lb)/(self.state_ub - self.state_lb) + x = (x - self.state_lb) / (self.state_ub - self.state_lb) return x def _greenhouse_dynamics_hf(self, x, u, params: GreenHouseParams): # C, C, C, m, g/m^3, kg/m^-3 - t_g, t_p, t_s, c_i, v_i = x[0], x[1], x[2], x[3], x[5] + t_g, t_p, t_s, c_i, v_i = x[0], x[1], x[2], x[3], x[4] # g/m^-2, g/m^-2, g/m^-2, [] mb, mf, ml, d_p = x[self.greenhouse_state_dim], x[self.greenhouse_state_dim + 1], \ x[self.greenhouse_state_dim + 2], x[self.greenhouse_state_dim + 3] @@ -956,10 +976,6 @@ def _greenhouse_dynamics_hf(self, x, u, params: GreenHouseParams): dt_g_dt = (k_v + params.kr) * (t_o - t_g) + alpha * (t_p - t_g) + params.ks * (t_s - t_g) \ + G * params.eta - l * E + l / (1 + params.epsilon) * Mc dt_g_dt = dt_g_dt / params.cg - # jax.debug.print('G {x}', x=G/params.cg) - # jax.debug.print('Mc{x}', x=Mc) - # jax.debug.print('t_g {x}', x=t_g) - # jax.debug.print('dt_g {x}', x=dt_g_dt) phi = params.phi rh = params.rh @@ -991,8 +1007,8 @@ def _greenhouse_dynamics_hf(self, x, u, params: GreenHouseParams): hf, hl = h * params.yf, h * params.yl dmf_dt = (b * g_f - (1 - b) * rf - hf) * mf dml_dt = (b * g_l - (1 - b) * rl - hl) * ml - dd_p_dt = params.d1 + params.d2 * jnp.log(t_g / params.d3) - params.d4 * t - h - + temp_ratio = jnp.clip(t_g / params.d3, a_min=1e-6) + dd_p_dt = params.d1 + params.d2 * jnp.log(temp_ratio) - params.d4 * t - h # Exogenous effects dt_o_dt = jnp.zeros_like(dt_g_dt) @@ -1013,8 +1029,7 @@ def _greenhouse_dynamics_hf(self, x, u, params: GreenHouseParams): return dx_dt def _greenhouse_dynamics_lf(self, x, u, params: GreenHouseParams): - - t_g, t_p, t_s, c_i, v_i = x[0], x[1], x[2], x[3], x[5] + t_g, t_p, t_s, c_i, v_i = x[0], x[1], x[2], x[3], x[4] # g/m^-2, g/m^-2, g/m^-2, [] mb, mf, ml, d_p = x[self.greenhouse_state_dim], x[self.greenhouse_state_dim + 1], \ x[self.greenhouse_state_dim + 2], x[self.greenhouse_state_dim + 3] @@ -1031,9 +1046,6 @@ def _greenhouse_dynamics_lf(self, x, u, params: GreenHouseParams): dt_g_dt = (k_v + params.kr) * (t_o - t_g) + params.ks * (t_s - t_g) \ + G * params.eta dt_g_dt = dt_g_dt / params.cg - # jax.debug.print('l {x}', x=l) - # jax.debug.print('p_g_star{x}', x=p_g_star) - # jax.debug.print('t_g {x}', x=t_g) dt_p_dt = jnp.zeros_like(dt_g_dt) @@ -1060,7 +1072,8 @@ def _greenhouse_dynamics_lf(self, x, u, params: GreenHouseParams): hf, hl = h * params.yf, h * params.yl dmf_dt = (b * g_f - (1 - b) * rf - hf) * mf dml_dt = (b * g_l - (1 - b) * rl - hl) * ml - dd_p_dt = params.d1 + params.d2 * jnp.log(t_g / params.d3) - params.d4 * t - h + temp_ratio = jnp.clip(t_g / params.d3, a_min=1e-6) + dd_p_dt = params.d1 + params.d2 * jnp.log(temp_ratio) - params.d4 * t - h # Exogenous effects diff --git a/sim_transfer/sims/simulators.py b/sim_transfer/sims/simulators.py index 77d5d3b3..8b8fc443 100644 --- a/sim_transfer/sims/simulators.py +++ b/sim_transfer/sims/simulators.py @@ -991,12 +991,15 @@ class SergioSim(FunctionSimulator): _dt: float = 1 / 10 # domain for generating data - state_lb: float = 0.0 - state_ub: float = 500 + # state_lb: float = 0.0 + state_ub: float = 500.0 + sample_x_max: float = 3 def __init__(self, n_genes: int = 20, n_cells: int = 20, use_hf: bool = False): - FunctionSimulator.__init__(self, input_size=n_genes * n_cells, output_size=n_genes * n_cells) - self.model = SergioDynamics(self._dt, n_genes, n_cells) + FunctionSimulator.__init__(self, input_size=2 * n_cells, output_size=2 * n_cells) + self.model = SergioDynamics(self._dt, n_genes, n_cells, state_ub=self.state_ub) + self.n_cells = n_cells + self.n_genes = n_genes self._setup_params() self.use_hf = use_hf if self.use_hf: @@ -1013,8 +1016,8 @@ def __init__(self, n_genes: int = 20, n_cells: int = 20, use_hf: bool = False): 'lower bounds have to be smaller than upper bounds' # setup domain - self.domain_lower = jnp.ones(shape=(n_genes * n_cells,)) * self.state_lb - self.domain_upper = jnp.ones(shape=(n_genes * n_cells,)) * self.state_ub + self.domain_lower = -self.sample_x_max * jnp.ones(shape=(2 * self.n_cells,)) + self.domain_upper = self.sample_x_max * jnp.ones(shape=(2 * self.n_cells,)) self._domain = HypercubeDomain(lower=self.domain_lower, upper=self.domain_upper) @property @@ -1022,7 +1025,7 @@ def domain(self) -> Domain: return self._domain def _setup_params(self): - self.lower_bound_param_hf = SergioParams(lam=jnp.array(0.79), + self.lower_bound_param_hf = SergioParams(lam=jnp.array(0.75), contribution_rates=jnp.array(-5.0), basal_rates=jnp.array(1.0), power=jnp.array(2.0), @@ -1035,7 +1038,7 @@ def _setup_params(self): self.default_param_hf = self.model.sample_single_params(jax.random.PRNGKey(0), self.lower_bound_param_hf, self.upper_bound_param_hf) - self.lower_bound_param_lf = SergioParams(lam=jnp.array(0.79), + self.lower_bound_param_lf = SergioParams(lam=jnp.array(0.4), contribution_rates=jnp.array(-5.0), basal_rates=jnp.array(1.0), power=jnp.array(1.0), @@ -1062,27 +1065,55 @@ def sample_params(self, rng_key: jax.random.PRNGKey): train_params = train_params._replace(power=0, graph=0) return params, train_params + def predict_next_state(self, x: jnp.array, params: NamedTuple, + key: jax.random.PRNGKey = jax.random.PRNGKey(0)) -> jnp.array: + assert x.ndim == 1 + mu, log_std = jnp.split(x, 2, axis=-1) + x = mu + jax.random.normal(key=key, shape=(self.n_genes, self.n_cells)) * jax.nn.softplus(log_std) + x = x.reshape(self.n_genes * self.n_cells) + # clip state to be between -3, 3 + x = jnp.clip(x, -self.sample_x_max, self.sample_x_max) + # scale it to be between [0, 1] + x = x / (2 * self.sample_x_max) + 0.5 + # sample next cells and genes from the sim + f = self.model.next_step(x, params) + # rescale it back to be between [-3, 3] + f = (f - 0.5) * (2 * self.sample_x_max) + # take the mean and std over genes + f = f.reshape(self.n_genes, self.n_cells) + mu_f, std_f = jnp.mean(f, axis=0), jnp.std(f, axis=0) + # clip std so that its positive and take log std + std_f = jnp.clip(std_f, 1e-6) + log_std_f = jnp.log(jnp.exp(std_f) - 1) + f = jnp.concatenate([mu_f, log_std_f], axis=-1) + return f + def sample_function_vals(self, x: jnp.ndarray, num_samples: int, rng_key: jax.random.PRNGKey) -> jnp.ndarray: assert x.ndim == 2 and x.shape[-1] == self.input_size + rng_key, gene_key = jax.random.split(rng_key, 2) params = self.model.sample_params_uniform(rng_key, sample_shape=num_samples, lower_bound=self._lower_bound_params, upper_bound=self._upper_bound_params) + gene_key = jax.random.split(gene_key, num_samples) - def batched_fun(z, params): - f = vmap(self.model.next_step, in_axes=(0, None))(z, params) + def batched_fun(z, params, key): + f = vmap(self.predict_next_state, in_axes=(0, None, None))(z, params, key) return f - f = vmap(batched_fun, in_axes=(None, 0))(x, params) + f = vmap(batched_fun, in_axes=(None, 0, 0))(x, params, gene_key) assert f.shape == (num_samples, x.shape[0], self.output_size) return f def sample_functions(self, num_samples: int, rng_key: jax.random.PRNGKey) -> Callable: + gene_key, rng_key = jax.random.split(rng_key, 2) params = self.model.sample_params_uniform(rng_key, sample_shape=(num_samples,), lower_bound=self._lower_bound_params, upper_bound=self._upper_bound_params) - def stacked_fun(z): - f = vmap(self.model.next_step, in_axes=(0, 0))(x, params) + gene_key = jax.random.split(gene_key, num_samples) + + def stacked_fun(x): + f = vmap(self.predict_next_state, in_axes=(0, 0, 0))(x, params, gene_key) return f return stacked_fun @@ -1094,18 +1125,18 @@ def domain(self) -> Domain: @property def normalization_stats(self) -> Dict[str, jnp.ndarray]: - stats = {'x_mean': jnp.ones(self.input_size) * (self.state_ub + self.state_lb) / 2, - 'x_std': jnp.ones(self.input_size) * (self.state_ub - self.state_lb) ** 2 / 12, - 'y_mean': jnp.ones(self.output_size) * (self.state_ub + self.state_lb) / 2, - 'y_std': jnp.ones(self.output_size) * (self.state_ub - self.state_lb) ** 2 / 12} + stats = {'x_mean': jnp.zeros(self.input_size), + 'x_std': (self.sample_x_max ** 2) * jnp.ones(self.input_size) / 3.0, + 'y_mean': jnp.zeros(self.output_size), + 'y_std': (self.sample_x_max ** 2) * jnp.ones(self.input_size) / 3.0} return stats def _typical_f(self, x: jnp.array) -> jnp.array: - f = jax.vmap(self.model.next_step, in_axes=(0, None))(x, self._typical_params) + f = jax.vmap(self.predict_next_state, in_axes=(0, None))(x, self._typical_params) return f def evaluate_sim(self, x: jnp.array, params: NamedTuple) -> jnp.array: - f = jax.vmap(self.model.next_step, in_axes=(0, None))(x, params) + f = jax.vmap(self.predict_next_state, in_axes=(0, None))(x, params) return f def _add_observation_noise(self, f_vals: jnp.ndarray, obs_noise_std: Union[jnp.ndarray, float], @@ -1125,7 +1156,7 @@ def _sample_x_data(self, rng_key: jax.random.PRNGKey, num_samples_train: int, nu class GreenHouseSim(FunctionSimulator): - param_ratio = 0.1 + param_ratio = 0.4 def __init__(self, use_hf: bool = False): self.model = GreenHouseDynamics(use_hf=use_hf) @@ -1242,10 +1273,6 @@ def stacked_fun(z): return stacked_fun - @property - def domain(self) -> Domain: - return self._domain - @property def normalization_stats(self) -> Dict[str, jnp.ndarray]: # x_u_b = jnp.concatenate([self.model.state_ub, self.model.input_ub], axis=0) @@ -1259,8 +1286,8 @@ def normalization_stats(self) -> Dict[str, jnp.ndarray]: 'y_mean': (y_u_b + y_l_b) / 2, 'y_std': (y_u_b - y_l_b) ** 2 / 12, } - # 'y_mean': (self.model.state_ub + self.model.state_lb) / 2, - # 'y_std': (self.model.state_ub - self.model.state_lb) ** 2 / 12} + # 'y_mean': (self.model.state_ub + self.model.state_lb) / 2, + # 'y_std': (self.model.state_ub - self.model.state_lb) ** 2 / 12} return stats def _typical_f(self, x: jnp.array) -> jnp.array: @@ -1281,6 +1308,8 @@ def get_eval_params(self, params: NamedTuple): def _add_observation_noise(self, f_vals: jnp.ndarray, obs_noise_std: Union[jnp.ndarray, float], rng_key: jax.random.PRNGKey) -> jnp.ndarray: + if isinstance(obs_noise_std, float): + obs_noise_std = jnp.ones_like(self.model.noise_std) * obs_noise_std obs_noise_std = jnp.clip(obs_noise_std, a_max=self.model.noise_std) y = f_vals + obs_noise_std * jax.random.normal(rng_key, shape=f_vals.shape) y = jnp.clip(y, a_min=self.model.constraint_lb) @@ -1434,28 +1463,30 @@ def evaluate_sim(self, x: jnp.array, params: NamedTuple) -> jnp.array: key1, key2 = jax.random.split(jax.random.PRNGKey(435349), 2) key_hf, key_lf = jax.random.split(key1, 2) - function_sim = PredictStateChangeWrapper(GreenHouseSim(use_hf=True)) - test_p, test_p_train = function_sim.sample_params(key1) - x, _ = function_sim._sample_x_data(key_hf, 64, 1) - param1 = function_sim._function_simulator._typical_params - f1 = function_sim.sample_function_vals(x, num_samples=4000, rng_key=key2) - f1 = function_sim._function_simulator.model.transform_state(f1) - import numpy as np - f2 = function_sim._typical_f(x) - print(jnp.isnan(f1).any()) - print(jnp.isnan(f2).any()) - - function_sim = PredictStateChangeWrapper(GreenHouseSim(use_hf=False)) - test_p, test_p_train = function_sim.sample_params(key1) - x, _ = function_sim._sample_x_data(key_lf, 64, 1) - param1 = function_sim._function_simulator._typical_params - f1 = function_sim.sample_function_vals(x, num_samples=4000, rng_key=key2) - import numpy as np - - f2 = function_sim._typical_f(x) - print(jnp.isnan(f1).any()) - print(jnp.isnan(f2).any()) - + # function_sim = GreenHouseSim(use_hf=True) + # test_p, test_p_train = function_sim.sample_params(key1) + # x, _ = function_sim._sample_x_data(key_hf, 64, 1) + # param1 = function_sim._typical_params + # f1 = function_sim.sample_function_vals(x, num_samples=4000, rng_key=key2) + # f1 = function_sim.model.transform_state(f1) + # import numpy as np + # + # f2 = function_sim._typical_f(x) + # f2 = function_sim.model.transform_state(f2) + # print(jnp.isnan(f1).any()) + # print(jnp.isnan(f2).any()) + # check = np.max(np.abs(np.asarray(f1 - function_sim.model.transform_state(x[..., : 16]))), axis=0) + # function_sim = GreenHouseSim(use_hf=False) + # test_p, test_p_train = function_sim.sample_params(key1) + # x, _ = function_sim._sample_x_data(key_lf, 64, 1) + # param1 = function_sim._typical_params + # f1 = function_sim.sample_function_vals(x, num_samples=4000, rng_key=key2) + # import numpy as np + # + # f2 = function_sim._typical_f(x) + # check = np.max(np.abs(np.asarray(f1 - function_sim.model.transform_state(x[..., : 16]))), axis=0) + # print(jnp.isnan(f1).any()) + # print(jnp.isnan(f2).any()) function_sim = SergioSim(5, 10, use_hf=False) function_sim.sample_params(key1)