diff --git a/internal/eigenState/avsOperators/avsOperators.go b/internal/eigenState/avsOperators/avsOperators.go index 6f2fc40c..c3df7a3d 100644 --- a/internal/eigenState/avsOperators/avsOperators.go +++ b/internal/eigenState/avsOperators/avsOperators.go @@ -13,6 +13,7 @@ import ( "github.com/wealdtech/go-merkletree/v2/keccak256" orderedmap "github.com/wk8/go-ordered-map/v2" "go.uber.org/zap" + "golang.org/x/xerrors" "gorm.io/gorm" "gorm.io/gorm/clause" "slices" @@ -29,38 +30,44 @@ type RegisteredAvsOperators struct { CreatedAt time.Time } -// Schema for avs_operator_changes table -type AvsOperatorChange struct { - Id uint64 `gorm:"type:serial"` - Operator string - Avs string - Registered bool - TransactionHash string - TransactionIndex uint64 - LogIndex uint64 - BlockNumber uint64 - CreatedAt time.Time +// AccumulatedStateChange represents the accumulated state change for a given block +type AccumulatedStateChange struct { + Avs string + Operator string + Registered bool + BlockNumber uint64 +} + +// RegisteredAvsOperatorDiff represents the diff between the registered_avs_operators table and the accumulated state +type RegisteredAvsOperatorDiff struct { + Avs string + Operator string + BlockNumber uint64 + Registered bool +} + +// SlotId represents a unique identifier for a slot +type SlotId string + +func NewSlotId(avs string, operator string) SlotId { + return SlotId(fmt.Sprintf("%s_%s", avs, operator)) } // EigenState model for AVS operators that implements IEigenStateModel -type AvsOperators struct { +type AvsOperatorsModel struct { base.BaseEigenState - StateTransitions types.StateTransitions[AvsOperatorChange] + StateTransitions types.StateTransitions[AccumulatedStateChange] Db *gorm.DB Network config.Network Environment config.Environment logger *zap.Logger globalConfig *config.Config -} -type RegisteredAvsOperatorDiff struct { - Operator string - Avs string - BlockNumber uint64 - Registered bool + // Accumulates state changes for SlotIds, grouped by block number + stateAccumulator map[uint64]map[SlotId]*AccumulatedStateChange } -// Create new instance of AvsOperators state model +// Create new instance of AvsOperatorsModel state model func NewAvsOperators( esm *stateManager.EigenStateManager, grm *gorm.DB, @@ -68,8 +75,8 @@ func NewAvsOperators( Environment config.Environment, logger *zap.Logger, globalConfig *config.Config, -) (*AvsOperators, error) { - s := &AvsOperators{ +) (*AvsOperatorsModel, error) { + s := &AvsOperatorsModel{ BaseEigenState: base.BaseEigenState{ Logger: logger, }, @@ -78,27 +85,29 @@ func NewAvsOperators( Environment: Environment, logger: logger, globalConfig: globalConfig, + + stateAccumulator: make(map[uint64]map[SlotId]*AccumulatedStateChange), } esm.RegisterState(s, 0) return s, nil } -func (a *AvsOperators) GetModelName() string { - return "AvsOperators" +func (a *AvsOperatorsModel) GetModelName() string { + return "AvsOperatorsModel" } -// Get the state transitions for the AvsOperators state model +// Get the state transitions for the AvsOperatorsModel state model // // Each state transition is function indexed by a block number. // BlockNumber 0 is the catchall state // // Returns the map and a reverse sorted list of block numbers that can be traversed when // processing a log to determine which state change to apply. -func (a *AvsOperators) GetStateTransitions() (types.StateTransitions[AvsOperatorChange], []uint64) { - stateChanges := make(types.StateTransitions[AvsOperatorChange]) +func (a *AvsOperatorsModel) GetStateTransitions() (types.StateTransitions[AccumulatedStateChange], []uint64) { + stateChanges := make(types.StateTransitions[AccumulatedStateChange]) // TODO(seanmcgary): make this not a closure so this function doesnt get big an messy... - stateChanges[0] = func(log *storage.TransactionLog) (*AvsOperatorChange, error) { + stateChanges[0] = func(log *storage.TransactionLog) (*AccumulatedStateChange, error) { arguments, err := a.ParseLogArguments(log) if err != nil { return nil, err @@ -109,21 +118,39 @@ func (a *AvsOperators) GetStateTransitions() (types.StateTransitions[AvsOperator return nil, err } + // Sanity check to make sure we've got an initialized accumulator map for the block + if _, ok := a.stateAccumulator[log.BlockNumber]; !ok { + return nil, xerrors.Errorf("No state accumulator found for block %d", log.BlockNumber) + } + + avs := arguments[0].Value.(string) + operator := arguments[1].Value.(string) + registered := false if val, ok := outputData["status"]; ok { registered = uint64(val.(float64)) == 1 } - change := &AvsOperatorChange{ - Operator: arguments[0].Value.(string), - Avs: arguments[1].Value.(string), - Registered: registered, - TransactionHash: log.TransactionHash, - TransactionIndex: log.TransactionIndex, - LogIndex: log.LogIndex, - BlockNumber: log.BlockNumber, + slotId := NewSlotId(avs, operator) + record, ok := a.stateAccumulator[log.BlockNumber][slotId] + if !ok { + record = &AccumulatedStateChange{ + Avs: avs, + Operator: operator, + BlockNumber: log.BlockNumber, + } + a.stateAccumulator[log.BlockNumber][slotId] = record } - return change, nil + if registered == false && ok { + // In this situation, we've encountered a register and unregister in the same block + // which functionally results in no state change at all so we want to remove the record + // from the accumulated state. + delete(a.stateAccumulator[log.BlockNumber], slotId) + return nil, nil + } + record.Registered = registered + + return record, nil } // Create an ordered list of block numbers @@ -140,7 +167,7 @@ func (a *AvsOperators) GetStateTransitions() (types.StateTransitions[AvsOperator } // Returns a map of contract addresses to event names that are interesting to the state model -func (a *AvsOperators) getContractAddressesForEnvironment() map[string][]string { +func (a *AvsOperatorsModel) getContractAddressesForEnvironment() map[string][]string { contracts := a.globalConfig.GetContractsMapForEnvAndNetwork() return map[string][]string{ contracts.AvsDirectory: []string{ @@ -150,19 +177,20 @@ func (a *AvsOperators) getContractAddressesForEnvironment() map[string][]string } // Given a log, determine if it is interesting to the state model -func (a *AvsOperators) IsInterestingLog(log *storage.TransactionLog) bool { +func (a *AvsOperatorsModel) IsInterestingLog(log *storage.TransactionLog) bool { addresses := a.getContractAddressesForEnvironment() return a.BaseEigenState.IsInterestingLog(addresses, log) } -func (a *AvsOperators) InitBlockProcessing(blockNumber uint64) error { +func (a *AvsOperatorsModel) InitBlockProcessing(blockNumber uint64) error { + a.stateAccumulator[blockNumber] = make(map[SlotId]*AccumulatedStateChange) return nil } // Handle the state change for the given log // // Takes a log and iterates over the state transitions to determine which state change to apply based on block number. -func (a *AvsOperators) HandleStateChange(log *storage.TransactionLog) (interface{}, error) { +func (a *AvsOperatorsModel) HandleStateChange(log *storage.TransactionLog) (interface{}, error) { stateChanges, sortedBlockNumbers := a.GetStateTransitions() for _, blockNumber := range sortedBlockNumbers { @@ -174,176 +202,136 @@ func (a *AvsOperators) HandleStateChange(log *storage.TransactionLog) (interface return nil, err } - if change != nil { - wroteChange, err := a.writeStateChange(change) - if err != nil { - return wroteChange, err - } - return wroteChange, nil + if change == nil { + return nil, xerrors.Errorf("No state change found for block %d", blockNumber) } + return change, nil } } return nil, nil } -// Write the state change to the database -func (a *AvsOperators) writeStateChange(change *AvsOperatorChange) (*AvsOperatorChange, error) { - a.logger.Sugar().Debugw("Writing state change", zap.Any("change", change)) - res := a.Db.Model(&AvsOperatorChange{}).Clauses(clause.Returning{}).Create(change) - if res.Error != nil { - a.logger.Error("Failed to insert into avs_operator_changes", zap.Error(res.Error)) - return change, res.Error - } - return change, nil -} - -// Write the new final state to the database. -// -// 1. Get latest distinct change value for each avs/operator -// 2. Join the latest unique change value with the previous blocks state to overlay new changes -// 3. Filter joined set on registered = false to get unregistrations -// 4. Determine which rows from the previous block should be carried over and which shouldnt (i.e. deregistrations) -// 5. Geneate the final state by unioning the carryover and the new registrations -// 6. Insert the final state into the registered_avs_operators table -func (a *AvsOperators) CommitFinalState(blockNumber uint64) error { +func (a *AvsOperatorsModel) clonePreviousBlocksToNewBlock(blockNumber uint64) error { query := ` - with new_changes as ( + insert into registered_avs_operators (avs, operator, block_number) select avs, operator, - block_number, - max(transaction_index) as transaction_index, - max(log_index) as log_index - from avs_operator_changes - where block_number = @currentBlock - group by 1, 2, 3 - ), - unique_registrations as ( - select - nc.avs, - nc.operator, - aoc.log_index, - aoc.registered, - nc.block_number - from new_changes as nc - left join avs_operator_changes as aoc on ( - aoc.avs = nc.avs - and aoc.operator = nc.operator - and aoc.log_index = nc.log_index - and aoc.transaction_index = nc.transaction_index - and aoc.block_number = nc.block_number - ) - ), - unregistrations as ( - select - concat(avs, '_', operator) as operator_avs - from unique_registrations - where registered = false - ), - carryover as ( - select - rao.avs, - rao.operator, @currentBlock as block_number - from registered_avs_operators as rao - where - rao.block_number = @previousBlock - and concat(rao.avs, '_', rao.operator) not in (select operator_avs from unregistrations) - ), - final_state as ( - (select avs, operator, block_number::bigint from carryover) - union all - (select avs, operator, block_number::bigint from unique_registrations where registered = true) - ) - insert into registered_avs_operators (avs, operator, block_number) - select avs, operator, block_number from final_state + from registered_avs_operators + where block_number = @previousBlock ` - res := a.Db.Exec(query, sql.Named("currentBlock", blockNumber), sql.Named("previousBlock", blockNumber-1), ) + if res.Error != nil { - a.logger.Sugar().Errorw("Failed to insert into registered_avs_operators", zap.Error(res.Error)) + a.logger.Sugar().Errorw("Failed to clone previous block state to new block", zap.Error(res.Error)) return res.Error } return nil } -func (a *AvsOperators) getDifferenceInStates(blockNumber uint64) ([]RegisteredAvsOperatorDiff, error) { - query := ` - with new_states as ( - select - avs, - operator, - block_number, - true as registered - from registered_avs_operators - where block_number = @currentBlock - ), - previous_states as ( - select - avs, - operator, - block_number, - true as registered - from registered_avs_operators - where block_number = @previousBlock - ), - unregistered as ( - (select avs, operator, registered from previous_states) - except - (select avs, operator, registered from new_states) - ), - new_registered as ( - (select avs, operator, registered from new_states) - except - (select avs, operator, registered from previous_states) - ) - select avs, operator, false as registered from unregistered - union all - select avs, operator, true as registered from new_registered; - ` - results := make([]RegisteredAvsOperatorDiff, 0) - res := a.Db.Model(&RegisteredAvsOperatorDiff{}). - Raw(query, - sql.Named("currentBlock", blockNumber), - sql.Named("previousBlock", blockNumber-1), - ). - Scan(&results) +// prepareState prepares the state for the current block by comparing the accumulated state changes. +// It separates out the changes into inserts and deletes +func (a *AvsOperatorsModel) prepareState(blockNumber uint64) ([]RegisteredAvsOperators, []RegisteredAvsOperators, error) { + accumulatedState, ok := a.stateAccumulator[blockNumber] + if !ok { + err := xerrors.Errorf("No accumulated state found for block %d", blockNumber) + a.logger.Sugar().Errorw(err.Error(), zap.Error(err), zap.Uint64("blockNumber", blockNumber)) + return nil, nil, err + } - if res.Error != nil { - a.logger.Sugar().Errorw("Failed to fetch registered_avs_operators", zap.Error(res.Error)) - return nil, res.Error + inserts := make([]RegisteredAvsOperators, 0) + deletes := make([]RegisteredAvsOperators, 0) + for _, stateChange := range accumulatedState { + record := RegisteredAvsOperators{ + Avs: stateChange.Avs, + Operator: stateChange.Operator, + BlockNumber: blockNumber, + } + if stateChange.Registered { + inserts = append(inserts, record) + } else { + deletes = append(deletes, record) + } } - return results, nil + return inserts, deletes, nil } -func (a *AvsOperators) ClearAccumulatedState(blockNumber uint64) error { - panic("implement me") +// CommitFinalState commits the final state for the given block number +func (a *AvsOperatorsModel) CommitFinalState(blockNumber uint64) error { + err := a.clonePreviousBlocksToNewBlock(blockNumber) + if err != nil { + return err + } + + recordsToInsert, recordsToDelete, err := a.prepareState(blockNumber) + if err != nil { + return err + } + + for _, record := range recordsToDelete { + res := a.Db.Delete(&RegisteredAvsOperators{}, "avs = ? and operator = ? and block_number = ?", record.Avs, record.Operator, record.BlockNumber) + if res.Error != nil { + a.logger.Sugar().Errorw("Failed to delete record", + zap.Error(res.Error), + zap.String("avs", record.Avs), + zap.String("operator", record.Operator), + zap.Uint64("blockNumber", blockNumber), + ) + return res.Error + } + } + if len(recordsToInsert) > 0 { + res := a.Db.Model(&RegisteredAvsOperators{}).Clauses(clause.Returning{}).Create(&recordsToInsert) + if res.Error != nil { + a.logger.Sugar().Errorw("Failed to insert records", zap.Error(res.Error)) + return res.Error + } + } + return nil } -// Generates a state root for the given block number. -// -// 1. Select all registered_avs_operators for the given block number ordered by avs and operator asc -// 2. Create an ordered map, with AVSs at the top level that point to an ordered map of operators and block numbers -// 3. Create a merkle tree for each AVS, with the operator:block_number pairs as leaves -// 4. Create a merkle tree for all AVS trees -// 5. Return the root of the full tree -func (a *AvsOperators) GenerateStateRoot(blockNumber uint64) (types.StateRoot, error) { - results, err := a.getDifferenceInStates(blockNumber) +func (a *AvsOperatorsModel) ClearAccumulatedState(blockNumber uint64) error { + delete(a.stateAccumulator, blockNumber) + return nil +} + +// GenerateStateRoot generates the state root for the given block number using the results of the state changes +func (a *AvsOperatorsModel) GenerateStateRoot(blockNumber uint64) (types.StateRoot, error) { + inserts, deletes, err := a.prepareState(blockNumber) if err != nil { return "", err } - fullTree, err := a.merkelizeState(blockNumber, results) + combinedResults := make([]RegisteredAvsOperatorDiff, 0) + for _, record := range inserts { + combinedResults = append(combinedResults, RegisteredAvsOperatorDiff{ + Avs: record.Avs, + Operator: record.Operator, + BlockNumber: record.BlockNumber, + Registered: true, + }) + } + for _, record := range deletes { + combinedResults = append(combinedResults, RegisteredAvsOperatorDiff{ + Avs: record.Avs, + Operator: record.Operator, + BlockNumber: record.BlockNumber, + Registered: false, + }) + } + + fullTree, err := a.merkelizeState(blockNumber, combinedResults) if err != nil { return "", err } return types.StateRoot(utils.ConvertBytesToString(fullTree.Root())), nil } -func (a *AvsOperators) merkelizeState(blockNumber uint64, avsOperators []RegisteredAvsOperatorDiff) (*merkletree.MerkleTree, error) { +func (a *AvsOperatorsModel) merkelizeState(blockNumber uint64, avsOperators []RegisteredAvsOperatorDiff) (*merkletree.MerkleTree, error) { // Avs -> operator:registered om := orderedmap.New[string, *orderedmap.OrderedMap[string, bool]]() diff --git a/internal/eigenState/avsOperators/avsOperators_test.go b/internal/eigenState/avsOperators/avsOperators_test.go index beff1aaa..ea07eaa4 100644 --- a/internal/eigenState/avsOperators/avsOperators_test.go +++ b/internal/eigenState/avsOperators/avsOperators_test.go @@ -28,7 +28,7 @@ func setup() ( return cfg, grm, l, err } -func teardown(model *AvsOperators) { +func teardown(model *AvsOperatorsModel) { model.Db.Exec("truncate table avs_operator_changes cascade") model.Db.Exec("truncate table registered_avs_operators cascade") } @@ -68,6 +68,9 @@ func Test_AvsOperatorState(t *testing.T) { assert.Equal(t, true, avsOperatorState.IsInterestingLog(&log)) + err = avsOperatorState.InitBlockProcessing(blockNumber) + assert.Nil(t, err) + res, err := avsOperatorState.HandleStateChange(&log) assert.Nil(t, err) assert.NotNil(t, res) @@ -98,6 +101,9 @@ func Test_AvsOperatorState(t *testing.T) { assert.Equal(t, true, avsOperatorState.IsInterestingLog(&log)) + err = avsOperatorState.InitBlockProcessing(blockNumber) + assert.Nil(t, err) + stateChange, err := avsOperatorState.HandleStateChange(&log) assert.Nil(t, err) assert.NotNil(t, stateChange) @@ -166,6 +172,9 @@ func Test_AvsOperatorState(t *testing.T) { for _, log := range logs { assert.True(t, avsOperatorState.IsInterestingLog(log)) + err = avsOperatorState.InitBlockProcessing(log.BlockNumber) + assert.Nil(t, err) + stateChange, err := avsOperatorState.HandleStateChange(log) assert.Nil(t, err) assert.NotNil(t, stateChange) @@ -185,16 +194,16 @@ func Test_AvsOperatorState(t *testing.T) { if log.BlockNumber == blocks[0] { assert.Equal(t, 1, len(states)) - diffs, err := avsOperatorState.getDifferenceInStates(log.BlockNumber) + inserts, deletes, err := avsOperatorState.prepareState(log.BlockNumber) assert.Nil(t, err) - assert.Equal(t, 1, len(diffs)) - assert.Equal(t, true, diffs[0].Registered) + assert.Equal(t, 1, len(inserts)) + assert.Equal(t, 0, len(deletes)) } else if log.BlockNumber == blocks[1] { assert.Equal(t, 0, len(states)) - diffs, err := avsOperatorState.getDifferenceInStates(log.BlockNumber) + inserts, deletes, err := avsOperatorState.prepareState(log.BlockNumber) assert.Nil(t, err) - assert.Equal(t, 1, len(diffs)) - assert.Equal(t, false, diffs[0].Registered) + assert.Equal(t, 0, len(inserts)) + assert.Equal(t, 1, len(deletes)) } stateRoot, err := avsOperatorState.GenerateStateRoot(log.BlockNumber) diff --git a/internal/eigenState/stakerDelegations/stakerDelegations.go b/internal/eigenState/stakerDelegations/stakerDelegations.go index 1ff0c49d..d5cdfc47 100644 --- a/internal/eigenState/stakerDelegations/stakerDelegations.go +++ b/internal/eigenState/stakerDelegations/stakerDelegations.go @@ -251,6 +251,9 @@ func (s *StakerDelegationsModel) CommitFinalState(blockNumber uint64) error { } recordsToInsert, recordsToDelete, err := s.prepareState(blockNumber) + if err != nil { + return err + } // TODO(seanmcgary): should probably wrap the operations of this function in a db transaction for _, record := range recordsToDelete {