Skip to content

Commit

Permalink
Update Stim backend to support conditionals and mid-circuit measureme…
Browse files Browse the repository at this point in the history
…nts (#2270)

* Update Stim backend to support conditionals and mid-circuit measurements

Signed-off-by: Ben Howe <[email protected]>

* Slight optimization: transpose `sample`

Signed-off-by: Ben Howe <[email protected]>

---------

Signed-off-by: Ben Howe <[email protected]>
  • Loading branch information
bmhowe23 authored Oct 17, 2024
1 parent 5fbbfcf commit 71abfc7
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 33 deletions.
10 changes: 9 additions & 1 deletion python/tests/kernel/test_kernel_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,15 @@ def kernel(theta: float):
assert np.isclose(want_exp, -1.13, atol=1e-2)


def test_dynamic_circuit():
@pytest.mark.parametrize('target', ['default', 'stim'])
def test_dynamic_circuit(target):
"""Test that we correctly sample circuits with
mid-circuit measurements and conditionals."""

if target == 'stim':
save_target = cudaq.get_target()
cudaq.set_target('stim')

@cudaq.kernel
def simple():
q = cudaq.qvector(2)
Expand Down Expand Up @@ -297,6 +302,9 @@ def simple():
assert '0' in c0 and '1' in c0
assert '00' in counts and '11' in counts

if target == 'stim':
cudaq.set_target(save_target)


def test_teleport():

Expand Down
154 changes: 124 additions & 30 deletions runtime/nvqir/stim/StimCircuitSimulator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,88 @@ namespace nvqir {
/// https://github.com/quantumlib/Stim.
class StimCircuitSimulator : public nvqir::CircuitSimulatorBase<double> {
protected:
stim::Circuit stimCircuit;
// Follow Stim naming convention (W) for bit width (required for templates).
static constexpr std::size_t W = stim::MAX_BITWORD_WIDTH;

/// @brief Number of measurements performed so far.
std::size_t num_measurements = 0;

/// @brief Top-level random engine. Stim simulator RNGs are based off of this
/// engine.
std::mt19937_64 randomEngine;

/// @brief Stim Tableau simulator (noiseless)
std::unique_ptr<stim::TableauSimulator<W>> tableau;

/// @brief Stim Frame/Flip simulator (used to generate multiple shots)
std::unique_ptr<stim::FrameSimulator<W>> sampleSim;

/// @brief Grow the state vector by one qubit.
void addQubitToState() override { addQubitsToState(1); }

/// @brief Get the batch size to use for the Stim sample simulator.
std::size_t getBatchSize() {
// Default to single shot
std::size_t batch_size = 1;
if (getExecutionContext() && getExecutionContext()->name == "sample" &&
!getExecutionContext()->hasConditionalsOnMeasureResults)
batch_size = getExecutionContext()->shots;
return batch_size;
}

/// @brief Override the default sized allocation of qubits
/// here to be a bit more efficient than the default implementation
void addQubitsToState(std::size_t qubitCount,
const void *stateDataIn = nullptr) override {
if (stateDataIn)
throw std::runtime_error("The Stim simulator does not support "
"initialization of qubits from state data.");
return;

if (!tableau) {
cudaq::info("Creating new Stim Tableau simulator");
// Bump the randomEngine before cloning and giving to the Tableau
// simulator.
randomEngine.discard(
std::uniform_int_distribution<int>(1, 30)(randomEngine));
tableau = std::make_unique<stim::TableauSimulator<W>>(
std::mt19937_64(randomEngine), /*num_qubits=*/0, /*sign_bias=*/+0);
}
if (!sampleSim) {
auto batch_size = getBatchSize();
cudaq::info("Creating new Stim frame simulator with batch size {}",
batch_size);
// Bump the randomEngine before cloning and giving to the sample
// simulator.
randomEngine.discard(
std::uniform_int_distribution<int>(1, 30)(randomEngine));
sampleSim = std::make_unique<stim::FrameSimulator<W>>(
stim::CircuitStats(),
stim::FrameSimulatorMode::STORE_MEASUREMENTS_TO_MEMORY, batch_size,
std::mt19937_64(randomEngine));
sampleSim->reset_all();
}
}

/// @brief Reset the qubit state.
void deallocateStateImpl() override { stimCircuit.clear(); }
void deallocateStateImpl() override {
tableau.reset();
// Update the randomEngine so that future invocations will use the updated
// RNG state.
if (sampleSim)
randomEngine = std::move(sampleSim->rng);
sampleSim.reset();
num_measurements = 0;
}

/// @brief Apply operation to all Stim simulators.
void applyOpToSims(const std::string &gate_name,
const std::vector<uint32_t> &targets) {
stim::Circuit tempCircuit;
cudaq::info("Calling applyOpToSims {} - {}", gate_name, targets);
tempCircuit.safe_append_u(gate_name, targets);
tableau->safe_do_circuit(tempCircuit);
sampleSim->safe_do_circuit(tempCircuit);
}

/// @brief Apply the noise channel on \p qubits
void applyNoiseChannel(const std::string_view gateName,
Expand Down Expand Up @@ -78,19 +142,21 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase<double> {
cudaq::info("Applying {} kraus channels to qubits {}", krausChannels.size(),
stimTargets);

stim::Circuit noiseOps;
for (auto &channel : krausChannels) {
if (channel.noise_type == cudaq::noise_model_type::bit_flip_channel)
stimCircuit.safe_append_ua("X_ERROR", stimTargets,
channel.parameters[0]);
noiseOps.safe_append_ua("X_ERROR", stimTargets, channel.parameters[0]);
else if (channel.noise_type ==
cudaq::noise_model_type::phase_flip_channel)
stimCircuit.safe_append_ua("Z_ERROR", stimTargets,
channel.parameters[0]);
noiseOps.safe_append_ua("Z_ERROR", stimTargets, channel.parameters[0]);
else if (channel.noise_type ==
cudaq::noise_model_type::depolarization_channel)
stimCircuit.safe_append_ua("DEPOLARIZE1", stimTargets,
channel.parameters[0]);
noiseOps.safe_append_ua("DEPOLARIZE1", stimTargets,
channel.parameters[0]);
}
// Only apply the noise operations to the sample simulator (not the Tableau
// simulator).
sampleSim->safe_do_circuit(noiseOps);
}

void applyGate(const GateApplicationTask &task) override {
Expand Down Expand Up @@ -119,7 +185,7 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase<double> {
for (auto t : task.targets)
stimTargets.push_back(t);
try {
stimCircuit.safe_append_u(gateName, stimTargets);
applyOpToSims(gateName, stimTargets);
} catch (std::out_of_range &e) {
throw std::runtime_error(
fmt::format("Gate not supported by Stim simulator: {}. Note that "
Expand All @@ -137,14 +203,31 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase<double> {
return 0;
}

/// @brief Measure the qubit and return the result. Collapse the
/// state vector.
bool measureQubit(const std::size_t index) override { return false; }
/// @brief Measure the qubit and return the result.
bool measureQubit(const std::size_t index) override {
// Perform measurement
applyOpToSims(
"M", std::vector<std::uint32_t>{static_cast<std::uint32_t>(index)});
num_measurements++;

// Get the tableau bit that was just generated.
const std::vector<bool> &v = tableau->measurement_record.storage;
const bool tableauBit = *v.crbegin();

// Get the mid-circuit sample to be XOR-ed with tableauBit.
bool sampleSimBit =
sampleSim->m_record.storage[num_measurements - 1][/*shot=*/0];

// Calculate the result.
bool result = tableauBit ^ sampleSimBit;

return result;
}

QubitOrdering getQubitOrdering() const override { return QubitOrdering::msb; }

public:
StimCircuitSimulator() {
StimCircuitSimulator() : randomEngine(std::random_device{}()) {
// Populate the correct name so it is printed correctly during
// deconstructor.
summaryData.name = name();
Expand All @@ -162,26 +245,38 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase<double> {
void resetQubit(const std::size_t index) override {
flushGateQueue();
flushAnySamplingTasks();
stimCircuit.safe_append_u(
applyOpToSims(
"R", std::vector<std::uint32_t>{static_cast<std::uint32_t>(index)});
}

/// @brief Sample the multi-qubit state.
cudaq::ExecutionResult sample(const std::vector<std::size_t> &qubits,
const int shots) override {
assert(shots <= sampleSim->batch_size);
std::vector<std::uint32_t> stimTargetQubits(qubits.begin(), qubits.end());
stimCircuit.safe_append_u("M", stimTargetQubits);
if (false) {
std::stringstream ss;
ss << stimCircuit << '\n';
cudaq::log("Stim circuit is\n{}", ss.str());
}
auto ref_sample = stim::TableauSimulator<
stim::MAX_BITWORD_WIDTH>::reference_sample_circuit(stimCircuit);
stim::simd_bit_table<stim::MAX_BITWORD_WIDTH> sample =
stim::sample_batch_measurements(stimCircuit, ref_sample, shots,
randomEngine, false);
size_t bits_per_sample = stimCircuit.count_measurements();
applyOpToSims("M", stimTargetQubits);
num_measurements += stimTargetQubits.size();

// Generate a reference sample
const std::vector<bool> &v = tableau->measurement_record.storage;
stim::simd_bits<W> ref(v.size());
for (size_t k = 0; k < v.size(); k++)
ref[k] ^= v[k];

// Now XOR results on a per-shot basis
stim::simd_bit_table<W> sample = sampleSim->m_record.storage;
auto nShots = sampleSim->batch_size;

// This is a slightly modified version of `sample_batch_measurements`, where
// we already have the `sample` from the frame simulator. It also places the
// `sample` in a layout amenable to the order of the loops below (shot
// major).
sample = sample.transposed();
if (ref.not_zero())
for (size_t s = 0; s < nShots; s++)
sample[s].word_range_ref(0, ref.num_simd_words) ^= ref;

size_t bits_per_sample = num_measurements;
std::vector<std::string> sequentialData;
// Only retain the final "qubits.size()" measurements. All other
// measurements were mid-circuit measurements that have been previously
Expand All @@ -191,9 +286,8 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase<double> {
CountsDictionary counts;
for (std::size_t shot = 0; shot < shots; shot++) {
std::string aShot(qubits.size(), '0');
for (std::size_t b = first_bit_to_save; b < bits_per_sample; b++) {
aShot[b - first_bit_to_save] = sample[b][shot] ? '1' : '0';
}
for (std::size_t b = first_bit_to_save; b < bits_per_sample; b++)
aShot[b - first_bit_to_save] = sample[shot][b] ? '1' : '0';
counts[aShot]++;
sequentialData.push_back(std::move(aShot));
}
Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_cond_for_break.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
******************************************************************************/

// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t

#include <cudaq.h>
Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_cond_for_loop-1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// clang-format off
// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t
// clang-format on

Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_cond_for_loop-2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// clang-format off
// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t
// clang-format on

Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_cond_for_loop-3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// clang-format off
// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t
// clang-format on

Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_cond_for_loop-4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// clang-format off
// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t
// clang-format on

Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_cond_for_loop-5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// clang-format off
// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t
// clang-format on

Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_cond_for_loop-6.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// clang-format off
// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t
// clang-format on

Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_simple_cond-1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// clang-format off
// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ %cpp_std --target stim --enable-mlir %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t
// clang-format on

Expand Down
24 changes: 22 additions & 2 deletions unittests/integration/builder_tester.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ CUDAQ_TEST(BuilderTester, checkSwap) {

// Conditional execution on the tensornet backend is slow for a large number of
// shots.
#if !defined(CUDAQ_BACKEND_TENSORNET) && !defined(CUDAQ_BACKEND_STIM)
#if !defined(CUDAQ_BACKEND_TENSORNET)
CUDAQ_TEST(BuilderTester, checkConditional) {
{
cudaq::set_random_seed(13);
Expand Down Expand Up @@ -985,7 +985,6 @@ CUDAQ_TEST(BuilderTester, checkMidCircuitMeasure) {

EXPECT_EQ(counts.count("0", "c1"), 1000);
EXPECT_EQ(counts.count("1", "c0"), 1000);
return;
}

{
Expand All @@ -1005,6 +1004,27 @@ CUDAQ_TEST(BuilderTester, checkMidCircuitMeasure) {
EXPECT_EQ(counts.count("1", "hello2"), 0);
EXPECT_EQ(counts.count("0", "hello2"), 1000);
}

{
// Force conditional sample
auto entryPoint = cudaq::make_kernel();
auto q = entryPoint.qalloc(2);
entryPoint.h(q[0]);
auto mres = entryPoint.mz(q[0], "res0");
entryPoint.c_if(mres, [&]() { entryPoint.x(q[1]); });
entryPoint.mz(q, "final");

printf("%s\n", entryPoint.to_quake().c_str());
auto counts = cudaq::sample(entryPoint);
counts.dump();

EXPECT_GT(counts.count("0", "res0"), 0);
EXPECT_GT(counts.count("1", "res0"), 0);
EXPECT_GT(counts.count("00", "final"), 0);
EXPECT_EQ(counts.count("01", "final"), 0);
EXPECT_EQ(counts.count("10", "final"), 0);
EXPECT_GT(counts.count("11", "final"), 0);
}
}
#endif

Expand Down

0 comments on commit 71abfc7

Please sign in to comment.