Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove chia-blockchain dependency for most of chia_rs #887

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 68 additions & 56 deletions crates/chia-protocol/src/block_record.rs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a really nice change. However, I think it would be good to check the correctness of this against the existing python implementation. One way to do that would be to break this off into a separate PR that still relies on the chia-blockchain dependency to compare the result between the rust and python implementation. Can you think of another way to be confident in the correctness of this port?

Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use chia_streamable_macro::streamable;

use crate::{Bytes32, ClassgroupElement, Coin, SubEpochSummary};
use chia_streamable_macro::streamable;

#[cfg(feature = "py-bindings")]
use pyo3::prelude::*;
Expand Down Expand Up @@ -67,10 +66,48 @@ impl BlockRecord {
pub fn is_challenge_block(&self, min_blocks_per_challenge_block: u8) -> bool {
self.deficit == min_blocks_per_challenge_block - 1
}
}

#[cfg(feature = "py-bindings")]
use pyo3::types::PyDict;
fn calculate_sp_interval_iters(&self, num_sps_sub_slot: u64) -> PyResult<u64> {
if self.sub_slot_iters % num_sps_sub_slot != 0 {
return Err(PyValueError::new_err(
"sub_slot_iters % constants.NUM_SPS_SUB_SLOT != 0",
));
}
Ok(self.sub_slot_iters / num_sps_sub_slot)
}

fn calculate_sp_iters(&self, num_sps_sub_slot: u32) -> PyResult<u64> {
if self.signage_point_index as u32 >= num_sps_sub_slot {
return Err(PyValueError::new_err("SP index too high"));
}
Ok(self.calculate_sp_interval_iters(num_sps_sub_slot as u64)?
* self.signage_point_index as u64)
}

fn calculate_ip_iters(
&self,
num_sps_sub_slot: u32,
num_sp_intervals_extra: u8,
) -> PyResult<u64> {
let sp_iters = self.calculate_sp_iters(num_sps_sub_slot)?;
let sp_interval_iters = self.calculate_sp_interval_iters(num_sps_sub_slot as u64)?;
if sp_iters % sp_interval_iters != 0 || sp_iters >= self.sub_slot_iters {
return Err(PyValueError::new_err(format!(
"Invalid sp iters {sp_iters} for this ssi {}",
self.sub_slot_iters
)));
} else if self.required_iters >= sp_interval_iters || self.required_iters == 0 {
return Err(PyValueError::new_err(format!(
"Required iters {} is not below the sp interval iters {} {} or not >=0",
self.required_iters, sp_interval_iters, self.sub_slot_iters
)));
}
Ok(
(sp_iters + num_sp_intervals_extra as u64 * sp_interval_iters + self.required_iters)
% self.sub_slot_iters,
)
}
}

#[cfg(feature = "py-bindings")]
use pyo3::exceptions::PyValueError;
Expand Down Expand Up @@ -102,16 +139,11 @@ impl BlockRecord {
))
}

// TODO: at some point it would be nice to port
// chia.consensus.pot_iterations to rust, and make this less hacky
fn sp_sub_slot_total_iters_impl(
&self,
py: Python<'_>,
constants: &Bound<'_, PyAny>,
) -> PyResult<u128> {
// TODO: these could be implemented as a total port of pot iterations
fn sp_sub_slot_total_iters_impl(&self, constants: &Bound<'_, PyAny>) -> PyResult<u128> {
let ret = self
.total_iters
.checked_sub(self.ip_iters_impl(py, constants)? as u128)
.checked_sub(self.ip_iters_impl(constants)? as u128)
.ok_or(PyValueError::new_err("uint128 overflow"))?;
if self.overflow {
ret.checked_sub(self.sub_slot_iters as u128)
Expand All @@ -121,48 +153,28 @@ impl BlockRecord {
}
}

fn ip_sub_slot_total_iters_impl(
&self,
py: Python<'_>,
constants: &Bound<'_, PyAny>,
) -> PyResult<u128> {
fn ip_sub_slot_total_iters_impl(&self, constants: &Bound<'_, PyAny>) -> PyResult<u128> {
self.total_iters
.checked_sub(self.ip_iters_impl(py, constants)? as u128)
.checked_sub(self.ip_iters_impl(constants)? as u128)
.ok_or(PyValueError::new_err("uint128 overflow"))
}

fn sp_iters_impl(&self, py: Python<'_>, constants: &Bound<'_, PyAny>) -> PyResult<u64> {
let ctx = PyDict::new(py);
ctx.set_item("sub_slot_iters", self.sub_slot_iters)?;
ctx.set_item("signage_point_index", self.signage_point_index)?;
ctx.set_item("constants", constants)?;
py.run(
c"from chia.consensus.pot_iterations import calculate_ip_iters, calculate_sp_iters\n\
ret = calculate_sp_iters(constants, sub_slot_iters, signage_point_index)\n",
None,
Some(&ctx),
)?;
ctx.get_item("ret").unwrap().unwrap().extract::<u64>()
}

fn ip_iters_impl(&self, py: Python<'_>, constants: &Bound<'_, PyAny>) -> PyResult<u64> {
let ctx = PyDict::new(py);
ctx.set_item("sub_slot_iters", self.sub_slot_iters)?;
ctx.set_item("signage_point_index", self.signage_point_index)?;
ctx.set_item("required_iters", self.required_iters)?;
ctx.set_item("constants", constants)?;
py.run(
c"from chia.consensus.pot_iterations import calculate_ip_iters, calculate_sp_iters\n\
ret = calculate_ip_iters(constants, sub_slot_iters, signage_point_index, required_iters)\n",
None,
Some(&ctx),
)?;
ctx.get_item("ret").unwrap().unwrap().extract::<u64>()
}

fn sp_total_iters_impl(&self, py: Python<'_>, constants: &Bound<'_, PyAny>) -> PyResult<u128> {
self.sp_sub_slot_total_iters_impl(py, constants)?
.checked_add(self.sp_iters_impl(py, constants)? as u128)
fn sp_iters_impl(&self, constants: &Bound<'_, PyAny>) -> PyResult<u64> {
let num_sps_sub_slot = constants.get_item("NUM_SPS_SUB_SLOT")?.extract::<u32>()?;
self.calculate_sp_iters(num_sps_sub_slot)
}

fn ip_iters_impl(&self, constants: &Bound<'_, PyAny>) -> PyResult<u64> {
let num_sps_sub_slot = constants.get_item("NUM_SPS_SUB_SLOT")?.extract::<u32>()?;
let num_sp_intervals_extra = constants
.get_item("NUM_SP_INTERVALS_EXTRA")?
.extract::<u8>()?;
self.calculate_ip_iters(num_sps_sub_slot, num_sp_intervals_extra)
}

fn sp_total_iters_impl(&self, constants: &Bound<'_, PyAny>) -> PyResult<u128> {
self.sp_sub_slot_total_iters_impl(constants)?
.checked_add(self.sp_iters_impl(constants)? as u128)
.ok_or(PyValueError::new_err("uint128 overflow"))
}

Expand All @@ -171,38 +183,38 @@ impl BlockRecord {
py: Python<'a>,
constants: &Bound<'_, PyAny>,
) -> PyResult<Bound<'a, PyAny>> {
ChiaToPython::to_python(&self.sp_sub_slot_total_iters_impl(py, constants)?, py)
ChiaToPython::to_python(&self.sp_sub_slot_total_iters_impl(constants)?, py)
}

fn ip_sub_slot_total_iters<'a>(
&self,
py: Python<'a>,
constants: &Bound<'_, PyAny>,
) -> PyResult<Bound<'a, PyAny>> {
ChiaToPython::to_python(&self.ip_sub_slot_total_iters_impl(py, constants)?, py)
ChiaToPython::to_python(&self.ip_sub_slot_total_iters_impl(constants)?, py)
}

fn sp_iters<'a>(
&self,
py: Python<'a>,
constants: &Bound<'_, PyAny>,
) -> PyResult<Bound<'a, PyAny>> {
ChiaToPython::to_python(&self.sp_iters_impl(py, constants)?, py)
ChiaToPython::to_python(&self.sp_iters_impl(constants)?, py)
}

fn ip_iters<'a>(
&self,
py: Python<'a>,
constants: &Bound<'_, PyAny>,
) -> PyResult<Bound<'a, PyAny>> {
ChiaToPython::to_python(&self.ip_iters_impl(py, constants)?, py)
ChiaToPython::to_python(&self.ip_iters_impl(constants)?, py)
}

fn sp_total_iters<'a>(
&self,
py: Python<'a>,
constants: &Bound<'_, PyAny>,
) -> PyResult<Bound<'a, PyAny>> {
ChiaToPython::to_python(&self.sp_total_iters_impl(py, constants)?, py)
ChiaToPython::to_python(&self.sp_total_iters_impl(constants)?, py)
}
}
17 changes: 3 additions & 14 deletions tests/test_blscache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
)
from chia_rs.sized_bytes import bytes32
from chia_rs.sized_ints import uint8, uint16, uint32, uint64, uint128
from chia.util.hash import std_hash
from chia.util.lru_cache import LRUCache
from chia.types.blockchain_format.program import Program as ChiaProgram
import pytest


Expand Down Expand Up @@ -174,7 +171,6 @@ def test_cached_bls():

# Use a small cache which can not accommodate all pairings
bls_cache = BLSCache(n_keys // 2)
local_cache = LRUCache(n_keys // 2)
# Verify signatures and cache pairings one at a time
for pk, msg, sig in zip(pks_half, msgs_half, sigs_half):
assert bls_cache.aggregate_verify([pk], [msg], sig)
Expand Down Expand Up @@ -221,13 +217,10 @@ def test_cached_bls_repeat_pk():
cached_bls = BLSCache()
n_keys = 400
seed = b"a" * 32
sks = [AugSchemeMPL.key_gen(seed) for i in range(n_keys)] + [
AugSchemeMPL.key_gen(std_hash(seed))
]
sks = [AugSchemeMPL.key_gen(seed) for _ in range(n_keys)]
pks = [sk.get_g1() for sk in sks]
pks_bytes = [bytes(sk.get_g1()) for sk in sks]

msgs = [("msg-%d" % (i,)).encode() for i in range(n_keys + 1)]
msgs = [("msg-%d" % (i,)).encode() for i in range(n_keys)]
sigs = [AugSchemeMPL.sign(sk, msg) for sk, msg in zip(sks, msgs)]
agg_sig = AugSchemeMPL.aggregate(sigs)

Expand Down Expand Up @@ -304,11 +297,7 @@ def test_validate_clvm_and_sig():
)
sig = AugSchemeMPL.sign(
sk,
(
ChiaProgram.to("hello").as_atom()
+ coin.name()
+ DEFAULT_CONSTANTS.AGG_SIG_ME_ADDITIONAL_DATA
), # noqa
(b"hello" + coin.name() + DEFAULT_CONSTANTS.AGG_SIG_ME_ADDITIONAL_DATA), # noqa
)

new_spend = SpendBundle(coin_spends, sig)
Expand Down
133 changes: 0 additions & 133 deletions tests/test_program_fidelity.py

This file was deleted.

Loading