Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add autograph support for for loops #6426

Merged
merged 96 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
5dfc38f
add user decompose function
andrijapau Oct 3, 2024
f1d9eee
code factor
andrijapau Oct 3, 2024
b24c97b
update changelog
andrijapau Oct 3, 2024
0d3bfb6
improve decomposition function
andrijapau Oct 4, 2024
0d66209
codefactor fix
andrijapau Oct 4, 2024
a25d6ba
added modified generator
andrijapau Oct 4, 2024
2ad96ac
copy code from catalyst
lillian542 Oct 4, 2024
c471c0b
remove catalyst from __init__ and transformer files
lillian542 Oct 4, 2024
49c0fac
replace many catalyst dependencies/mentions
lillian542 Oct 7, 2024
5dd1d42
copy tests from catalyst
lillian542 Oct 7, 2024
c94da87
update imports
lillian542 Oct 7, 2024
ab8ae7a
[ci skip]
lillian542 Oct 8, 2024
ac7d53e
Merge branch 'master' into add_autograph
lillian542 Oct 8, 2024
4a6bdf3
import autograph module in capture
lillian542 Oct 8, 2024
2c2f1dd
remove final catalyst dependencies from ag_primitives
lillian542 Oct 8, 2024
9b18b97
record queueing in FlatFn so we can access program_length in ag_primi…
lillian542 Oct 8, 2024
cf92f2a
most conditional tests work
lillian542 Oct 9, 2024
7a21764
initial removal of queueing dependency
lillian542 Oct 10, 2024
72b3e42
the current state of the tests
lillian542 Oct 10, 2024
fee16ff
remove index setting and logical operators
lillian542 Oct 11, 2024
3ccd0e0
formatting
lillian542 Oct 11, 2024
1ab03f3
remove fallback
lillian542 Oct 12, 2024
2bfaf72
remove disable_autograph
lillian542 Oct 12, 2024
42ce1e8
remove autograph_include
lillian542 Oct 12, 2024
d3a807c
Merge branch 'master' into autograph_ctrl_flow
lillian542 Oct 12, 2024
fa0b415
remove strict_conversion and ignore_fallbacks
lillian542 Oct 12, 2024
54b3036
Tidying up
lillian542 Oct 12, 2024
26cbde4
Tidying up more
lillian542 Oct 12, 2024
211bd81
fix more tests
lillian542 Oct 16, 2024
6be1ff5
Merge branch 'autograph_ctrl_flow' of github.com:PennyLaneAI/pennylan…
lillian542 Oct 16, 2024
75b0c3c
only cond in ag_primitives
lillian542 Oct 16, 2024
5bfecb4
clean up transformer and utils
lillian542 Oct 16, 2024
ac51fad
re-organize tests
lillian542 Oct 16, 2024
d4d7220
remove utils and clean up docstrings
lillian542 Oct 16, 2024
207ecc6
update changelog
lillian542 Oct 16, 2024
d2f6e34
add malt as dependency of PL
lillian542 Oct 16, 2024
da0cd1d
fix failing test
lillian542 Oct 16, 2024
cea27f6
Merge branch 'master' into autograph1
lillian542 Oct 16, 2024
9da4afa
package name is diastatic-malt
lillian542 Oct 16, 2024
fbff208
Merge branch 'autograph1' of github.com:PennyLaneAI/pennylane into au…
lillian542 Oct 16, 2024
e443173
fix decorator test
lillian542 Oct 16, 2024
2ee18ca
rename test file to avoid CI confusion
lillian542 Oct 16, 2024
3ab7c9c
add while_loop implementation
lillian542 Oct 16, 2024
900cfe1
add test file
lillian542 Oct 17, 2024
fc59744
a few more tests and docstrings updates
lillian542 Oct 17, 2024
a8946af
one more test for code coverage
lillian542 Oct 17, 2024
908d6b1
Merge branch 'master' into autograph1
lillian542 Oct 17, 2024
60cb38b
Update pennylane/capture/autograph/ag_primitives.py
lillian542 Oct 17, 2024
9c8c74d
add initial tests
lillian542 Oct 17, 2024
bfc39b2
Merge branch 'autograph1' into autograph_while_loop
lillian542 Oct 17, 2024
3972b89
some small test changes
lillian542 Oct 17, 2024
ef2d503
update changelog
lillian542 Oct 21, 2024
19130e9
xfail test that includes for loop
lillian542 Oct 21, 2024
4b73551
add for loop support and tests
lillian542 Oct 21, 2024
d5eaa8d
Merge branch 'master' into autograph1
lillian542 Oct 28, 2024
7da5662
Merge branch 'autograph1' into autograph_while_loop
lillian542 Oct 28, 2024
a03acf8
Merge branch 'master' into autograph1
lillian542 Nov 21, 2024
1a01dea
Apply suggestions from code review
lillian542 Nov 21, 2024
42c91f9
use inner_args to avoid taken arguments
lillian542 Nov 21, 2024
78307b5
update copyright year
lillian542 Nov 21, 2024
e36fa65
use functools.wraps
lillian542 Nov 21, 2024
998ce1a
replace qjit example with pl example
lillian542 Nov 21, 2024
83a934e
add import path for run_autograph, autograph_source
lillian542 Nov 21, 2024
36267d8
change import structure + update example
lillian542 Nov 21, 2024
c06e768
fix a couple docstring mistakes
lillian542 Nov 21, 2024
d6fae5d
Merge branch 'autograph1' into autograph_while_loop
lillian542 Nov 21, 2024
a63009d
Apply suggestions from code review
lillian542 Nov 21, 2024
793e7f4
remove unneeded check
lillian542 Nov 26, 2024
1f1f097
small test fixes
lillian542 Nov 26, 2024
705eb35
Merge branch 'autograph_while_loop' of github.com:PennyLaneAI/pennyla…
lillian542 Nov 26, 2024
c387988
Merge branch 'master' into autograph1
lillian542 Nov 26, 2024
84ca81b
Merge branch 'autograph1' into autograph_while_loop
lillian542 Nov 26, 2024
e8f23ab
Merge branch 'autograph_while_loop' into autograph_for_loop
lillian542 Nov 26, 2024
ed9ea11
remove source_info function
lillian542 Nov 26, 2024
0faba0e
reoraganize and update tests
lillian542 Nov 26, 2024
4c5d969
clean up error msgs and docstrings
lillian542 Nov 26, 2024
b981ade
update tests
lillian542 Nov 26, 2024
6432807
update changelog
lillian542 Nov 26, 2024
3d6d40e
pylint complaint
lillian542 Nov 26, 2024
f7dbeda
Apply suggestions from code review
lillian542 Nov 27, 2024
9692d15
Update doc/releases/changelog-dev.md
lillian542 Nov 27, 2024
9bf9813
Apply suggestions from code review
lillian542 Nov 27, 2024
2df11e8
Apply suggestions from code review
lillian542 Nov 27, 2024
fd10c80
Update pennylane/capture/autograph/ag_primitives.py
lillian542 Nov 28, 2024
897d11f
Merge branch 'master' into autograph_while_loop
lillian542 Nov 30, 2024
c1660ed
add test that for_stmt redirects to pennylane implementation
lillian542 Nov 30, 2024
94b6b90
Merge branch 'autograph_while_loop' into autograph_for_loop
lillian542 Nov 30, 2024
3f4ace3
Merge branch 'master' into autograph_for_loop
lillian542 Nov 30, 2024
af9e5ae
black formatting
lillian542 Nov 30, 2024
8eec7f8
pylint
lillian542 Nov 30, 2024
e5d76fa
add space between no and cover in pragma: no cover
lillian542 Dec 3, 2024
d5158d8
Merge branch 'master' into autograph_for_loop
lillian542 Dec 3, 2024
b95e4af
trigger ci
lillian542 Dec 3, 2024
3c58be3
Merge branch 'master' into autograph_for_loop
lillian542 Dec 3, 2024
80b3a6e
also change the one I missed
lillian542 Dec 3, 2024
12e47b8
Merge branch 'autograph_for_loop' of github.com:PennyLaneAI/pennylane…
lillian542 Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ added `binary_mapping()` function to map `BoseWord` and `BoseSentence` to qubit
* Added submodule `devices.qubit_mixed.measure` as a necessary step for the new API, featuring a `measure` function for measuring qubits in mixed-state devices.
[(#6637)](https://github.com/PennyLaneAI/pennylane/pull/6637)

* Support is added for `if`/`else` statements and `while` loops in circuits executed with `qml.capture.enabled`, via `autograph`.
* Support is added for `if`/`else` statements and `for` and `while` loops in circuits executed with `qml.capture.enabled`, via `autograph`
[(#6406)](https://github.com/PennyLaneAI/pennylane/pull/6406)
[(#6413)](https://github.com/PennyLaneAI/pennylane/pull/6413)
[(#6426)](https://github.com/PennyLaneAI/pennylane/pull/6426)

* Added `christiansen_mapping()` function to map `BoseWord` and `BoseSentence` to qubit operators, using christiansen mapping.
[(#6623)](https://github.com/PennyLaneAI/pennylane/pull/6623)
Expand Down
245 changes: 242 additions & 3 deletions pennylane/capture/autograph/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,27 @@
"""
import copy
import functools
from typing import Any, Callable, Tuple
from typing import Any, Callable, Iterator, SupportsIndex, Tuple, Union

from malt.core import config as ag_config
from malt.impl import api as ag_api
from malt.impl.api import converted_call as ag_converted_call
from malt.operators import py_builtins as ag_py_builtins
from malt.operators.variables import Undefined

import pennylane as qml

has_jax = True
try:
import jax
import jax.numpy as jnp
except ImportError: # pragma: no cover
has_jax = False


__all__ = [
"if_stmt",
"for_stmt",
"while_stmt",
"converted_call",
]
Expand All @@ -58,8 +61,6 @@ def _assert_results(results, var_names):


# pylint: disable=too-many-arguments, too-many-positional-arguments


def if_stmt(
pred: bool,
true_fn: Callable[[], Any],
Expand Down Expand Up @@ -137,6 +138,157 @@ def _assert_iteration_inputs(inputs, symbol_names):
) from e


def _assert_iteration_results(inputs, outputs, symbol_names):
"""The results of a for loop should have the identical type as the inputs since they are
"passed" as inputs to the next iteration. A mismatch here may indicate that a loop-carried
variable was initialized with the wrong type.
"""

for i, (inp, out) in enumerate(zip(inputs, outputs)):
inp_t, out_t = jax.api_util.shaped_abstractify(inp), jax.api_util.shaped_abstractify(out)
if inp_t.dtype != out_t.dtype or inp_t.shape != out_t.shape:
raise AutoGraphError(
f"The variable '{symbol_names[i]}' was initialized with the wrong type, or you may "
f"be trying to change its type from one iteration to the next. "
f"Expected: {out_t}, Got: {inp_t}"
)


# pylint: disable=too-many-positional-arguments
def _call_pennylane_for(
start,
stop,
step,
body_fn,
get_state,
set_state,
symbol_names,
enum_start=None,
array_iterable=None,
):
"""Dispatch to a PennyLane implementation of for loops."""

# Ensure iteration arguments are properly initialized. We cannot process uninitialized
# loop carried values as we need their type information for tracing.
init_iter_args = get_state()
_assert_iteration_inputs(init_iter_args, symbol_names)

@qml.for_loop(start, stop, step)
def functional_for(i, *iter_args):
# Assign tracers to the iteration variables identified by AutoGraph (iter_args in mlir).
set_state(iter_args)

# The iteration index/element (for <...> in) is already handled by the body function, e.g.:
# def body_fn(itr):
# i, x = itr
# ...
if enum_start is None and array_iterable is None:
# for i in range(..)
body_fn(i)
elif enum_start is None:
# for x in array
body_fn(array_iterable[i])
else:
# for (i, x) in enumerate(array)
body_fn((i + enum_start, array_iterable[i]))

return get_state()

final_iter_args = functional_for(*init_iter_args)
_assert_iteration_results(init_iter_args, final_iter_args, symbol_names)
return final_iter_args


# pylint: disable=too-many-statements
def for_stmt(
iteration_target: Any,
_extra_test: Union[Callable[[], bool], None],
body_fn: Callable[[int], None],
get_state: Callable[[], Tuple],
set_state: Callable[[Tuple], None],
symbol_names: Tuple[str],
_opts: dict,
):
"""An implementation of the AutoGraph 'for .. in ..' statement. The interface is defined by
AutoGraph, here we merely provide an implementation of it in terms of PennyLane primitives."""

assert _extra_test is None

# The general approach is to convert as much code as possible into a graph-based form:
# - For loops over iterables will attempt a conversion of the iterable to array
# - For loops over a Python range will be converted to a native PennyLane for_loop. The now
# dynamic iteration variable can cause issues in downstream user code that raise an error.
# - For loops over a Python enumeration use a combination of the above, providing a dynamic
# iteration variable and conversion of the iterable to array.

# Any of these could fail depending on the compatibility of the user code. A failure could
# also occur because an exception is raised during the tracing of the loop body after conversion
# (for example because the user forgot to use a list instead of an array)
# The PennyLane autograph implementation does not currently fall back to a Python loop in this case,
# but this has been implemented in Catalyst and could be extended to this. It does, however, require an
# active qeueing context.

exception_raised = None
init_state = get_state()
assert len(init_state) == len(symbol_names)

if isinstance(iteration_target, PRange):
start, stop, step = iteration_target.get_raw_range()
enum_start = None
iteration_array = None
elif isinstance(iteration_target, PEnumerate):
start, stop, step = 0, len(iteration_target.iteration_target), 1
enum_start = iteration_target.start_idx
try:
iteration_array = jnp.asarray(iteration_target.iteration_target)
except Exception as e: # pylint: disable=bare-except, broad-exception-caught, broad-except
exception_raised = e
else:
start, stop, step = 0, len(iteration_target), 1
enum_start = None
try:
iteration_array = jnp.asarray(iteration_target)
except Exception as e: # pylint: disable=bare-except, broad-exception-caught, broad-except
exception_raised = e

if exception_raised:

raise AutoGraphError(
f"Could not convert the iteration target {iteration_target} to array while processing "
f"a for-loop with AutoGraph."
) from exception_raised

try:
set_state(init_state)
results = _call_pennylane_for(
start,
stop,
step,
body_fn,
get_state,
set_state,
symbol_names,
enum_start,
iteration_array,
)
except Exception as e: # pylint: disable=broad-exception-caught
# pylint: disable=import-outside-toplevel
import textwrap

raise AutoGraphError(
f"Tracing of an AutoGraph converted for loop failed with an exception:\n"
f" {type(e).__name__}:{textwrap.indent(str(e), ' ')}\n"
f"\n"
f"Make sure that loop variables are not used in tracing-incompatible ways, for instance "
f"by indexing a Python list with it (rather than a JAX array). Also ensure all variables "
f"are initialized before the loop begins, and that they don't change type across iterations.\n"
f"To understand different types of JAX tracing errors, please refer to the guide at: "
f"https://jax.readthedocs.io/en/latest/errors.html"
) from e

set_state(results)


def _call_pennylane_while(loop_test, loop_body, get_state, set_state, symbol_names):
"""Dispatch to a PennyLane implementation of while loops."""

Expand Down Expand Up @@ -226,6 +378,7 @@ def converted_call(fn, args, kwargs, caller_fn_scope=None, options=None):
with Patcher(
(ag_api, "_TRANSPILER", qml.capture.autograph.transformer.TRANSFORMER),
(ag_config, "CONVERSION_RULES", module_allowlist),
(ag_py_builtins, "BUILTIN_FUNCTIONS_MAP", py_builtins_map),
):
# HOTFIX: pass through calls of known PennyLane wrapper functions
if fn in (
Expand Down Expand Up @@ -264,3 +417,89 @@ def qnode_call_wrapper():
return new_qnode()

return ag_converted_call(fn, args, kwargs, caller_fn_scope, options)


class PRange:
"""PennyLane range object. This class re-implements the built-in range class
(which can't be inherited from). The only change is saving and accessing the
inputs directly, to circumvent some JAX-unfriendly code in the Python range.
"""

def __init__(self, start_stop, stop=None, step=None):
self._py_range = None
self._start = start_stop if stop is not None else 0
self._stop = stop if stop is not None else start_stop
self._step = step if step is not None else 1

def get_raw_range(self):
"""Get the raw values defining this range: start, stop, step."""
return self._start, self._stop, self._step

@property
def py_range(self):
"""Access the underlying Python range object. If it doesn't exist, create one."""
if self._py_range is None:
self._py_range = range(self._start, self._stop, self._step)
return self._py_range

# Interface of the Python range class.
# pylint: disable=missing-function-docstring

@property
def start(self) -> int: # pragma: no cover
return self.py_range.start

@property
def stop(self) -> int: # pragma: no cover
return self.py_range.stop

@property
def step(self) -> int: # pragma: no cover
return self.py_range.step

def count(self, __value: int) -> int: # pragma: no cover
return self.py_range.count(__value)

def index(self, __value: int) -> int: # pragma: no cover
return self.py_range.index(__value)

def __len__(self) -> int: # pragma: no cover
return self.py_range.__len__()

def __eq__(self, __value: object) -> bool: # pragma: no cover
return self.py_range.__eq__(__value)

def __hash__(self) -> int: # pragma: no cover
return self.py_range.__hash__()

def __contains__(self, __key: object) -> bool: # pragma: no cover
return self.py_range.__contains__(__key)

def __iter__(self) -> Iterator[int]: # pragma: no cover
return self.py_range.__iter__()

def __getitem__(
self, __key: Union[SupportsIndex, slice]
) -> Union[int, range]: # pragma: no cover
return self.py_range.__getitem__(__key)

def __reversed__(self) -> Iterator[int]: # pragma: no cover
return self.py_range.__reversed__()


# pylint: disable=too-few-public-methods, super-init-not-called
class PEnumerate(enumerate):
"""PennyLane enumeration object. Inherits from Python ``enumerate``, but adds storing the
input iteration_target and start_idx, which are used by the for-loop conversion.
"""

def __init__(self, iterable, start=0):
self.iteration_target = iterable
self.start_idx = start


py_builtins_map = {
**ag_py_builtins.BUILTIN_FUNCTIONS_MAP,
"range": PRange,
"enumerate": PEnumerate,
}
1 change: 1 addition & 0 deletions tests/capture/autograph/test_autograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def test_get_extra_locals(self):

assert ag_fn_dict["if_stmt"].__module__ == "pennylane.capture.autograph.ag_primitives"
assert ag_fn_dict["while_stmt"].__module__ == "pennylane.capture.autograph.ag_primitives"
assert ag_fn_dict["for_stmt"].__module__ == "pennylane.capture.autograph.ag_primitives"
assert (
ag_fn_dict["converted_call"].__module__ == "pennylane.capture.autograph.ag_primitives"
)
Expand Down
Loading
Loading