Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup sample and allow specifying compile_kwargs (several major changes related to step samplers) #7578

Merged
merged 11 commits into from
Nov 29, 2024
7 changes: 5 additions & 2 deletions pymc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
from pymc.backends.base import BaseTrace, IBaseTrace
from pymc.backends.ndarray import NDArray
from pymc.blocking import PointType
from pymc.model import Model
from pymc.step_methods.compound import BlockedStep, CompoundStep

Expand Down Expand Up @@ -100,11 +101,12 @@ def _init_trace(
trace: BaseTrace | None,
model: Model,
trace_vars: list[TensorVariable] | None = None,
initial_point: PointType | None = None,
) -> BaseTrace:
"""Initialize a trace backend for a chain."""
strace: BaseTrace
if trace is None:
strace = NDArray(model=model, vars=trace_vars)
strace = NDArray(model=model, vars=trace_vars, test_point=initial_point)
elif isinstance(trace, BaseTrace):
if len(trace) > 0:
raise ValueError("Continuation of traces is no longer supported.")
Expand All @@ -122,7 +124,7 @@ def init_traces(
chains: int,
expected_length: int,
step: BlockedStep | CompoundStep,
initial_point: Mapping[str, np.ndarray],
initial_point: PointType,
model: Model,
trace_vars: list[TensorVariable] | None = None,
) -> tuple[RunType | None, Sequence[IBaseTrace]]:
Expand All @@ -145,6 +147,7 @@ def init_traces(
trace=backend,
model=model,
trace_vars=trace_vars,
initial_point=initial_point,
)
for chain_number in range(chains)
]
Expand Down
53 changes: 37 additions & 16 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
)

import numpy as np
import pytensor

from pymc.backends.report import SamplerReport
from pymc.model import modelcontext
from pymc.pytensorf import compile_pymc
from pymc.util import get_var_name

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -147,32 +149,51 @@ class BaseTrace(IBaseTrace):
use different test point that might be with changed variables shapes
"""

def __init__(self, name, model=None, vars=None, test_point=None):
self.name = name

def __init__(
self,
name=None,
model=None,
vars=None,
test_point=None,
*,
fn=None,
var_shapes=None,
var_dtypes=None,
):
model = modelcontext(model)
self.model = model

if vars is None:
vars = model.unobserved_value_vars

unnamed_vars = {var for var in vars if var.name is None}
if unnamed_vars:
raise Exception(f"Can't trace unnamed variables: {unnamed_vars}")
self.vars = vars
self.varnames = [var.name for var in vars]
self.fn = model.compile_fn(vars, inputs=model.value_vars, on_unused_input="ignore")

if fn is None:
# borrow=True avoids deepcopy when inputs=output which is the case for untransformed value variables
fn = compile_pymc(
inputs=[pytensor.In(v, borrow=True) for v in model.value_vars],
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
outputs=[pytensor.Out(v, borrow=True) for v in vars],
on_unused_input="ignore",
)
fn.trust_input = True

# Get variable shapes. Most backends will need this
# information.
if test_point is None:
test_point = model.initial_point()
else:
test_point_ = model.initial_point().copy()
test_point_.update(test_point)
test_point = test_point_
var_values = list(zip(self.varnames, self.fn(test_point)))
self.var_shapes = {var: value.shape for var, value in var_values}
self.var_dtypes = {var: value.dtype for var, value in var_values}
if var_shapes is None or var_dtypes is None:
if test_point is None:
test_point = model.initial_point()
var_values = tuple(zip(vars, fn(**test_point)))
var_shapes = {var.name: value.shape for var, value in var_values}
var_dtypes = {var.name: value.dtype for var, value in var_values}

self.name = name
self.model = model
self.fn = fn
self.vars = vars
self.varnames = [var.name for var in vars]
self.var_shapes = var_shapes
self.var_dtypes = var_dtypes
self.chain = None
self._is_base_setup = False
self.sampler_vars = None
Expand Down
29 changes: 18 additions & 11 deletions pymc/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class NDArray(base.BaseTrace):
`model.unobserved_RVs` is used.
"""

def __init__(self, name=None, model=None, vars=None, test_point=None):
super().__init__(name, model, vars, test_point)
def __init__(self, name=None, model=None, vars=None, test_point=None, **kwargs):
super().__init__(name, model, vars, test_point, **kwargs)
self.draw_idx = 0
self.draws = None
self.samples = {}
Expand Down Expand Up @@ -76,7 +76,7 @@ def setup(self, draws, chain, sampler_vars=None) -> None:
else: # Otherwise, make array of zeros for each variable.
self.draws = draws
for varname, shape in self.var_shapes.items():
self.samples[varname] = np.zeros((draws, *shape), dtype=self.var_dtypes[varname])
self.samples[varname] = np.empty((draws, *shape), dtype=self.var_dtypes[varname])

if sampler_vars is None:
return
Expand Down Expand Up @@ -105,17 +105,18 @@ def record(self, point, sampler_stats=None) -> None:
point: dict
Values mapped to variable names
"""
for varname, value in zip(self.varnames, self.fn(point)):
self.samples[varname][self.draw_idx] = value
samples = self.samples
draw_idx = self.draw_idx
for varname, value in zip(self.varnames, self.fn(*point.values())):
samples[varname][draw_idx] = value

if self._stats is not None and sampler_stats is None:
raise ValueError("Expected sampler_stats")
if self._stats is None and sampler_stats is not None:
raise ValueError("Unknown sampler_stats")
if sampler_stats is not None:
for data, vars in zip(self._stats, sampler_stats):
for key, val in vars.items():
data[key][self.draw_idx] = val
data[key][draw_idx] = val
elif self._stats is not None:
raise ValueError("Expected sampler_stats")

self.draw_idx += 1

def _get_sampler_stats(
Expand Down Expand Up @@ -166,7 +167,13 @@ def _slice(self, idx: slice):
# Only the first `draw_idx` value are valid because of preallocation
idx = slice(*idx.indices(len(self)))

sliced = NDArray(model=self.model, vars=self.vars)
sliced = type(self)(
model=self.model,
vars=self.vars,
fn=self.fn,
var_shapes=self.var_shapes,
var_dtypes=self.var_dtypes,
)
sliced.chain = self.chain
sliced.samples = {varname: values[idx] for varname, values in self.samples.items()}
sliced.sampler_vars = self.sampler_vars
Expand Down
25 changes: 10 additions & 15 deletions pymc/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@
StatShape: TypeAlias = Sequence[int | None] | None


# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
# `point_map_info` is a tuple of tuples containing `(name, shape, size, dtype)` for
# each of the raveled variables.
class RaveledVars(NamedTuple):
data: np.ndarray
point_map_info: tuple[tuple[str, tuple[int, ...], np.dtype], ...]
point_map_info: tuple[tuple[str, tuple[int, ...], int, np.dtype], ...]


class Compose(Generic[T]):
Expand All @@ -67,10 +67,9 @@ class DictToArrayBijection:
@staticmethod
def map(var_dict: PointType) -> RaveledVars:
"""Map a dictionary of names and variables to a concatenated 1D array space."""
vars_info = tuple((v, k, v.shape, v.dtype) for k, v in var_dict.items())
raveled_vars = [v[0].ravel() for v in vars_info]
if raveled_vars:
result = np.concatenate(raveled_vars)
vars_info = tuple((v, k, v.shape, v.size, v.dtype) for k, v in var_dict.items())
if vars_info:
result = np.concatenate(tuple(v[0].ravel() for v in vars_info))
else:
result = np.array([])
return RaveledVars(result, tuple(v[1:] for v in vars_info))
Expand All @@ -91,19 +90,15 @@ def rmap(

"""
if start_point:
result = dict(start_point)
result = start_point.copy()
else:
result = {}

if not isinstance(array, RaveledVars):
raise TypeError("`array` must be a `RaveledVars` type")

last_idx = 0
for name, shape, dtype in array.point_map_info:
arr_len = np.prod(shape, dtype=int)
var = array.data[last_idx : last_idx + arr_len].reshape(shape).astype(dtype)
result[name] = var
last_idx += arr_len
for name, shape, size, dtype in array.point_map_info:
end = last_idx + size
result[name] = array.data[last_idx:end].reshape(shape).astype(dtype)
last_idx = end

return result

Expand Down
82 changes: 53 additions & 29 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
ShapeError,
ShapeWarning,
)
from pymc.initial_point import make_initial_point_fn
from pymc.initial_point import PointType, make_initial_point_fn
from pymc.logprob.basic import transformed_conditional_logp
from pymc.logprob.transforms import Transform
from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values
Expand All @@ -61,6 +61,7 @@
gradient,
hessian,
inputvars,
join_nonshared_inputs,
rewrite_pregrad,
)
from pymc.util import (
Expand Down Expand Up @@ -172,6 +173,9 @@
dtype=None,
casting="no",
compute_grads=True,
model=None,
initial_point: PointType | None = None,
ravel_inputs: bool | None = None,
**kwargs,
):
if extra_vars_and_values is None:
Expand Down Expand Up @@ -219,9 +223,7 @@
givens = []
self._extra_vars_shared = {}
for var, value in extra_vars_and_values.items():
shared = pytensor.shared(
value, var.name + "_shared__", shape=[1 if s == 1 else None for s in value.shape]
)
shared = pytensor.shared(value, var.name + "_shared__", shape=value.shape)
self._extra_vars_shared[var.name] = shared
givens.append((var, shared))

Expand All @@ -231,13 +233,28 @@
grads = pytensor.grad(cost, grad_vars, disconnected_inputs="ignore")
for grad_wrt, var in zip(grads, grad_vars):
grad_wrt.name = f"{var.name}_grad"
outputs = [cost, *grads]
grads = pt.join(0, *[pt.atleast_1d(grad.ravel()) for grad in grads])
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
outputs = [cost, grads]
else:
outputs = [cost]

inputs = grad_vars
if ravel_inputs:
if initial_point is None:
initial_point = modelcontext(model).initial_point()

Check warning on line 243 in pymc/model/core.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/core.py#L243

Added line #L243 was not covered by tests
outputs, raveled_grad_vars = join_nonshared_inputs(
point=initial_point, inputs=grad_vars, outputs=outputs, make_inputs_shared=False
)
inputs = [raveled_grad_vars]
else:
if ravel_inputs is None:
warnings.warn(
"ValueGradFunction will become a function of raveled inputs.\n"
"Specify `ravel_inputs` to suppress this warning. Note that setting `ravel_inputs=False` will be forbidden in a future release."
)
inputs = grad_vars

self._pytensor_function = compile_pymc(inputs, outputs, givens=givens, **kwargs)
self._raveled_inputs = ravel_inputs

def set_weights(self, values):
if values.shape != (self._n_costs - 1,):
Expand All @@ -247,38 +264,29 @@
def set_extra_values(self, extra_vars):
self._extra_are_set = True
for var in self._extra_vars:
self._extra_vars_shared[var.name].set_value(extra_vars[var.name])
self._extra_vars_shared[var.name].set_value(extra_vars[var.name], borrow=True)

def get_extra_values(self):
if not self._extra_are_set:
raise ValueError("Extra values are not set.")

return {var.name: self._extra_vars_shared[var.name].get_value() for var in self._extra_vars}

def __call__(self, grad_vars, grad_out=None, extra_vars=None):
def __call__(self, grad_vars, *, extra_vars=None):
if extra_vars is not None:
self.set_extra_values(extra_vars)

if not self._extra_are_set:
elif not self._extra_are_set:
raise ValueError("Extra values are not set.")

if isinstance(grad_vars, RaveledVars):
grad_vars = list(DictToArrayBijection.rmap(grad_vars).values())

cost, *grads = self._pytensor_function(*grad_vars)

if grads:
grads_raveled = DictToArrayBijection.map(
{v.name: gv for v, gv in zip(self._grad_vars, grads)}
)

if grad_out is None:
return cost, grads_raveled.data
if self._raveled_inputs:
grad_vars = (grad_vars.data,)
else:
np.copyto(grad_out, grads_raveled.data)
return cost
else:
return cost
grad_vars = DictToArrayBijection.rmap(grad_vars).values()
elif self._raveled_inputs and not isinstance(grad_vars, Sequence):
grad_vars = (grad_vars,)

return self._pytensor_function(*grad_vars)

@property
def profile(self):
Expand Down Expand Up @@ -521,7 +529,14 @@
def isroot(self):
return self.parent is None

def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
def logp_dlogp_function(
self,
grad_vars=None,
tempered=False,
initial_point: PointType | None = None,
ravel_inputs: bool | None = None,
**kwargs,
):
"""Compile a PyTensor function that computes logp and gradient.

Parameters
Expand All @@ -547,13 +562,22 @@
costs = [self.logp()]

input_vars = {i for i in graph_inputs(costs) if not isinstance(i, Constant)}
ip = self.initial_point(0)
if initial_point is None:
initial_point = self.initial_point(0)
extra_vars_and_values = {
var: ip[var.name]
var: initial_point[var.name]
for var in self.value_vars
if var in input_vars and var not in grad_vars
}
return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs)
return ValueGradFunction(
costs,
grad_vars,
extra_vars_and_values,
model=self,
initial_point=initial_point,
ravel_inputs=ravel_inputs,
**kwargs,
)

def compile_logp(
self,
Expand Down
Loading