Skip to content

Commit

Permalink
make unitary private
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles MOUSSA committed Feb 10, 2025
1 parent eafdb21 commit c66b6fd
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 15 deletions.
2 changes: 1 addition & 1 deletion horqrux/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class _HamiltonianEvolution(Primitive):
target: QubitSupport
control: QubitSupport

def unitary(self, values: dict[str, Array] = dict()) -> Array:
def _unitary(self, values: dict[str, Array] = dict()) -> Array:
return expm(values["hamiltonian"] * (-1j * values["time_evolution"]))


Expand Down
4 changes: 2 additions & 2 deletions horqrux/parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __iter__(self) -> Iterable:
def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
return cls(*children, *aux_data)

def unitary(self, values: dict[str, float] = dict()) -> Array:
def _unitary(self, values: dict[str, float] = dict()) -> Array:
return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(values))

def jacobian(self, values: dict[str, float] = dict()) -> Array:
Expand Down Expand Up @@ -141,7 +141,7 @@ def RZ(


class _PHASE(Parametric):
def unitary(self, values: dict[str, float] = dict()) -> Array:
def _unitary(self, values: dict[str, float] = dict()) -> Array:
u = jnp.eye(2, 2, dtype=default_dtype)
u = u.at[(1, 1)].set(jnp.exp(1.0j * self.parse_values(values)))
return u
Expand Down
6 changes: 3 additions & 3 deletions horqrux/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def tree_flatten(self) -> tuple[tuple, tuple[str, TargetQubits, ControlQubits, N
def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
return cls(*children, *aux_data)

def unitary(self, values: dict[str, float] = dict()) -> Array:
def _unitary(self, values: dict[str, float] = dict()) -> Array:
"""Obtain the base unitary from `generator_name`.
Args:
Expand All @@ -81,7 +81,7 @@ def dagger(self, values: dict[str, float] = dict()) -> Array:
Returns:
Array: The base unitary daggered from `generator_name`.
"""
return _dagger(self.unitary(values))
return _dagger(self._unitary(values))

def tensor(self, values: dict[str, float] = dict()) -> Array:
"""Obtain the unitary taking into account the qubit support for controlled operations.
Expand All @@ -92,7 +92,7 @@ def tensor(self, values: dict[str, float] = dict()) -> Array:
Returns:
Array: Unitary representation taking into account the qubit support.
"""
base_unitary = self.unitary(values)
base_unitary = self._unitary(values)
if is_controlled(self.control):
return controlled(base_unitary, self.target, self.control)
return base_unitary
Expand Down
2 changes: 1 addition & 1 deletion horqrux/shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def to_matrix(
observable.control == observable.parse_idx(none_like(observable.target)),
"Controlled gates cannot be promoted from observables to operations on the whole state vector",
)
unitary = observable.unitary(values=values)
unitary = observable._unitary(values=values)
target = observable.target[0][0]
identity = jnp.eye(2, dtype=unitary.dtype)
ops = [identity for _ in range(n_qubits)]
Expand Down
2 changes: 1 addition & 1 deletion horqrux/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def list(cls) -> list[str]:


class OperationType(StrEnum):
UNITARY = "unitary"
UNITARY = "_unitary"
DAGGER = "dagger"
JACOBIAN = "jacobian"

Expand Down
24 changes: 17 additions & 7 deletions tests/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from horqrux.apply import apply_gate, apply_operator
from horqrux.parametric import PHASE, RX, RY, RZ
from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z
from horqrux.utils import density_mat, equivalent_state, product_state, random_state
from horqrux.utils import OperationType, density_mat, equivalent_state, product_state, random_state

MAX_QUBITS = 7
PARAMETRIC_GATES = (RX, RY, RZ, PHASE)
Expand All @@ -31,7 +31,7 @@ def test_primitive(gate_fn: Callable) -> None:
# test density matrix is similar to pure state
dm = apply_operator(
density_mat(orig_state),
gate.unitary(),
gate._unitary(),
gate.target[0],
gate.control[0],
)
Expand All @@ -54,7 +54,7 @@ def test_controlled_primitive(gate_fn: Callable) -> None:
# test density matrix is similar to pure state
dm = apply_operator(
density_mat(orig_state),
gate.unitary(),
gate._unitary(),
gate.target[0],
gate.control[0],
)
Expand All @@ -75,7 +75,7 @@ def test_parametric(gate_fn: Callable) -> None:
# test density matrix is similar to pure state
dm = apply_operator(
density_mat(orig_state),
gate.unitary(values),
gate._unitary(values),
gate.target[0],
gate.control[0],
)
Expand All @@ -99,7 +99,7 @@ def test_controlled_parametric(gate_fn: Callable) -> None:
# test density matrix is similar to pure state
dm = apply_operator(
density_mat(orig_state),
gate.unitary(values),
gate._unitary(values),
gate.target[0],
gate.control[0],
)
Expand Down Expand Up @@ -149,10 +149,20 @@ def test_merge_gates() -> None:
"c": np.random.uniform(0.1, 2 * np.pi),
}
state_grouped = apply_gate(
product_state("0000"), gates, values, "unitary", group_gates=True, merge_ops=True
product_state("0000"),
gates,
values,
OperationType.UNITARY,
group_gates=True,
merge_ops=True,
)
state = apply_gate(
product_state("0000"), gates, values, "unitary", group_gates=False, merge_ops=False
product_state("0000"),
gates,
values,
OperationType.UNITARY,
group_gates=False,
merge_ops=False,
)
assert jnp.allclose(state_grouped, state)

Expand Down

0 comments on commit c66b6fd

Please sign in to comment.