Skip to content

Commit

Permalink
Update blockifier and add support for SkipFeeCharge flag (#1107)
Browse files Browse the repository at this point in the history
  • Loading branch information
omerfirmak authored Aug 17, 2023
1 parent bfe8e8b commit 2aae6b5
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 24 deletions.
8 changes: 4 additions & 4 deletions mocks/mock_vm.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 12 additions & 5 deletions rpc/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ func (h *Handler) TransactionStatus(ctx context.Context, hash felt.Felt) (*Trans
}

func (h *Handler) EstimateFee(broadcastedTxns []BroadcastedTransaction, id BlockID) ([]FeeEstimate, *jsonrpc.Error) {
result, err := h.SimulateTransactions(id, broadcastedTxns, nil)
result, err := h.SimulateTransactions(id, broadcastedTxns, []SimulationFlag{SkipFeeChargeFlag})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1137,9 +1137,15 @@ func (h *Handler) TraceTransaction(hash felt.Felt) (json.RawMessage, *jsonrpc.Er
func (h *Handler) SimulateTransactions(id BlockID, transactions []BroadcastedTransaction,
simulationFlags []SimulationFlag,
) ([]SimulatedTransaction, *jsonrpc.Error) {
if len(simulationFlags) > 0 {
return nil, jsonrpc.Err(jsonrpc.InvalidParams, "Simulation flags are not supported")
skipValidate := utils.Any(simulationFlags, func(f SimulationFlag) bool {
return f == SkipValidateFlag
})
if skipValidate {
return nil, jsonrpc.Err(jsonrpc.InvalidParams, "Skip validate is not supported")
}
skipFeeCharge := utils.Any(simulationFlags, func(f SimulationFlag) bool {
return f == SkipFeeChargeFlag
})

state, closer, err := h.stateByBlockID(&id)
if err != nil {
Expand Down Expand Up @@ -1185,7 +1191,8 @@ func (h *Handler) SimulateTransactions(id BlockID, transactions []BroadcastedTra
if sequencerAddress == nil {
sequencerAddress = core.NetworkBlockHashMetaInfo(h.network).FallBackSequencerAddress
}
gasesConsumed, traces, err := h.vm.Execute(txns, classes, blockNumber, header.Timestamp, sequencerAddress, state, h.network, paidFeesOnL1)
gasesConsumed, traces, err := h.vm.Execute(txns, classes, blockNumber, header.Timestamp, sequencerAddress,
state, h.network, paidFeesOnL1, skipFeeCharge)
if err != nil {
rpcErr := *ErrContractError
rpcErr.Data = err.Error()
Expand Down Expand Up @@ -1275,7 +1282,7 @@ func (h *Handler) traceBlockTransactions(block *core.Block, numTxns int) ([]Trac
}

_, traces, err := h.vm.Execute(transactions, classes, blockNumber, header.Timestamp,
sequencerAddress, state, h.network, paidFeesOnL1)
sequencerAddress, state, h.network, paidFeesOnL1, false)
if err != nil {
rpcErr := *ErrContractError
rpcErr.Data = err.Error()
Expand Down
15 changes: 8 additions & 7 deletions rpc/handlers_test.go

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions utils/slices.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ func IndexFunc[T comparable](slice []T, f func(T) bool) int {
func All[T comparable](slice []T, f func(T) bool) bool {
return IndexFunc(slice, func(e T) bool { return !f(e) }) == -1
}

// Any returns true if any of the elements match the given predicate
func Any[T comparable](slice []T, f func(T) bool) bool {
return IndexFunc(slice, f) != -1
}
24 changes: 24 additions & 0 deletions utils/slices_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,27 @@ func TestAll(t *testing.T) {
assert.True(t, allOdd)
})
}

func TestAny(t *testing.T) {
t.Run("nil slice", func(t *testing.T) {
var input []int
v := Any(input, func(int) bool {
return false
})
assert.False(t, v)
})
t.Run("not found", func(t *testing.T) {
input := []int{1, 2, 3, 4}
found := Any(input, func(v int) bool {
return v == 5
})
assert.False(t, found)
})
t.Run("found", func(t *testing.T) {
input := []int{1, 2, 3, 4, 5}
found := Any(input, func(v int) bool {
return v == 5
})
assert.True(t, found)
})
}
4 changes: 2 additions & 2 deletions vm/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
serde = "1.0.160"
serde = "1.0.171"
serde_json = { version = "1.0.96", features = ["raw_value"] }
blockifier = {git = "https://github.com/starkware-libs/blockifier", rev = "4cd75c7e6d8e5534cbf9074d7a7d0192283266d7"}
blockifier = {git = "https://github.com/starkware-libs/blockifier", rev = "5ba0fb4"}
starknet_api = { git = "https://github.com/starkware-libs/starknet-api", rev = "8f620bc" }
cairo-vm = "0.8.2"
cairo-lang-casm = "2.1.0"
Expand Down
6 changes: 4 additions & 2 deletions vm/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ pub extern "C" fn cairoVMExecute(
chain_id: *const c_char,
sequencer_address: *const c_uchar,
paid_fees_on_l1_json: *const c_char,
skip_charge_fee: c_uchar,
) {
let reader = JunoStateReader::new(reader_handle);
let chain_id_str = unsafe { CStr::from_ptr(chain_id) }.to_str().unwrap();
Expand Down Expand Up @@ -210,10 +211,11 @@ pub extern "C" fn cairoVMExecute(
return;
}

let charge_fee = skip_charge_fee == 0;
let res = match txn.unwrap() {
Transaction::AccountTransaction(t) => t.execute(&mut state, &block_context),
Transaction::AccountTransaction(t) => t.execute(&mut state, &block_context, charge_fee),
Transaction::L1HandlerTransaction(t) => {
let maybe_execution_info = t.execute(&mut state, &block_context);
let maybe_execution_info = t.execute(&mut state, &block_context, charge_fee);
if maybe_execution_info.is_err() {
maybe_execution_info
} else {
Expand Down
13 changes: 11 additions & 2 deletions vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ package vm
// char* chain_id);
//
// extern void cairoVMExecute(char* txns_json, char* classes_json, uintptr_t readerHandle, unsigned long long block_number,
// unsigned long long block_timestamp, char* chain_id, char* sequencer_address, char* paid_fees_on_l1_json);
// unsigned long long block_timestamp, char* chain_id, char* sequencer_address, char* paid_fees_on_l1_json,
// unsigned char skip_charge_fee);
//
// #cgo LDFLAGS: -L./rust/target/release -ljuno_starknet_rs -lm -ldl
import "C"
Expand All @@ -31,6 +32,7 @@ type VM interface {
) ([]*felt.Felt, error)
Execute(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64,
sequencerAddress *felt.Felt, state core.StateReader, network utils.Network, paidFeesOnL1 []*felt.Felt,
skipChargeFee bool,
) ([]*felt.Felt, []json.RawMessage, error)
}

Expand Down Expand Up @@ -142,6 +144,7 @@ func (*vm) Call(contractAddr, selector *felt.Felt, calldata []felt.Felt, blockNu
// Execute executes a given transaction set and returns the gas spent per transaction
func (*vm) Execute(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64,
sequencerAddress *felt.Felt, state core.StateReader, network utils.Network, paidFeesOnL1 []*felt.Felt,
skipChargeFee bool,
) ([]*felt.Felt, []json.RawMessage, error) {
context := &callContext{
state: state,
Expand All @@ -164,6 +167,10 @@ func (*vm) Execute(txns []core.Transaction, declaredClasses []core.Class, blockN
classesJSONCStr := C.CString(string(classesJSON))

sequencerAddressBytes := sequencerAddress.Bytes()
var skipChargeFeeByte byte
if skipChargeFee {
skipChargeFeeByte = 1
}

chainID := C.CString(network.ChainIDString())
C.cairoVMExecute(txnsJSONCstr,
Expand All @@ -173,7 +180,9 @@ func (*vm) Execute(txns []core.Transaction, declaredClasses []core.Class, blockN
C.ulonglong(blockTimestamp),
chainID,
(*C.char)(unsafe.Pointer(&sequencerAddressBytes[0])),
paidFeesOnL1CStr)
paidFeesOnL1CStr,
C.uchar(skipChargeFeeByte),
)

C.free(unsafe.Pointer(classesJSONCStr))
C.free(unsafe.Pointer(paidFeesOnL1CStr))
Expand Down
4 changes: 2 additions & 2 deletions vm/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,11 @@ func TestExecute(t *testing.T) {
address = utils.HexToFelt(t, "0x46a89ae102987331d369645031b49c27738ed096f2789c24449966da4c6de6b")
timestamp = uint64(1666877926)
)
_, _, err := New().Execute([]core.Transaction{}, []core.Class{}, 0, timestamp, address, state, network, []*felt.Felt{})
_, _, err := New().Execute([]core.Transaction{}, []core.Class{}, 0, timestamp, address, state, network, []*felt.Felt{}, false)
require.NoError(t, err)
})
t.Run("zero data", func(t *testing.T) {
_, _, err := New().Execute(nil, nil, 0, 0, &felt.Zero, state, network, []*felt.Felt{})
_, _, err := New().Execute(nil, nil, 0, 0, &felt.Zero, state, network, []*felt.Felt{}, false)
require.NoError(t, err)
})
}

0 comments on commit 2aae6b5

Please sign in to comment.