Skip to content

Commit

Permalink
feat: abstract eigenStateModel
Browse files Browse the repository at this point in the history
  • Loading branch information
gpsanant committed Oct 20, 2024
1 parent 8f1e322 commit 9ac4f86
Show file tree
Hide file tree
Showing 19 changed files with 597 additions and 572 deletions.
119 changes: 55 additions & 64 deletions internal/eigenState/avsOperators/avsOperators.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import (
"time"

"github.com/Layr-Labs/go-sidecar/internal/config"
"github.com/Layr-Labs/go-sidecar/internal/eigenState/base"
"github.com/Layr-Labs/go-sidecar/internal/eigenState/eigenStateModel"
"github.com/Layr-Labs/go-sidecar/internal/eigenState/stateManager"
"github.com/Layr-Labs/go-sidecar/internal/eigenState/types"
"github.com/Layr-Labs/go-sidecar/internal/eigenState/utils"
"github.com/Layr-Labs/go-sidecar/internal/storage"
"github.com/Layr-Labs/go-sidecar/internal/utils"
"go.uber.org/zap"
"golang.org/x/xerrors"
"gorm.io/gorm"
Expand Down Expand Up @@ -58,10 +58,9 @@ func NewSlotID(avs string, operator string) types.SlotID {
}

// EigenState model for AVS operators that implements IEigenStateModel.
type AvsOperatorsModel struct {
base.BaseEigenState
type AvsOperatorsBaseModel struct {
StateTransitions types.StateTransitions[AccumulatedStateChange]
DB *gorm.DB
db *gorm.DB
logger *zap.Logger
globalConfig *config.Config

Expand All @@ -72,51 +71,66 @@ type AvsOperatorsModel struct {
deltaAccumulator map[uint64][]*AvsOperatorStateChange
}

// NewAvsOperators creates a new AvsOperatorsModel.
func NewAvsOperators(
// NewAvsOperators creates a new AvsOperatorsBaseModel.
func NewAvsOperatorsModel(
esm *stateManager.EigenStateManager,
grm *gorm.DB,
logger *zap.Logger,
globalConfig *config.Config,
) (*AvsOperatorsModel, error) {
s := &AvsOperatorsModel{
BaseEigenState: base.BaseEigenState{
Logger: logger,
},
DB: grm,
) (*eigenStateModel.EigenStateModel, error) {
base := &AvsOperatorsBaseModel{
db: grm,
logger: logger,
globalConfig: globalConfig,

stateAccumulator: make(map[uint64]map[types.SlotID]*AccumulatedStateChange),

deltaAccumulator: make(map[uint64][]*AvsOperatorStateChange),
}
esm.RegisterState(s, 0)
return s, nil
m := eigenStateModel.NewEigenStateModel(base)

esm.RegisterState(m, 0)
return m, nil
}

func (a *AvsOperatorsBaseModel) Logger() *zap.Logger {
return a.logger
}

func (a *AvsOperatorsBaseModel) ModelName() string {
return "AvsOperatorsBaseModel"
}

func (a *AvsOperatorsBaseModel) TableName() string {
return "registered_avs_operators"
}

func (a *AvsOperatorsBaseModel) DB() *gorm.DB {
return a.db
}

func (a *AvsOperatorsModel) GetModelName() string {
return "AvsOperatorsModel"
func (a *AvsOperatorsBaseModel) Base() interface{} {
return a
}

// Get the state transitions for the AvsOperatorsModel state model
// Get the state transitions for the AvsOperatorsBaseModel state model
//
// Each state transition is function indexed by a block number.
// BlockNumber 0 is the catchall state
//
// Returns the map and a reverse sorted list of block numbers that can be traversed when
// processing a log to determine which state change to apply.
func (a *AvsOperatorsModel) GetStateTransitions() (types.StateTransitions[AccumulatedStateChange], []uint64) {
func (a *AvsOperatorsBaseModel) GetStateTransitions() (types.StateTransitions[AccumulatedStateChange], []uint64) {
stateChanges := make(types.StateTransitions[AccumulatedStateChange])

// TODO(seanmcgary): make this not a closure so this function doesnt get big an messy...
stateChanges[0] = func(log *storage.TransactionLog) (*AccumulatedStateChange, error) {
arguments, err := a.ParseLogArguments(log)
arguments, err := utils.ParseLogArguments(a.logger, log)
if err != nil {
return nil, err
}

outputData, err := a.ParseLogOutput(log)
outputData, err := utils.ParseLogOutput(a.logger, log)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -179,7 +193,7 @@ func (a *AvsOperatorsModel) GetStateTransitions() (types.StateTransitions[Accumu
}

// Returns a map of contract addresses to event names that are interesting to the state model.
func (a *AvsOperatorsModel) getContractAddressesForEnvironment() map[string][]string {
func (a *AvsOperatorsBaseModel) GetInterestingLogMap() map[string][]string {
contracts := a.globalConfig.GetContractsMapForChain()
return map[string][]string{
contracts.AvsDirectory: {
Expand All @@ -188,13 +202,7 @@ func (a *AvsOperatorsModel) getContractAddressesForEnvironment() map[string][]st
}
}

// Given a log, determine if it is interesting to the state model.
func (a *AvsOperatorsModel) IsInterestingLog(log *storage.TransactionLog) bool {
addresses := a.getContractAddressesForEnvironment()
return a.BaseEigenState.IsInterestingLog(addresses, log)
}

func (a *AvsOperatorsModel) InitBlockProcessing(blockNumber uint64) error {
func (a *AvsOperatorsBaseModel) InitBlockProcessing(blockNumber uint64) error {
a.stateAccumulator[blockNumber] = make(map[types.SlotID]*AccumulatedStateChange)
a.deltaAccumulator[blockNumber] = make([]*AvsOperatorStateChange, 0)
return nil
Expand All @@ -203,7 +211,7 @@ func (a *AvsOperatorsModel) InitBlockProcessing(blockNumber uint64) error {
// Handle the state change for the given log
//
// Takes a log and iterates over the state transitions to determine which state change to apply based on block number.
func (a *AvsOperatorsModel) HandleStateChange(log *storage.TransactionLog) (interface{}, error) {
func (a *AvsOperatorsBaseModel) HandleStateChange(log *storage.TransactionLog) (interface{}, error) {
stateChanges, sortedBlockNumbers := a.GetStateTransitions()

for _, blockNumber := range sortedBlockNumbers {
Expand All @@ -224,7 +232,7 @@ func (a *AvsOperatorsModel) HandleStateChange(log *storage.TransactionLog) (inte
return nil, nil
}

func (a *AvsOperatorsModel) clonePreviousBlocksToNewBlock(blockNumber uint64) error {
func (a *AvsOperatorsBaseModel) clonePreviousBlocksToNewBlock(blockNumber uint64) error {
query := `
insert into registered_avs_operators (avs, operator, block_number)
select
Expand All @@ -234,7 +242,7 @@ func (a *AvsOperatorsModel) clonePreviousBlocksToNewBlock(blockNumber uint64) er
from registered_avs_operators
where block_number = @previousBlock
`
res := a.DB.Exec(query,
res := a.db.Exec(query,
sql.Named("currentBlock", blockNumber),
sql.Named("previousBlock", blockNumber-1),
)
Expand All @@ -248,7 +256,7 @@ func (a *AvsOperatorsModel) clonePreviousBlocksToNewBlock(blockNumber uint64) er

// prepareState prepares the state for the current block by comparing the accumulated state changes.
// It separates out the changes into inserts and deletes.
func (a *AvsOperatorsModel) prepareState(blockNumber uint64) ([]RegisteredAvsOperators, []RegisteredAvsOperators, error) {
func (a *AvsOperatorsBaseModel) prepareState(blockNumber uint64) ([]RegisteredAvsOperators, []RegisteredAvsOperators, error) {
accumulatedState, ok := a.stateAccumulator[blockNumber]
if !ok {
err := xerrors.Errorf("No accumulated state found for block %d", blockNumber)
Expand All @@ -273,7 +281,7 @@ func (a *AvsOperatorsModel) prepareState(blockNumber uint64) ([]RegisteredAvsOpe
return inserts, deletes, nil
}

func (a *AvsOperatorsModel) writeDeltaRecordsToDeltaTable(blockNumber uint64) error {
func (a *AvsOperatorsBaseModel) writeDeltaRecordsToDeltaTable(blockNumber uint64) error {
records, ok := a.deltaAccumulator[blockNumber]
if !ok {
msg := "Delta accumulator was not initialized"
Expand All @@ -282,7 +290,7 @@ func (a *AvsOperatorsModel) writeDeltaRecordsToDeltaTable(blockNumber uint64) er
}

if len(records) > 0 {
res := a.DB.Model(&AvsOperatorStateChange{}).Clauses(clause.Returning{}).Create(&records)
res := a.db.Model(&AvsOperatorStateChange{}).Clauses(clause.Returning{}).Create(&records)
if res.Error != nil {
a.logger.Sugar().Errorw("Failed to insert delta records", zap.Error(res.Error))
return res.Error
Expand All @@ -292,7 +300,7 @@ func (a *AvsOperatorsModel) writeDeltaRecordsToDeltaTable(blockNumber uint64) er
}

// CommitFinalState commits the final state for the given block number.
func (a *AvsOperatorsModel) CommitFinalState(blockNumber uint64) error {
func (a *AvsOperatorsBaseModel) CommitFinalState(blockNumber uint64) error {
err := a.clonePreviousBlocksToNewBlock(blockNumber)
if err != nil {
return err
Expand All @@ -304,7 +312,7 @@ func (a *AvsOperatorsModel) CommitFinalState(blockNumber uint64) error {
}

for _, record := range recordsToDelete {
res := a.DB.Delete(&RegisteredAvsOperators{}, "avs = ? and operator = ? and block_number = ?", record.Avs, record.Operator, record.BlockNumber)
res := a.db.Delete(&RegisteredAvsOperators{}, "avs = ? and operator = ? and block_number = ?", record.Avs, record.Operator, record.BlockNumber)
if res.Error != nil {
a.logger.Sugar().Errorw("Failed to delete record",
zap.Error(res.Error),
Expand All @@ -316,7 +324,7 @@ func (a *AvsOperatorsModel) CommitFinalState(blockNumber uint64) error {
}
}
if len(recordsToInsert) > 0 {
res := a.DB.Model(&RegisteredAvsOperators{}).Clauses(clause.Returning{}).Create(&recordsToInsert)
res := a.db.Model(&RegisteredAvsOperators{}).Clauses(clause.Returning{}).Create(&recordsToInsert)
if res.Error != nil {
a.logger.Sugar().Errorw("Failed to insert records", zap.Error(res.Error))
return res.Error
Expand All @@ -330,60 +338,43 @@ func (a *AvsOperatorsModel) CommitFinalState(blockNumber uint64) error {
return nil
}

func (a *AvsOperatorsModel) ClearAccumulatedState(blockNumber uint64) error {
func (a *AvsOperatorsBaseModel) ClearAccumulatedState(blockNumber uint64) error {
delete(a.stateAccumulator, blockNumber)
delete(a.deltaAccumulator, blockNumber)
return nil
}

// GenerateStateRoot generates the state root for the given block number using the results of the state changes.
func (a *AvsOperatorsModel) GenerateStateRoot(blockNumber uint64) (types.StateRoot, error) {
func (a *AvsOperatorsBaseModel) GetStateDiffs(blockNumber uint64) ([]types.StateDiff, error) {
inserts, deletes, err := a.prepareState(blockNumber)
if err != nil {
return "", err
return nil, err
}

combinedResults := make([]*RegisteredAvsOperatorDiff, 0)
diffs := make([]*RegisteredAvsOperatorDiff, 0)
for _, record := range inserts {
combinedResults = append(combinedResults, &RegisteredAvsOperatorDiff{
diffs = append(diffs, &RegisteredAvsOperatorDiff{
Avs: record.Avs,
Operator: record.Operator,
BlockNumber: record.BlockNumber,
Registered: true,
})
}
for _, record := range deletes {
combinedResults = append(combinedResults, &RegisteredAvsOperatorDiff{
diffs = append(diffs, &RegisteredAvsOperatorDiff{
Avs: record.Avs,
Operator: record.Operator,
BlockNumber: record.BlockNumber,
Registered: false,
})
}

inputs := a.sortValuesForMerkleTree(combinedResults)

fullTree, err := a.MerkleizeState(blockNumber, inputs)
if err != nil {
return "", err
}
return types.StateRoot(utils.ConvertBytesToString(fullTree.Root())), nil
}

func (a *AvsOperatorsModel) sortValuesForMerkleTree(diffs []*RegisteredAvsOperatorDiff) []*base.MerkleTreeInput {
inputs := make([]*base.MerkleTreeInput, 0)
stateDiffs := make([]types.StateDiff, 0)
for _, diff := range diffs {
inputs = append(inputs, &base.MerkleTreeInput{
stateDiffs = append(stateDiffs, types.StateDiff{
SlotID: NewSlotID(diff.Avs, diff.Operator),
Value: []byte(fmt.Sprintf("%t", diff.Registered)),
})
}
slices.SortFunc(inputs, func(i, j *base.MerkleTreeInput) int {
return strings.Compare(string(i.SlotID), string(j.SlotID))
})
return inputs
}

func (a *AvsOperatorsModel) DeleteState(startBlockNumber uint64, endBlockNumber uint64) error {
return a.BaseEigenState.DeleteState("registered_avs_operators", startBlockNumber, endBlockNumber, a.DB)
return stateDiffs, nil
}
33 changes: 17 additions & 16 deletions internal/eigenState/avsOperators/avsOperators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/Layr-Labs/go-sidecar/internal/config"
"github.com/Layr-Labs/go-sidecar/internal/eigenState/eigenStateModel"
"github.com/Layr-Labs/go-sidecar/internal/eigenState/stateManager"
"github.com/Layr-Labs/go-sidecar/internal/logger"
"github.com/Layr-Labs/go-sidecar/internal/sqlite/migrations"
Expand Down Expand Up @@ -38,23 +39,23 @@ func setup() (
return cfg, db, l, err
}

func teardown(model *AvsOperatorsModel) {
model.DB.Exec("delete from avs_operator_changes")
model.DB.Exec("delete from registered_avs_operators")
model.DB.Exec("delete from avs_operator_state_changes")
func teardown(model *eigenStateModel.EigenStateModel) {
model.DB().Exec("delete from avs_operator_changes")
model.DB().Exec("delete from registered_avs_operators")
model.DB().Exec("delete from avs_operator_state_changes")
}

func getInsertedDeltaRecordsForBlock(blockNumber uint64, model *AvsOperatorsModel) ([]*AvsOperatorStateChange, error) {
func getInsertedDeltaRecordsForBlock(blockNumber uint64, model *eigenStateModel.EigenStateModel) ([]*AvsOperatorStateChange, error) {
results := []*AvsOperatorStateChange{}

res := model.DB.Model(&AvsOperatorStateChange{}).Where("block_number = ?", blockNumber).Find(&results)
res := model.DB().Model(&AvsOperatorStateChange{}).Where("block_number = ?", blockNumber).Find(&results)
return results, res.Error
}

func getInsertedDeltaRecords(model *AvsOperatorsModel) ([]*AvsOperatorStateChange, error) {
func getInsertedDeltaRecords(model *eigenStateModel.EigenStateModel) ([]*AvsOperatorStateChange, error) {
results := []*AvsOperatorStateChange{}

res := model.DB.Model(&AvsOperatorStateChange{}).Order("block_number asc").Find(&results)
res := model.DB().Model(&AvsOperatorStateChange{}).Order("block_number asc").Find(&results)
return results, res.Error
}

Expand All @@ -67,7 +68,7 @@ func Test_AvsOperatorState(t *testing.T) {

t.Run("Should create a new AvsOperatorState", func(t *testing.T) {
esm := stateManager.NewEigenStateManager(l, grm)
avsOperatorState, err := NewAvsOperators(esm, grm, l, cfg)
avsOperatorState, err := NewAvsOperatorsModel(esm, grm, l, cfg)
assert.Nil(t, err)
assert.NotNil(t, avsOperatorState)
})
Expand All @@ -88,7 +89,7 @@ func Test_AvsOperatorState(t *testing.T) {
DeletedAt: time.Time{},
}

avsOperatorState, err := NewAvsOperators(esm, grm, l, cfg)
avsOperatorState, err := NewAvsOperatorsModel(esm, grm, l, cfg)
assert.Nil(t, err)

assert.Equal(t, true, avsOperatorState.IsInterestingLog(&log))
Expand Down Expand Up @@ -135,7 +136,7 @@ func Test_AvsOperatorState(t *testing.T) {
DeletedAt: time.Time{},
}

avsOperatorState, err := NewAvsOperators(esm, grm, l, cfg)
avsOperatorState, err := NewAvsOperatorsModel(esm, grm, l, cfg)
assert.Nil(t, err)

assert.Equal(t, true, avsOperatorState.IsInterestingLog(&log))
Expand All @@ -151,7 +152,7 @@ func Test_AvsOperatorState(t *testing.T) {
assert.Nil(t, err)

states := []RegisteredAvsOperators{}
statesRes := avsOperatorState.DB.
statesRes := avsOperatorState.DB().
Model(&RegisteredAvsOperators{}).
Raw("select * from registered_avs_operators where block_number = @blockNumber", sql.Named("blockNumber", blockNumber)).
Scan(&states)
Expand Down Expand Up @@ -205,7 +206,7 @@ func Test_AvsOperatorState(t *testing.T) {
},
}

avsOperatorState, err := NewAvsOperators(esm, grm, l, cfg)
avsOperatorState, err := NewAvsOperatorsModel(esm, grm, l, cfg)
assert.Nil(t, err)

for _, log := range logs {
Expand All @@ -222,7 +223,7 @@ func Test_AvsOperatorState(t *testing.T) {
assert.Nil(t, err)

states := []RegisteredAvsOperators{}
statesRes := avsOperatorState.DB.
statesRes := avsOperatorState.DB().
Model(&RegisteredAvsOperators{}).
Raw("select * from registered_avs_operators where block_number = @blockNumber", sql.Named("blockNumber", log.BlockNumber)).
Scan(&states)
Expand All @@ -233,13 +234,13 @@ func Test_AvsOperatorState(t *testing.T) {

if log.BlockNumber == blocks[0] {
assert.Equal(t, 1, len(states))
inserts, deletes, err := avsOperatorState.prepareState(log.BlockNumber)
inserts, deletes, err := avsOperatorState.Base().(*AvsOperatorsBaseModel).prepareState(log.BlockNumber)
assert.Nil(t, err)
assert.Equal(t, 1, len(inserts))
assert.Equal(t, 0, len(deletes))
} else if log.BlockNumber == blocks[1] {
assert.Equal(t, 0, len(states))
inserts, deletes, err := avsOperatorState.prepareState(log.BlockNumber)
inserts, deletes, err := avsOperatorState.Base().(*AvsOperatorsBaseModel).prepareState(log.BlockNumber)
assert.Nil(t, err)
assert.Equal(t, 0, len(inserts))
assert.Equal(t, 1, len(deletes))
Expand Down
Loading

0 comments on commit 9ac4f86

Please sign in to comment.