Skip to content

Commit

Permalink
add downcasting step to streamable classmethods that support it
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-o-how committed Sep 5, 2024
1 parent 91205ae commit 5f0bab7
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 7 deletions.
14 changes: 13 additions & 1 deletion crates/chia-bls/src/gtelement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ use blst::*;
use chia_traits::chia_error::Result;
use chia_traits::{read_bytes, Streamable};
use clvmr::sha2::Sha256;
#[cfg(feature = "py-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "py-bindings")]
use pyo3::types::PyType;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::io::Cursor;
Expand Down Expand Up @@ -112,6 +116,15 @@ impl GTElement {
hex::encode(self.to_bytes())
}

#[classmethod]
#[pyo3(name = "from_parent")]
pub fn from_parent(_cls: &Bound<'_, PyType>, instance: Self) -> PyResult<PyObject> {
Python::with_gil(|py| {
// ignore child case
Ok(instance.into_py(py))
})
}

#[must_use]
pub fn __mul__(&self, rhs: &Self) -> Self {
let mut ret = self.clone();
Expand All @@ -130,7 +143,6 @@ mod pybindings {

use crate::parse_hex::parse_hex_string;
use chia_traits::{FromJsonDict, ToJsonDict};
use pyo3::prelude::*;

impl ToJsonDict for GTElement {
fn to_json_dict(&self, py: Python<'_>) -> PyResult<PyObject> {
Expand Down
13 changes: 13 additions & 0 deletions crates/chia-bls/src/public_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ use crate::{DerivableKey, Error, Result};
use blst::*;
use chia_traits::{read_bytes, Streamable};
use clvmr::sha2::Sha256;
#[cfg(feature = "py-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "py-bindings")]
use pyo3::types::PyType;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::io::Cursor;
Expand Down Expand Up @@ -321,6 +325,15 @@ impl PublicKey {
other.pair(self)
}

#[classmethod]
#[pyo3(name = "from_parent")]
pub fn from_parent(_cls: &Bound<'_, PyType>, instance: Self) -> PyResult<PyObject> {
Python::with_gil(|py| {
// ignore child case
Ok(instance.into_py(py))
})
}

#[pyo3(name = "get_fingerprint")]
pub fn py_get_fingerprint(&self) -> u32 {
self.get_fingerprint()
Expand Down
13 changes: 13 additions & 0 deletions crates/chia-bls/src/secret_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ use blst::*;
use chia_traits::{read_bytes, Streamable};
use clvmr::sha2::Sha256;
use hkdf::HkdfExtract;
#[cfg(feature = "py-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "py-bindings")]
use pyo3::types::PyType;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::io::Cursor;
Expand Down Expand Up @@ -266,6 +270,15 @@ impl SecretKey {
hex::encode(self.to_bytes())
}

#[classmethod]
#[pyo3(name = "from_parent")]
pub fn from_parent(_cls: &Bound<'_, PyType>, instance: Self) -> PyResult<PyObject> {
Python::with_gil(|py| {
// ignore child case
Ok(instance.into_py(py))
})
}

#[pyo3(name = "derive_hardened")]
#[must_use]
pub fn py_derive_hardened(&self, idx: u32) -> Self {
Expand Down
14 changes: 13 additions & 1 deletion crates/chia-bls/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ use crate::{Error, GTElement, PublicKey, Result, SecretKey};
use blst::*;
use chia_traits::{read_bytes, Streamable};
use clvmr::sha2::Sha256;
#[cfg(feature = "py-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "py-bindings")]
use pyo3::types::PyType;
use std::borrow::Borrow;
use std::fmt;
use std::hash::{Hash, Hasher};
Expand Down Expand Up @@ -486,6 +490,15 @@ impl Signature {
Self::default()
}

#[classmethod]
#[pyo3(name = "from_parent")]
pub fn from_parent(_cls: &Bound<'_, PyType>, instance: Self) -> PyResult<PyObject> {
Python::with_gil(|py| {
// ignore child case
Ok(instance.into_py(py))
})
}

#[pyo3(name = "pair")]
pub fn py_pair(&self, other: &PublicKey) -> GTElement {
self.pair(other)
Expand Down Expand Up @@ -518,7 +531,6 @@ mod pybindings {
use crate::parse_hex::parse_hex_string;

use chia_traits::{FromJsonDict, ToJsonDict};
use pyo3::prelude::*;

impl ToJsonDict for Signature {
fn to_json_dict(&self, py: Python<'_>) -> PyResult<PyObject> {
Expand Down
17 changes: 17 additions & 0 deletions crates/chia-protocol/src/coin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use clvmr::sha2::Sha256;

#[cfg(feature = "py-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "py-bindings")]
use pyo3::types::PyType;

#[streamable]
#[derive(Copy)]
Expand Down Expand Up @@ -55,6 +57,21 @@ impl Coin {
}
}

#[cfg(feature = "py-bindings")]
#[pymethods]
impl Coin {
#[classmethod]
#[pyo3(name = "from_parent")]
pub fn from_parent(cls: &Bound<'_, PyType>, coin: Self) -> PyResult<PyObject> {
Python::with_gil(|py| {
// Convert result into potential child class
let instance = cls.call1((coin.parent_coin_info, coin.puzzle_hash, coin.amount))?;

Ok(instance.into_py(py))
})
}
}

impl<N, E: ClvmEncoder<Node = N>> ToClvm<E> for Coin {
fn to_clvm(&self, encoder: &mut E) -> Result<N, ToClvmError> {
clvm_list!(self.parent_coin_info, self.puzzle_hash, self.amount).to_clvm(encoder)
Expand Down
19 changes: 19 additions & 0 deletions crates/chia-protocol/src/coin_spend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,29 @@ use chia_streamable_macro::streamable;

use crate::coin::Coin;
use crate::program::Program;
#[cfg(feature = "py-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "py-bindings")]
use pyo3::types::PyType;

#[streamable]
pub struct CoinSpend {
coin: Coin,
puzzle_reveal: Program,
solution: Program,
}

#[cfg(feature = "py-bindings")]
#[pymethods]
impl CoinSpend {
#[classmethod]
#[pyo3(name = "from_parent")]
pub fn from_parent(cls: &Bound<'_, PyType>, cs: Self) -> PyResult<PyObject> {
Python::with_gil(|py| {
// Convert result into potential child class
let instance = cls.call1((cs.coin, cs.puzzle_reveal, cs.solution))?;

Ok(instance.into_py(py))
})
}
}
17 changes: 17 additions & 0 deletions crates/chia-protocol/src/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ use clvmr::serde::{
};
use clvmr::sha2::Sha256;
use clvmr::{Allocator, ChiaDialect};
#[cfg(feature = "py-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "py-bindings")]
use pyo3::types::PyType;
use std::io::Cursor;
use std::ops::Deref;

Expand Down Expand Up @@ -486,6 +490,19 @@ impl ToJsonDict for Program {
}
}

#[cfg(feature = "py-bindings")]
#[pymethods]
impl Program {
#[classmethod]
#[pyo3(name = "from_parent")]
pub fn from_parent(_cls: &Bound<'_, PyType>, instance: Self) -> PyResult<PyObject> {
Python::with_gil(|py| {
// ignore child case
Ok(instance.into_py(py))
})
}
}

#[cfg(feature = "py-bindings")]
impl FromJsonDict for Program {
fn from_json_dict(o: &Bound<'_, PyAny>) -> PyResult<Self> {
Expand Down
14 changes: 14 additions & 0 deletions crates/chia-protocol/src/spend_bundle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,20 @@ impl SpendBundle {
})
}

#[classmethod]
#[pyo3(name = "from_parent")]
pub fn from_parent(cls: &Bound<'_, PyType>, spend_bundle: Self) -> PyResult<PyObject> {
Python::with_gil(|py| {
// Convert result into potential child class
let instance = cls.call(
(spend_bundle.coin_spends, spend_bundle.aggregated_signature),
None,
)?;

Ok(instance.into_py(py))
})
}

#[pyo3(name = "name")]
fn py_name(&self) -> Bytes32 {
self.name()
Expand Down
36 changes: 31 additions & 5 deletions crates/chia_py_streamable_macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,21 @@ pub fn py_streamable_macro(input: proc_macro::TokenStream) -> proc_macro::TokenS
impl #ident {
#[classmethod]
#[pyo3(signature=(json_dict))]
pub fn from_json_dict(_cls: &pyo3::Bound<'_, pyo3::types::PyType>, json_dict: &pyo3::Bound<pyo3::PyAny>) -> pyo3::PyResult<Self> {
<Self as #crate_name::from_json_dict::FromJsonDict>::from_json_dict(json_dict)
pub fn from_json_dict(cls: &pyo3::Bound<'_, pyo3::types::PyType>, json_dict: &pyo3::Bound<pyo3::PyAny>) -> pyo3::PyResult<pyo3::PyObject> {
use pyo3::prelude::PyAnyMethods;
use pyo3::IntoPy;
let rust_obj = <Self as #crate_name::from_json_dict::FromJsonDict>::from_json_dict(json_dict);
match rust_obj {
Ok(obk) => {
pyo3::Python::with_gil(|py| {
// Convert result into potential child class
// let instance = cls.call(py, (rust_obj,))?;
let instance = cls.call_method1("from_parent", (obk.into_py(py),))?;
Ok(instance.into_py(py))
})
},
Err(e) => Err(e)
}
}

pub fn to_json_dict(&self, py: pyo3::Python) -> pyo3::PyResult<pyo3::PyObject> {
Expand Down Expand Up @@ -191,7 +204,7 @@ pub fn py_streamable_macro(input: proc_macro::TokenStream) -> proc_macro::TokenS
pyo3::Python::with_gil(|py| {
// Convert result into potential child class
// let instance = cls.call(py, (rust_obj,))?;
let instance = cls.call1((obk.into_py(py),))?;
let instance = cls.call_method1("from_parent", (obk.into_py(py),))?;
Ok(instance.into_py(py))
})
},
Expand All @@ -201,14 +214,27 @@ pub fn py_streamable_macro(input: proc_macro::TokenStream) -> proc_macro::TokenS

#[classmethod]
#[pyo3(name = "from_bytes_unchecked")]
pub fn py_from_bytes_unchecked(_cls: &pyo3::Bound<'_, pyo3::types::PyType>, blob: pyo3::buffer::PyBuffer<u8>) -> pyo3::PyResult<Self> {
pub fn py_from_bytes_unchecked(cls: &pyo3::Bound<'_, pyo3::types::PyType>, blob: pyo3::buffer::PyBuffer<u8>) -> pyo3::PyResult<pyo3::PyObject> {
use pyo3::prelude::PyAnyMethods;
use pyo3::IntoPy;
if !blob.is_c_contiguous() {
panic!("from_bytes_unchecked() must be called with a contiguous buffer");
}
let slice = unsafe {
std::slice::from_raw_parts(blob.buf_ptr() as *const u8, blob.len_bytes())
};
<Self as #crate_name::Streamable>::from_bytes_unchecked(slice).map_err(|e| <#crate_name::chia_error::Error as Into<pyo3::PyErr>>::into(e))
let rust_obj = <Self as #crate_name::Streamable>::from_bytes_unchecked(slice).map_err(|e| <#crate_name::chia_error::Error as Into<pyo3::PyErr>>::into(e));
match rust_obj {
Ok(obk) => {
pyo3::Python::with_gil(|py| {
// Convert result into potential child class
// let instance = cls.call(py, (rust_obj,))?;
let instance = cls.call_method1("from_parent", (obk.into_py(py),))?;
Ok(instance.into_py(py))
})
},
Err(e) => Err(e)
}
}

// returns the type as well as the number of bytes read from the buffer
Expand Down

0 comments on commit 5f0bab7

Please sign in to comment.