Skip to content

Commit

Permalink
rm pass-through method
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz committed Apr 8, 2024
1 parent 42d37fd commit d98789e
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,11 @@ class Circuit:
self.observable: list[Primitive] = [Z(0)]

@partial(vmap, in_axes=(None, None, 0))
def forward(self, param_values: Array, x: Array) -> Array:
def __call__(self, param_values: Array, x: Array) -> Array:
state = zero_state(self.n_qubits)
param_dict = {name: val for name, val in zip(self.param_names, param_values)}
return adjoint_expectation(state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}})

def __call__(self, param_values: Array, x: Array) -> Array:
return self.forward(param_values, x)

@property
def n_vparams(self) -> int:
Expand Down Expand Up @@ -271,11 +269,8 @@ class TotalMagnetization:
def __post_init__(self) -> None:
self.paulis = [Z(i) for i in range(self.n_qubits)]

def forward(self, state: Array, values: dict) -> Array:
return reduce(add, [apply_gate(state, pauli, values) for pauli in self.paulis])

def __call__(self, state: Array, values: dict) -> Array:
return self.forward(state, values)
return reduce(add, [apply_gate(state, pauli, values) for pauli in self.paulis])


@dataclass
Expand All @@ -290,7 +285,7 @@ class Circuit:
self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers)
self.observable = TotalMagnetization(self.n_qubits)

def forward(self, param_vals: Array, x: Array, y: Array) -> Array:
def __call__(self, param_vals: Array, x: Array, y: Array) -> Array:
state = zero_state(self.n_qubits)
param_dict = {name: val for name, val in zip(self.param_names, param_vals)}
out_state = apply_gate(
Expand All @@ -299,9 +294,6 @@ class Circuit:
projected_state = self.observable(state, param_dict)
return jnp.real(inner(out_state, projected_state))

def __call__(self, param_vals: Array, x: Array, y: Array) -> Array:
return self.forward(param_vals, x, y)

@property
def n_vparams(self) -> int:
return len(self.param_names)
Expand Down

0 comments on commit d98789e

Please sign in to comment.