Skip to content

Commit

Permalink
Fix bit pack of mwpf and fusion blossom decoders under multiple logic…
Browse files Browse the repository at this point in the history
…al observable (#873)

This PR fixed two bugs in MWPF decoder

## 1. Supporting decomposed detector error model

While MWPF expects a decoding hypergraph, the input detector error model
from sinter is by default decomposed. The decomposed DEM may contain the
same detector or logical observable multiple times, which is not
considered by the previous implementation.

The previous implementation assumes that each detector and logical
observable only appears once, thus, I used
```python
frames: List[int] = []
...
frames.append(t.val)
```

However, this no longer works if the same frame appears in multiple
decomposed parts. In this case, the DEM actually means that "the
hyperedge contributes to the logical observable iff count(frame) % 2 ==
1". This is fixed by
```python
frames: set[int] = set()
...
frames ^= { t.val }
```

## 2. Supporting multiple logical observables

Although a previous [PR
#864](#864) has fixed the panic
issue when multiple logical observables are encountered, the returned
value is actually problematic and causes significantly higher logical
error rate.

The previous implementation converts a `int` typed bitmask to a
bitpacked value using `np.packbits(prediction, bitorder="little")`.
However, this doesn't work for more than one logical observables.
For example, if I define an observable using `OBSERVABLE_INCLUDE(2)
...`, supposedly the bitpacked value should be `[4]` because $1<<2 = 4$.
However, `np.packbits(4, bitorder="little") = [1]`, which is incorrect.

The correct procedure is first generate the binary representation with
`self.num_obs` bits using `np.binary_repr(prediction,
width=self.num_obs)`, in this case, `'100'`, and then revert the order
of the bits to `['0', '0', '1']`, and then run the packbits which gives
us the correct value `[4]`.

The full code is below:
```python
predictions[shot] = np.packbits(
    np.array(list(np.binary_repr(prediction, width=self.num_obs))[::-1],dtype=np.uint8),
    bitorder="little",
)
```
  • Loading branch information
yuewuo authored Jan 31, 2025
1 parent 6afad14 commit 9811a6a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 54 deletions.
5 changes: 4 additions & 1 deletion glue/sample/src/sinter/_decoding/_decoding_fusion_blossom.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def decode_shots_bit_packed(
syndrome = fusion_blossom.SyndromePattern(syndrome_vertices=dets_sparse)
self.solver.solve(syndrome)
prediction = int(np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()]))
predictions[shot] = np.packbits(prediction, bitorder='little')
predictions[shot] = np.packbits(
np.array(list(np.binary_repr(prediction, width=self.num_obs))[::-1],dtype=np.uint8),
bitorder="little",
)
self.solver.clear()
return predictions

Expand Down
88 changes: 35 additions & 53 deletions glue/sample/src/sinter/_decoding/_decoding_mwpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def decode_shots_bit_packed(
bit_packed_detection_event_data: "np.ndarray",
) -> "np.ndarray":
num_shots = bit_packed_detection_event_data.shape[0]
predictions = np.zeros(shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8)
predictions = np.zeros(
shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8
)
import mwpf

for shot in range(num_shots):
Expand All @@ -58,29 +60,42 @@ def decode_shots_bit_packed(
np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()])
)
self.solver.clear()
predictions[shot] = np.packbits(prediction, bitorder="little")
predictions[shot] = np.packbits(
np.array(
list(np.binary_repr(prediction, width=self.num_obs))[::-1],
dtype=np.uint8,
),
bitorder="little",
)
return predictions


class MwpfDecoder(Decoder):
"""Use MWPF to predict observables from detection events."""

def compile_decoder_for_dem(
def __init__(
self,
*,
dem: "stim.DetectorErrorModel",
decoder_cls: Any = None, # decoder class used to construct the MWPF decoder.
# in the Rust implementation, all of them inherits from the class of `SolverSerialPlugins`
# but just provide different plugins for optimizing the primal and/or dual solutions.
# For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only
# grows the clusters until the first valid solution appears; some more optimized solvers uses
# one or more plugins to further optimize the solution, which requires longer decoding time.
cluster_node_limit: int = 50, # The maximum number of nodes in a cluster.
cluster_node_limit: int = 50, # The maximum number of nodes in a cluster,
):
self.decoder_cls = decoder_cls
self.cluster_node_limit = cluster_node_limit
super().__init__()

def compile_decoder_for_dem(
self,
*,
dem: "stim.DetectorErrorModel",
) -> CompiledDecoder:
solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
dem,
decoder_cls=decoder_cls,
cluster_node_limit=cluster_node_limit,
decoder_cls=self.decoder_cls,
cluster_node_limit=self.cluster_node_limit,
)
return MwpfCompiledDecoder(
solver,
Expand All @@ -99,13 +114,14 @@ def decode_via_files(
dets_b8_in_path: pathlib.Path,
obs_predictions_b8_out_path: pathlib.Path,
tmp_dir: pathlib.Path,
decoder_cls: Any = None,
) -> None:
import mwpf

error_model = stim.DetectorErrorModel.from_file(dem_path)
solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
error_model, decoder_cls=decoder_cls
error_model,
decoder_cls=self.decoder_cls,
cluster_node_limit=self.cluster_node_limit,
)
num_det_bytes = math.ceil(num_dets / 8)
with open(dets_b8_in_path, "rb") as dets_in_f:
Expand Down Expand Up @@ -136,44 +152,8 @@ def decode_via_files(


class HyperUFDecoder(MwpfDecoder):
def compile_decoder_for_dem(
self, *, dem: "stim.DetectorErrorModel"
) -> CompiledDecoder:
try:
import mwpf
except ImportError as ex:
raise mwpf_import_error() from ex

return super().compile_decoder_for_dem(
dem=dem, decoder_cls=mwpf.SolverSerialUnionFind
)

def decode_via_files(
self,
*,
num_shots: int,
num_dets: int,
num_obs: int,
dem_path: pathlib.Path,
dets_b8_in_path: pathlib.Path,
obs_predictions_b8_out_path: pathlib.Path,
tmp_dir: pathlib.Path,
) -> None:
try:
import mwpf
except ImportError as ex:
raise mwpf_import_error() from ex

return super().decode_via_files(
num_shots=num_shots,
num_dets=num_dets,
num_obs=num_obs,
dem_path=dem_path,
dets_b8_in_path=dets_b8_in_path,
obs_predictions_b8_out_path=obs_predictions_b8_out_path,
tmp_dir=tmp_dir,
decoder_cls=mwpf.SolverSerialUnionFind,
)
def __init__(self):
super().__init__(decoder_cls="SolverSerialUnionFind", cluster_node_limit=0)


def iter_flatten_model(
Expand All @@ -193,16 +173,16 @@ def _helper(m: stim.DetectorErrorModel, reps: int):
_helper(instruction.body_copy(), instruction.repeat_count)
elif isinstance(instruction, stim.DemInstruction):
if instruction.type == "error":
dets: List[int] = []
frames: List[int] = []
dets: set[int] = set()
frames: set[int] = set()
t: stim.DemTarget
p = instruction.args_copy()[0]
for t in instruction.targets_copy():
if t.is_relative_detector_id():
dets.append(t.val + det_offset)
dets ^= {t.val + det_offset}
elif t.is_logical_observable_id():
frames.append(t.val)
handle_error(p, dets, frames)
frames ^= {t.val}
handle_error(p, list(dets), list(frames))
elif instruction.type == "shift_detectors":
det_offset += instruction.targets_copy()[0]
a = np.array(instruction.args_copy())
Expand Down Expand Up @@ -310,6 +290,8 @@ def handle_detector_coords(detector: int, coords: np.ndarray):
if decoder_cls is None:
# default to the solver with highest accuracy
decoder_cls = mwpf.SolverSerialJointSingleHair
elif isinstance(decoder_cls, str):
decoder_cls = getattr(mwpf, decoder_cls)
return (
(
decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit})
Expand Down

0 comments on commit 9811a6a

Please sign in to comment.