Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transforms can be captured with PLxPR #6633

Merged
merged 20 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@

<h4>Capturing and representing hybrid programs</h4>

* PennyLane transforms can now be captured as primitived with experimental program capture enabled.
[(#6633)](https://github.com/PennyLaneAI/pennylane/pull/6633)
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved

* `jax.vmap` can be captured with `qml.capture.make_plxpr` and is compatible with quantum circuits.
[(#6349)](https://github.com/PennyLaneAI/pennylane/pull/6349)
[(#6422)](https://github.com/PennyLaneAI/pennylane/pull/6422)
Expand Down
3 changes: 3 additions & 0 deletions pennylane/transforms/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
from .transform_dispatcher import TransformDispatcher, TransformError


def transform(

Check notice on line 23 in pennylane/transforms/core/transform.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/core/transform.py#L23

Too many positional arguments (7/5) (too-many-positional-arguments)
quantum_transform,
expand_transform=None,
classical_cotransform=None,
is_informative=False,
final_transform=False,
use_argnum_in_expand=False,
plxpr_transform=None,
): # pylint: disable=too-many-arguments
"""Generalizes a function that transforms tapes to work with additional circuit-like objects such as a
:class:`~.QNode`.
Expand Down Expand Up @@ -59,6 +60,7 @@
of the transform program. ``is_informative`` supersedes ``final_transform``.
use_argnum_in_expand=False (bool): Whether or not to use ``argnum`` of the tape to determine trainable parameters
during the expansion transform process.
plxpr_transform=None (Optional[Callable]): Function for processing primitives when transforming PLxPR.
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved

Returns:

Expand Down Expand Up @@ -240,4 +242,5 @@
is_informative=is_informative,
final_transform=final_transform,
use_argnum_in_expand=use_argnum_in_expand,
plxpr_transform=plxpr_transform,
)
135 changes: 121 additions & 14 deletions pennylane/transforms/core/transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ class TransformError(Exception):
"""Raised when there is an error with the transform logic."""


class TransformDispatcher:
def _default_plxpr_transform(transform_name): # pylint: disable=missing-function-docstring
def wrapper(*_, **__):
raise TransformError(f"{transform_name} cannot be used to transform PLxPR.")

return wrapper


class TransformDispatcher: # pylint: disable=too-many-instance-attributes
r"""Converts a transform that has the signature ``(tape -> Sequence(tape), fn)`` to a transform dispatcher
that can act on :class:`pennylane.tape.QuantumTape`, quantum function, :class:`pennylane.QNode`,
:class:`pennylane.devices.Device`.
Expand Down Expand Up @@ -62,7 +69,7 @@ def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument

return super().__new__(cls)

# pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments,too-many-positional-arguments
def __init__(
self,
transform,
Expand All @@ -71,6 +78,7 @@ def __init__(
is_informative=False,
final_transform=False,
use_argnum_in_expand=False,
plxpr_transform=None,
): # pylint:disable=redefined-outer-name
self._transform = transform
self._expand_transform = expand_transform
Expand All @@ -79,7 +87,12 @@ def __init__(
# is_informative supersedes final_transform
self._final_transform = is_informative or final_transform
self._qnode_transform = self.default_qnode_transform
self._plxpr_transform = plxpr_transform or _default_plxpr_transform(
self._transform.__name__
)

self._use_argnum_in_expand = use_argnum_in_expand
self._primitive = _create_transform_primitive(self._transform.__name__)
functools.update_wrapper(self, transform)

def __call__(self, *targs, **tkwargs): # pylint: disable=too-many-return-statements
Expand All @@ -94,19 +107,19 @@ def __call__(self, *targs, **tkwargs): # pylint: disable=too-many-return-statem
if self._expand_transform:
expanded_tapes, expand_processing = self._expand_transform(obj, *targs, **tkwargs)
transformed_tapes = []
processing_and_sclices = []
processing_and_slices = []
start = 0
for tape in expanded_tapes:
intermediate_tapes, post_processing_fn = self._transform(
tape, *targs, **tkwargs
)
transformed_tapes.extend(intermediate_tapes)
end = start + len(intermediate_tapes)
processing_and_sclices.append(tuple([post_processing_fn, slice(start, end)]))
processing_and_slices.append(tuple([post_processing_fn, slice(start, end)]))
start = end

def processing_fn(results):
processed_results = [fn(results[slice]) for fn, slice in processing_and_sclices]
processed_results = [fn(results[slice]) for fn, slice in processing_and_slices]
return expand_processing(processed_results)

else:
Expand All @@ -117,17 +130,23 @@ def processing_fn(results):
return transformed_tapes, processing_fn

if isinstance(obj, qml.QNode):
if qml.capture.enabled():
return self._capture_callable_transform(obj, targs, tkwargs)
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
return self._qnode_transform(obj, targs, tkwargs)

if isinstance(obj, qml.devices.Device):
return self._device_transform(obj, targs, tkwargs)

if obj.__class__.__name__ == "QJIT":
raise TransformError(
"Functions that are wrapped / decorated with qjit cannot subsequently be"
f" transformed with a PennyLane transform (attempted {self})."
f" For the desired affect, ensure that qjit is applied after {self}."
)

if callable(obj):
return self._qfunc_transform(obj, targs, tkwargs)

if isinstance(obj, Sequence) and all(isinstance(q, qml.tape.QuantumScript) for q in obj):
return self._batch_transform(obj, targs, tkwargs)

Expand Down Expand Up @@ -214,23 +233,79 @@ def default_qnode_transform(self, qnode, targs, tkwargs):
if self.expand_transform:
qnode.add_transform(
TransformContainer(
self._expand_transform, targs, tkwargs, use_argnum=self._use_argnum_in_expand
self._expand_transform,
args=targs,
kwargs=tkwargs,
use_argnum=self._use_argnum_in_expand,
)
)
qnode.add_transform(
TransformContainer(
self._transform,
targs,
tkwargs,
self._classical_cotransform,
self._is_informative,
self._final_transform,
args=targs,
kwargs=tkwargs,
classical_cotransform=self._classical_cotransform,
plxpr_transform=self._plxpr_transform,
is_informative=self._is_informative,
final_transform=self._final_transform,
)
)
return qnode

def plxpr_transform(self, primitive, tracers, params, targs, tkwargs, state):
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
"""Function for processing primitives to transform PLxPR.

Args:
primitive (jax.core.Primitive): Primitive to transform
tracers (Sequence[jax.core.Tracer]): Input tracers to the primitive
params (dict): Dictionary containing keyword arguments/metadata for the primitive
targs (Sequence[Any]): Arguments for the transform
tkwargs (dict): Keyword arguments for the transform
state (dict): Dictionary containing auxiliary information about the environment/state
needed to apply the transform

Returns:
Any: The results of the transformed primitive
"""
# Implemented this way rather than using a property so that the correct docstring is used
return self._plxpr_transform(primitive, tracers, params, targs, tkwargs, state)

def _capture_callable_transform(self, qfunc, targs, tkwargs):
"""Apply the transform on a quantum function when program capture is enabled"""

@functools.wraps(qfunc)
def qfunc_transformed(*args, **kwargs):
import jax # pylint: disable=import-outside-toplevel

flat_qfunc = qml.capture.flatfn.FlatFn(qfunc)
jaxpr = jax.make_jaxpr(functools.partial(flat_qfunc, **kwargs))(*args)

n_args = len(args)
n_consts = len(jaxpr.consts)
args_slice = slice(0, n_args)
consts_slice = slice(n_args, n_args + n_consts)
targs_slice = slice(n_args + n_consts, None)

results = self._primitive.bind(
*args,
*jaxpr.consts,
*targs,
inner_jaxpr=jaxpr.jaxpr,
args_slice=args_slice,
consts_slice=consts_slice,
targs_slice=targs_slice,
tkwargs=tkwargs,
)

assert flat_qfunc.out_tree is not None
return jax.tree_util.tree_unflatten(flat_qfunc.out_tree, results)

return qfunc_transformed

def _qfunc_transform(self, qfunc, targs, tkwargs):
"""Apply the transform on a quantum function."""
if qml.capture.enabled():
return self._capture_callable_transform(qfunc, targs, tkwargs)

@functools.wraps(qfunc)
def qfunc_transformed(*args, **kwargs):
Expand Down Expand Up @@ -298,7 +373,7 @@ def preprocess(
):
"""This function updates the original device transform program to be applied."""
program, config = self.original_device.preprocess(execution_config)
program.push_back(TransformContainer(self.transform, targs, tkwargs))
program.push_back(TransformContainer(self.transform, args=targs, kwargs=tkwargs))
return program, config

@property
Expand Down Expand Up @@ -351,7 +426,7 @@ def processing_fn(res: ResultBatch) -> ResultBatch:
return tuple(execution_tapes), processing_fn


class TransformContainer:
class TransformContainer: # pylint: disable=too-many-instance-attributes, too-many-positional-arguments
"""Class to store a quantum transform with its ``args``, ``kwargs`` and classical co-transforms. Use
:func:`~.pennylane.transform`.

Expand All @@ -370,14 +445,16 @@ def __init__(
args=None,
kwargs=None,
classical_cotransform=None,
plxpr_transform=None,
is_informative=False,
final_transform=False,
use_argnum=False,
): # pylint:disable=redefined-outer-name,too-many-arguments
): # pylint:disable=redefined-outer-name,too-many-arguments,too-many-positional-arguments
self._transform = transform
self._args = args or []
self._kwargs = kwargs or {}
self._classical_cotransform = classical_cotransform
self._plxpr_transform = plxpr_transform
self._is_informative = is_informative
self._final_transform = is_informative or final_transform
self._use_argnum = use_argnum
Expand All @@ -392,6 +469,7 @@ def __iter__(self):
self._args,
self._kwargs,
self._classical_cotransform,
self._plxpr_transform,
self._is_informative,
self.final_transform,
)
Expand Down Expand Up @@ -429,6 +507,11 @@ def classical_cotransform(self):
"""The stored quantum transform's classical co-transform."""
return self._classical_cotransform

@property
def plxpr_transform(self):
"""The stored quantum transform's PLxPR transform."""
return self._plxpr_transform

@property
def is_informative(self):
"""``True`` if the transform is informative."""
Expand All @@ -438,3 +521,27 @@ def is_informative(self):
def final_transform(self):
"""``True`` if the transform needs to be executed"""
return self._final_transform


@functools.lru_cache
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
def _create_transform_primitive(name):
try:
# pylint: disable=import-outside-toplevel
import jax
except ImportError:
return None

transform_prim = jax.core.Primitive(name)
transform_prim.multiple_results = True

@transform_prim.def_impl
def _(
*all_args, inner_jaxpr, args_slice, consts_slice, targs_slice, tkwargs
): # pylint: disable=unused-argument
raise NotImplementedError

@transform_prim.def_abstract_eval
def _(*_, inner_jaxpr, **__):
return [out.aval for out in inner_jaxpr.outvars]

return transform_prim
24 changes: 13 additions & 11 deletions pennylane/transforms/core/transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,12 @@ def add_transform(self, transform: TransformDispatcher, *targs, **tkwargs):
self.push_back(
TransformContainer(
transform.transform,
targs,
tkwargs,
transform.classical_cotransform,
transform.is_informative,
transform.final_transform,
args=targs,
kwargs=tkwargs,
classical_cotransform=transform.classical_cotransform,
plxpr_transform=transform.plxpr_transform,
is_informative=transform.is_informative,
final_transform=transform.final_transform,
)
)

Expand All @@ -376,11 +377,12 @@ def insert_front_transform(self, transform: TransformDispatcher, *targs, **tkwar
self.insert_front(
TransformContainer(
transform.transform,
targs,
tkwargs,
transform.classical_cotransform,
transform.is_informative,
transform.final_transform,
args=targs,
kwargs=tkwargs,
classical_cotransform=transform.classical_cotransform,
plxpr_transform=transform.plxpr_transform,
is_informative=transform.is_informative,
final_transform=transform.final_transform,
)
)

Expand Down Expand Up @@ -528,7 +530,7 @@ def __call__(
processing_fns_stack = []

for i, transform_container in enumerate(self):
transform, targs, tkwargs, cotransform, _, _ = transform_container
transform, targs, tkwargs, cotransform, _, _, _ = transform_container

execution_tapes = []
fns = []
Expand Down
Loading