Skip to content

Commit

Permalink
Enabled automatic ATA creation in CW
Browse files Browse the repository at this point in the history
  • Loading branch information
silaslenihan committed Jan 30, 2025
1 parent aa71d84 commit ed8ad2f
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 238 deletions.
90 changes: 87 additions & 3 deletions pkg/solana/chainwriter/chain_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package chainwriter
import (
"context"
"encoding/json"
"errors"
"fmt"
"math/big"

"github.com/gagliardetto/solana-go"
addresslookuptable "github.com/gagliardetto/solana-go/programs/address-lookup-table"
"github.com/gagliardetto/solana-go/rpc"

"github.com/smartcontractkit/chainlink-ccip/chains/solana/utils/tokens"
commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/services"
Expand Down Expand Up @@ -55,6 +57,7 @@ type MethodConfig struct {
FromAddress string
InputModifications commoncodec.ModifiersConfig
ChainSpecificName string
ATAs []ATALookup
LookupTables LookupTables
Accounts []Lookup
// Location in the args where the debug ID is stored
Expand Down Expand Up @@ -214,6 +217,79 @@ func (s *SolanaChainWriterService) FilterLookupTableAddresses(
return filteredLookupTables
}

// CreateATAs first checks if a specified location exists, then checks if the accounts derived from the
// ATALookups in the ChainWriter's configuration exist on-chain and creates them if they do not.
func CreateATAs(ctx context.Context, args any, lookups []ATALookup, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader, idl string, feePayer solana.PublicKey) ([]solana.Instruction, error) {
createATAInstructions := []solana.Instruction{}
for _, lookup := range lookups {
// Check if location exists
if lookup.Location != "" {
// TODO refactor GetValuesAtLocation to not return an error if the field doesn't exist
_, err := GetValuesAtLocation(args, lookup.Location)
if err != nil {
// field doesn't exist, so ignore ATA creation
if errors.Is(err, errFieldNotFound) {
continue
}
return nil, fmt.Errorf("error getting values at location: %w", err)
}
}
walletAddresses, err := GetAddresses(ctx, args, []Lookup{lookup.WalletAddress}, derivedTableMap, reader, idl)
if err != nil {
return nil, fmt.Errorf("error resolving wallet address: %w", err)
}
if len(walletAddresses) != 1 {
return nil, fmt.Errorf("expected exactly one wallet address, got %d", len(walletAddresses))
}
wallet := walletAddresses[0].PublicKey

tokenPrograms, err := GetAddresses(ctx, args, []Lookup{lookup.TokenProgram}, derivedTableMap, reader, idl)
if err != nil {
return nil, fmt.Errorf("error resolving token program address: %w", err)
}

mints, err := GetAddresses(ctx, args, []Lookup{lookup.MintAddress}, derivedTableMap, reader, idl)
if err != nil {
return nil, fmt.Errorf("error resolving mint address: %w", err)
}

if len(tokenPrograms) != len(mints) {
return nil, fmt.Errorf("expected equal number of token programs and mints, got %d tokenPrograms and %d mints", len(tokenPrograms), len(mints))
}

for i := range tokenPrograms {
tokenProgram := tokenPrograms[i].PublicKey
mint := mints[i].PublicKey

ataAddress, _, err := tokens.FindAssociatedTokenAddress(tokenProgram, mint, wallet)
if err != nil {
return nil, fmt.Errorf("error deriving ATA: %w", err)
}

accountInfo, err := reader.GetAccountInfoWithOpts(ctx, ataAddress, &rpc.GetAccountInfoOpts{
Encoding: "base64",
Commitment: rpc.CommitmentFinalized,
})
if err != nil {
return nil, fmt.Errorf("error checking ATA %s on-chain: %w", ataAddress, err)
}

// Check if account exists on-chain
if accountInfo.Value != nil {
continue
}

ins, _, err := tokens.CreateAssociatedTokenAccount(tokenProgram, mint, wallet, feePayer)
if err != nil {
return nil, fmt.Errorf("error creating associated token account: %w", err)
}
createATAInstructions = append(createATAInstructions, ins)
}
}

return createATAInstructions, nil
}

// SubmitTransaction builds, encodes, and enqueues a transaction using the provided program
// configuration and method details. It relies on the configured IDL, account lookups, and
// lookup tables to gather the necessary accounts and data. The function retrieves the latest
Expand Down Expand Up @@ -274,6 +350,11 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
return errorWithDebugID(fmt.Errorf("error parsing fee payer address: %w", err), debugID)
}

createATAinstructions, err := CreateATAs(ctx, args, methodConfig.ATAs, derivedTableMap, s.reader, programConfig.IDL, feePayer)
if err != nil {
return errorWithDebugID(fmt.Errorf("error resolving account addresses: %w", err), debugID)
}

// Filter the lookup table addresses based on which accounts are actually used
filteredLookupTableMap := s.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap)

Expand Down Expand Up @@ -310,10 +391,13 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
discriminator := GetDiscriminator(methodConfig.ChainSpecificName)
encodedPayload = append(discriminator[:], encodedPayload...)

// Combine the two sets of instructions into one slice
var instructions []solana.Instruction
instructions = append(instructions, createATAinstructions...)
instructions = append(instructions, solana.NewInstruction(programID, accounts, encodedPayload))

tx, err := solana.NewTransaction(
[]solana.Instruction{
solana.NewInstruction(programID, accounts, encodedPayload),
},
instructions,
blockhash.Value.Blockhash,
solana.TransactionPayer(feePayer),
solana.TransactionAddressTables(filteredLookupTableMap),
Expand Down
215 changes: 0 additions & 215 deletions pkg/solana/chainwriter/chain_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -799,221 +799,6 @@ func TestChainWriter_CCIPRouter(t *testing.T) {
})
}

func TestChainWriter_CCIPRouter(t *testing.T) {
t.Parallel()

// setup admin key
adminPk, err := solana.NewRandomPrivateKey()
require.NoError(t, err)
admin := adminPk.PublicKey()

routerAddr := chainwriter.GetRandomPubKey(t)
destTokenAddr := chainwriter.GetRandomPubKey(t)

poolKeys := []solana.PublicKey{destTokenAddr}
poolKeys = append(poolKeys, chainwriter.CreateTestPubKeys(t, 3)...)

// simplified CCIP Config - does not contain full account list
ccipCWConfig := chainwriter.ChainWriterConfig{
Programs: map[string]chainwriter.ProgramConfig{
"ccip_router": {
Methods: map[string]chainwriter.MethodConfig{
"execute": {
FromAddress: admin.String(),
InputModifications: []codec.ModifierConfig{
&codec.RenameModifierConfig{
Fields: map[string]string{"ReportContextByteWords": "ReportContext"},
},
&codec.RenameModifierConfig{
Fields: map[string]string{"RawExecutionReport": "Report"},
},
},
ChainSpecificName: "execute",
ArgsTransform: "CCIP",
LookupTables: chainwriter.LookupTables{},
Accounts: []chainwriter.Lookup{
chainwriter.AccountConstant{
Name: "testAcc1",
Address: chainwriter.GetRandomPubKey(t).String(),
},
chainwriter.AccountConstant{
Name: "testAcc2",
Address: chainwriter.GetRandomPubKey(t).String(),
},
chainwriter.AccountConstant{
Name: "testAcc3",
Address: chainwriter.GetRandomPubKey(t).String(),
},
chainwriter.AccountConstant{
Name: "poolAddr1",
Address: poolKeys[0].String(),
},
chainwriter.AccountConstant{
Name: "poolAddr2",
Address: poolKeys[1].String(),
},
chainwriter.AccountConstant{
Name: "poolAddr3",
Address: poolKeys[2].String(),
},
chainwriter.AccountConstant{
Name: "poolAddr4",
Address: poolKeys[3].String(),
},
},
},
"commit": {
FromAddress: admin.String(),
InputModifications: []codec.ModifierConfig{
&codec.RenameModifierConfig{
Fields: map[string]string{"ReportContextByteWords": "ReportContext"},
},
&codec.RenameModifierConfig{
Fields: map[string]string{"RawReport": "Report"},
},
},
ChainSpecificName: "commit",
ArgsTransform: "",
LookupTables: chainwriter.LookupTables{},
Accounts: []chainwriter.Lookup{
chainwriter.AccountConstant{
Name: "testAcc1",
Address: chainwriter.GetRandomPubKey(t).String(),
},
chainwriter.AccountConstant{
Name: "testAcc2",
Address: chainwriter.GetRandomPubKey(t).String(),
},
chainwriter.AccountConstant{
Name: "testAcc3",
Address: chainwriter.GetRandomPubKey(t).String(),
},
},
},
},
IDL: ccipRouterIDL,
},
},
}

ctx := tests.Context(t)
// mock client
rw := clientmocks.NewReaderWriter(t)
// mock estimator
ge := feemocks.NewEstimator(t)

t.Run("CCIP execute is encoded successfully and ArgsTransform is applied correctly.", func(t *testing.T) {
// mock txm
txm := txmMocks.NewTxManager(t)
// initialize chain writer
cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, ccipCWConfig)
require.NoError(t, err)

recentBlockHash := solana.Hash{}
rw.On("LatestBlockhash", mock.Anything).Return(&rpc.GetLatestBlockhashResult{Value: &rpc.LatestBlockhashResult{Blockhash: recentBlockHash, LastValidBlockHeight: uint64(100)}}, nil).Once()

pda, _, err := solana.FindProgramAddress([][]byte{[]byte("token_admin_registry"), destTokenAddr.Bytes()}, routerAddr)
require.NoError(t, err)

lookupTable := mockTokenAdminRegistryLookupTable(t, rw, pda)

mockFetchLookupTableAddresses(t, rw, lookupTable, poolKeys)

txID := uuid.NewString()
txm.On("Enqueue", mock.Anything, admin.String(), mock.MatchedBy(func(tx *solana.Transaction) bool {
txData := tx.Message.Instructions[0].Data
payload := txData[8:]
var decoded ccip_router.Execute
dec := ag_binary.NewBorshDecoder(payload)
err = dec.Decode(&decoded)
require.NoError(t, err)

tokenIndexes := *decoded.TokenIndexes

require.Len(t, tokenIndexes, 1)
require.Equal(t, uint8(3), tokenIndexes[0])
return true
}), &txID, mock.Anything).Return(nil).Once()

// stripped back report just for purposes of example
abstractReport := ccipocr3.ExecutePluginReportSingleChain{
Messages: []ccipocr3.Message{
{
TokenAmounts: []ccipocr3.RampTokenAmount{
{
DestTokenAddress: destTokenAddr.Bytes(),
},
},
},
},
}

// Marshal the abstract report to json just for testing purposes.
encodedReport, err := json.Marshal(abstractReport)
require.NoError(t, err)

args := chainwriter.ReportPreTransform{
ReportContext: [2][32]byte{{0x01}, {0x02}},
Report: encodedReport,
Info: ccipocr3.ExecuteReportInfo{
MerkleRoots: []ccipocr3.MerkleRootChain{},
AbstractReports: []ccipocr3.ExecutePluginReportSingleChain{abstractReport},
},
}

submitErr := cw.SubmitTransaction(ctx, "ccip_router", "execute", args, txID, routerAddr.String(), nil, nil)
require.NoError(t, submitErr)
})

t.Run("CCIP commit is encoded successfully", func(t *testing.T) {
// mock txm
txm := txmMocks.NewTxManager(t)
// initialize chain writer
cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, ccipCWConfig)
require.NoError(t, err)

recentBlockHash := solana.Hash{}
rw.On("LatestBlockhash", mock.Anything).Return(&rpc.GetLatestBlockhashResult{Value: &rpc.LatestBlockhashResult{Blockhash: recentBlockHash, LastValidBlockHeight: uint64(100)}}, nil).Once()

type CommitArgs struct {
ReportContext [2][32]byte
Report []byte
Rs [][32]byte
Ss [][32]byte
RawVs [32]byte
Info ccipocr3.CommitReportInfo
}

txID := uuid.NewString()

// TODO: Replace with actual type from ccipocr3
args := CommitArgs{
ReportContext: [2][32]byte{{0x01}, {0x02}},
Report: []byte{0x01, 0x02},
Rs: [][32]byte{{0x01, 0x02}},
Ss: [][32]byte{{0x01, 0x02}},
RawVs: [32]byte{0x01, 0x02},
Info: ccipocr3.CommitReportInfo{
RemoteF: 1,
MerkleRoots: []ccipocr3.MerkleRootChain{},
},
}

txm.On("Enqueue", mock.Anything, admin.String(), mock.MatchedBy(func(tx *solana.Transaction) bool {
txData := tx.Message.Instructions[0].Data
payload := txData[8:]
var decoded ccip_router.Commit
dec := ag_binary.NewBorshDecoder(payload)
err := dec.Decode(&decoded)
require.NoError(t, err)
return true
}), &txID, mock.Anything).Return(nil).Once()

submitErr := cw.SubmitTransaction(ctx, "ccip_router", "commit", args, txID, routerAddr.String(), nil, nil)
require.NoError(t, submitErr)
})
}

func TestChainWriter_GetTransactionStatus(t *testing.T) {
t.Parallel()

Expand Down
Loading

0 comments on commit ed8ad2f

Please sign in to comment.