Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bootstrapping Final Model #716

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 163 additions & 119 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class in this module implements the general logic in a very versatile way
from abc import abstractmethod
import inspect
from collections import defaultdict
from joblib import Parallel, delayed
import re

import numpy as np
Expand All @@ -42,7 +43,7 @@ class in this module implements the general logic in a very versatile way
from ._cate_estimator import (BaseCateEstimator, LinearCateEstimator,
TreatmentExpansionMixin)
from .inference import BootstrapInference
from .utilities import (_deprecate_positional, check_input_arrays,
from .utilities import (convertArg, _deprecate_positional, check_input_arrays,
cross_product, filter_none_kwargs,
inverse_onehot, jacify_featurizer, ndim, reshape, shape, transpose)

Expand Down Expand Up @@ -195,7 +196,7 @@ def predict(self, X, y, W=None):

CachedValues = namedtuple('CachedValues', ['nuisances',
'Y', 'T', 'X', 'W', 'Z', 'sample_weight', 'freq_weight',
'sample_var', 'groups'])
'sample_var', 'groups', 'output_T'])


class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
Expand Down Expand Up @@ -443,6 +444,7 @@ def __init__(self, *,
self.categories = categories
self.mc_iters = mc_iters
self.mc_agg = mc_agg
self._clone_model_finals = False
super().__init__()

@abstractmethod
Expand Down Expand Up @@ -547,9 +549,118 @@ def _prefit(self, Y, T, *args, only_final=False, **kwargs):
if not only_final:
# generate an instance of the nuisance model
self._ortho_learner_model_nuisance = self._gen_ortho_learner_model_nuisance()

super()._prefit(Y, T, *args, **kwargs)

def _gen_cloned_ortho_learner_model_finals(self, num_clone_final_models):
self._clone_model_finals = True
self._cloned_final_models = [clone(self._ortho_learner_model_final, safe=False) for _ in range(num_clone_final_models)]
self._current_cloned_index = 0
self._current_cloned_final_model = self._cloned_final_models[self._current_cloned_index]

def _set_current_cloned_ortho_learner_model_final(self, clone_index):
self._current_cloned_index = clone_index
self._current_cloned_final_model = self._cloned_final_models[self._current_cloned_index]
self._ortho_learner_model_final = self._current_cloned_final_model

def _fit_compute_final_T(self, cached_values):
final_T = cached_values.T
if self.transformer:
if (self.discrete_treatment):
final_T = self.transformer.transform(final_T.reshape(-1, 1))
else: # treatment featurizer case
final_T = cached_values.output_T
return cached_values._replace(T=final_T)

def _fit_cached_values(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, inference=None, only_final=False, check_input=True):
if check_input:
Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)
self._check_input_dims(Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)
output_T = None
if self.discrete_treatment:
categories = self.categories
if categories != 'auto':
categories = [categories] # OneHotEncoder expects a 2D array with features per column
self.transformer = OneHotEncoder(categories=categories, sparse=False, drop='first')
self.transformer.fit(reshape(T, (-1, 1)))
self._d_t = (len(self.transformer.categories_[0]) - 1,)
elif self.treatment_featurizer:
self._original_treatment_featurizer = clone(self.treatment_featurizer, safe=False)
self.transformer = jacify_featurizer(self.treatment_featurizer)
output_T = self.transformer.fit_transform(T)
self._d_t = np.shape(output_T)[1:]
else:
self.transformer = None

if self.discrete_instrument:
self.z_transformer = OneHotEncoder(categories='auto', sparse=False, drop='first')
self.z_transformer.fit(reshape(Z, (-1, 1)))
else:
self.z_transformer = None
all_nuisances = []
fitted_inds = None
if sample_weight is None:
if freq_weight is not None:
sample_weight_nuisances = freq_weight
else:
sample_weight_nuisances = None
else:
if freq_weight is not None:
sample_weight_nuisances = freq_weight * sample_weight
else:
sample_weight_nuisances = sample_weight

self._models_nuisance = []
for idx in range(self.mc_iters or 1):
nuisances, fitted_models, new_inds, scores = self._fit_nuisances(
Y, T, X, W, Z, sample_weight=sample_weight_nuisances, groups=groups)
all_nuisances.append(nuisances)
self._models_nuisance.append(fitted_models)
if scores is None:
self.nuisance_scores_ = None
else:
if idx == 0:
self.nuisance_scores_ = tuple([] for _ in scores)
for ind, score in enumerate(scores):
self.nuisance_scores_[ind].append(score)
if fitted_inds is None:
fitted_inds = new_inds
elif not np.array_equal(fitted_inds, new_inds):
raise AttributeError("Different indices were fit by different folds, so they cannot be aggregated")

if self.mc_iters is not None:
if self.mc_agg == 'mean':
nuisances = tuple(np.mean(nuisance_mc_variants, axis=0)
for nuisance_mc_variants in zip(*all_nuisances))
elif self.mc_agg == 'median':
nuisances = tuple(np.median(nuisance_mc_variants, axis=0)
for nuisance_mc_variants in zip(*all_nuisances))
else:
raise ValueError(
"Parameter `mc_agg` must be one of {'mean', 'median'}. Got {}".format(self.mc_agg))

Y, T, X, W, Z, sample_weight, freq_weight, sample_var = (self._subinds_check_none(arr, fitted_inds)
for arr in (Y, T, X, W, Z, sample_weight,
freq_weight, sample_var))
nuisances = tuple([self._subinds_check_none(nuis, fitted_inds) for nuis in nuisances])
cached_values = CachedValues(nuisances=nuisances,
Y=Y, T=T, X=X, W=W, Z=Z,
sample_weight=sample_weight,
freq_weight=freq_weight,
sample_var=sample_var,
groups=groups,
output_T=output_T)
return cached_values

def _fit_init(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, inference=None, check_input=True):
self._random_state = check_random_state(self.random_state)
assert (freq_weight is None) == (
sample_var is None), "Sample variances and frequency weights must be provided together!"
assert not (self.discrete_treatment and self.treatment_featurizer), "Treatment featurization " \
"is not supported when treatment is discrete"

@BaseCateEstimator._wrap_fit
def fit(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, inference=None, only_final=False, check_input=True):
Expand Down Expand Up @@ -599,93 +710,19 @@ def fit(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=N
-------
self : object
"""
self._random_state = check_random_state(self.random_state)
assert (freq_weight is None) == (
sample_var is None), "Sample variances and frequency weights must be provided together!"
assert not (self.discrete_treatment and self.treatment_featurizer), "Treatment featurization " \
"is not supported when treatment is discrete"
if check_input:
Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)
self._check_input_dims(Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)

self._fit_init(Y=Y, T=T, X=X, W=W, Z=Z, sample_weight=sample_weight,
freq_weight=freq_weight, sample_var=sample_var, groups=groups,
cache_values=cache_values, inference=inference, check_input=check_input)
cached_values = None
if not only_final:

if self.discrete_treatment:
categories = self.categories
if categories != 'auto':
categories = [categories] # OneHotEncoder expects a 2D array with features per column
self.transformer = OneHotEncoder(categories=categories, sparse=False, drop='first')
self.transformer.fit(reshape(T, (-1, 1)))
self._d_t = (len(self.transformer.categories_[0]) - 1,)
elif self.treatment_featurizer:
self._original_treatment_featurizer = clone(self.treatment_featurizer, safe=False)
self.transformer = jacify_featurizer(self.treatment_featurizer)
output_T = self.transformer.fit_transform(T)
self._d_t = np.shape(output_T)[1:]
else:
self.transformer = None

if self.discrete_instrument:
self.z_transformer = OneHotEncoder(categories='auto', sparse=False, drop='first')
self.z_transformer.fit(reshape(Z, (-1, 1)))
else:
self.z_transformer = None

all_nuisances = []
fitted_inds = None
if sample_weight is None:
if freq_weight is not None:
sample_weight_nuisances = freq_weight
else:
sample_weight_nuisances = None
else:
if freq_weight is not None:
sample_weight_nuisances = freq_weight * sample_weight
else:
sample_weight_nuisances = sample_weight

self._models_nuisance = []
for idx in range(self.mc_iters or 1):
nuisances, fitted_models, new_inds, scores = self._fit_nuisances(
Y, T, X, W, Z, sample_weight=sample_weight_nuisances, groups=groups)
all_nuisances.append(nuisances)
self._models_nuisance.append(fitted_models)
if scores is None:
self.nuisance_scores_ = None
else:
if idx == 0:
self.nuisance_scores_ = tuple([] for _ in scores)
for ind, score in enumerate(scores):
self.nuisance_scores_[ind].append(score)
if fitted_inds is None:
fitted_inds = new_inds
elif not np.array_equal(fitted_inds, new_inds):
raise AttributeError("Different indices were fit by different folds, so they cannot be aggregated")

if self.mc_iters is not None:
if self.mc_agg == 'mean':
nuisances = tuple(np.mean(nuisance_mc_variants, axis=0)
for nuisance_mc_variants in zip(*all_nuisances))
elif self.mc_agg == 'median':
nuisances = tuple(np.median(nuisance_mc_variants, axis=0)
for nuisance_mc_variants in zip(*all_nuisances))
else:
raise ValueError(
"Parameter `mc_agg` must be one of {'mean', 'median'}. Got {}".format(self.mc_agg))

Y, T, X, W, Z, sample_weight, freq_weight, sample_var = (self._subinds_check_none(arr, fitted_inds)
for arr in (Y, T, X, W, Z, sample_weight,
freq_weight, sample_var))
nuisances = tuple([self._subinds_check_none(nuis, fitted_inds) for nuis in nuisances])
self._cached_values = CachedValues(nuisances=nuisances,
Y=Y, T=T, X=X, W=W, Z=Z,
sample_weight=sample_weight,
freq_weight=freq_weight,
sample_var=sample_var,
groups=groups) if cache_values else None
cached_values = self._fit_cached_values(Y=Y, T=T, X=X, W=W, Z=Z,
sample_weight=sample_weight,
freq_weight=freq_weight,
sample_var=sample_var,
groups=groups)
self._cached_values = cached_values if cache_values else None
else:
nuisances = self._cached_values.nuisances
cached_values = self._cached_values
# _d_t is altered by fit nuisances to what prefit does. So we need to perform the same
# alteration even when we only want to fit_final.
if self.transformer is not None:
Expand All @@ -694,23 +731,8 @@ def fit(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=N
else:
output_T = self.transformer.fit_transform(T)
self._d_t = np.shape(output_T)[1:]

final_T = T
if self.transformer:
if (self.discrete_treatment):
final_T = self.transformer.transform(final_T.reshape(-1, 1))
else: # treatment featurizer case
final_T = output_T

self._fit_final(Y=Y,
T=final_T,
X=X, W=W, Z=Z,
nuisances=nuisances,
sample_weight=sample_weight,
freq_weight=freq_weight,
sample_var=sample_var,
groups=groups)

cached_values = self._fit_compute_final_T(cached_values)
self._fit_final(cached_values)
return self

@property
Expand Down Expand Up @@ -794,21 +816,43 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None, group
Y, T, X=X, W=W, Z=Z,
sample_weight=sample_weight, groups=groups)
return nuisances, fitted_models, fitted_inds, scores

def _fit_final(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None,
freq_weight=None, sample_var=None, groups=None):
self._ortho_learner_model_final.fit(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z,
nuisances=nuisances,
sample_weight=sample_weight,
freq_weight=freq_weight,
sample_var=sample_var,
groups=groups))

def _set_bootstrap_params(self, indices, n_jobs, verbose):
self._bootstrap_indices = indices
self._n_jobs = n_jobs
self._verbose = verbose

def _fit_final(self, cached_values, final_model=None):
if final_model is None:
final_model = self._ortho_learner_model_final
if self._clone_model_finals:
def fit(x, **kwargs):
x.fit(**filter_none_kwargs(**kwargs))
return x
cached_values_dict = cached_values._asdict()
del cached_values_dict["output_T"]

Parallel(n_jobs=self._n_jobs, prefer='threads', verbose=self._verbose)(
delayed(fit)(cloned_final_model,
**{arg: convertArg(cached_values_dict[arg], inds) for arg in cached_values_dict})
for inds, cloned_final_model in zip(self._bootstrap_indices, self._cloned_final_models)
)
final_model.fit(cached_values.Y, cached_values.T, **filter_none_kwargs(X=cached_values.X,
W=cached_values.W,
Z=cached_values.Z,
nuisances=cached_values.nuisances,
sample_weight=cached_values.sample_weight,
freq_weight=cached_values.freq_weight,
sample_var=cached_values.sample_var,
groups=cached_values.groups))
self.score_ = None
if hasattr(self._ortho_learner_model_final, 'score'):
self.score_ = self._ortho_learner_model_final.score(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z,
nuisances=nuisances,
sample_weight=sample_weight,
groups=groups))
if hasattr(final_model, 'score'):
self.score_ = final_model.score(cached_values.Y, cached_values.T, **filter_none_kwargs(X=cached_values.X,
W=cached_values.W,
Z=cached_values.Z,
nuisances=cached_values.nuisances,
sample_weight=cached_values.sample_weight,
groups=cached_values.groups))

def const_marginal_effect(self, X=None):
X, = check_input_arrays(X)
Expand Down
Loading