diff --git a/pkg/solana/chainwriter/chain_writer.go b/pkg/solana/chainwriter/chain_writer.go index 1bf8a3a8b..8d3928150 100644 --- a/pkg/solana/chainwriter/chain_writer.go +++ b/pkg/solana/chainwriter/chain_writer.go @@ -3,6 +3,7 @@ package chainwriter import ( "context" "encoding/json" + "errors" "fmt" "math/big" @@ -10,6 +11,7 @@ import ( addresslookuptable "github.com/gagliardetto/solana-go/programs/address-lookup-table" "github.com/gagliardetto/solana-go/rpc" + "github.com/smartcontractkit/chainlink-ccip/chains/solana/utils/tokens" commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" @@ -55,6 +57,7 @@ type MethodConfig struct { FromAddress string InputModifications commoncodec.ModifiersConfig ChainSpecificName string + ATAs []ATALookup LookupTables LookupTables Accounts []Lookup // Location in the args where the debug ID is stored @@ -214,6 +217,79 @@ func (s *SolanaChainWriterService) FilterLookupTableAddresses( return filteredLookupTables } +// CreateATAs first checks if a specified location exists, then checks if the accounts derived from the +// ATALookups in the ChainWriter's configuration exist on-chain and creates them if they do not. +func CreateATAs(ctx context.Context, args any, lookups []ATALookup, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader, idl string, feePayer solana.PublicKey) ([]solana.Instruction, error) { + createATAInstructions := []solana.Instruction{} + for _, lookup := range lookups { + // Check if location exists + if lookup.Location != "" { + // TODO refactor GetValuesAtLocation to not return an error if the field doesn't exist + _, err := GetValuesAtLocation(args, lookup.Location) + if err != nil { + // field doesn't exist, so ignore ATA creation + if errors.Is(err, errFieldNotFound) { + continue + } + return nil, fmt.Errorf("error getting values at location: %w", err) + } + } + walletAddresses, err := GetAddresses(ctx, args, []Lookup{lookup.WalletAddress}, derivedTableMap, reader, idl) + if err != nil { + return nil, fmt.Errorf("error resolving wallet address: %w", err) + } + if len(walletAddresses) != 1 { + return nil, fmt.Errorf("expected exactly one wallet address, got %d", len(walletAddresses)) + } + wallet := walletAddresses[0].PublicKey + + tokenPrograms, err := GetAddresses(ctx, args, []Lookup{lookup.TokenProgram}, derivedTableMap, reader, idl) + if err != nil { + return nil, fmt.Errorf("error resolving token program address: %w", err) + } + + mints, err := GetAddresses(ctx, args, []Lookup{lookup.MintAddress}, derivedTableMap, reader, idl) + if err != nil { + return nil, fmt.Errorf("error resolving mint address: %w", err) + } + + if len(tokenPrograms) != len(mints) { + return nil, fmt.Errorf("expected equal number of token programs and mints, got %d tokenPrograms and %d mints", len(tokenPrograms), len(mints)) + } + + for i := range tokenPrograms { + tokenProgram := tokenPrograms[i].PublicKey + mint := mints[i].PublicKey + + ataAddress, _, err := tokens.FindAssociatedTokenAddress(tokenProgram, mint, wallet) + if err != nil { + return nil, fmt.Errorf("error deriving ATA: %w", err) + } + + accountInfo, err := reader.GetAccountInfoWithOpts(ctx, ataAddress, &rpc.GetAccountInfoOpts{ + Encoding: "base64", + Commitment: rpc.CommitmentFinalized, + }) + if err != nil { + return nil, fmt.Errorf("error checking ATA %s on-chain: %w", ataAddress, err) + } + + // Check if account exists on-chain + if accountInfo.Value != nil { + continue + } + + ins, _, err := tokens.CreateAssociatedTokenAccount(tokenProgram, mint, wallet, feePayer) + if err != nil { + return nil, fmt.Errorf("error creating associated token account: %w", err) + } + createATAInstructions = append(createATAInstructions, ins) + } + } + + return createATAInstructions, nil +} + // SubmitTransaction builds, encodes, and enqueues a transaction using the provided program // configuration and method details. It relies on the configured IDL, account lookups, and // lookup tables to gather the necessary accounts and data. The function retrieves the latest @@ -274,6 +350,11 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra return errorWithDebugID(fmt.Errorf("error parsing fee payer address: %w", err), debugID) } + createATAinstructions, err := CreateATAs(ctx, args, methodConfig.ATAs, derivedTableMap, s.reader, programConfig.IDL, feePayer) + if err != nil { + return errorWithDebugID(fmt.Errorf("error resolving account addresses: %w", err), debugID) + } + // Filter the lookup table addresses based on which accounts are actually used filteredLookupTableMap := s.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap) @@ -310,10 +391,13 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra discriminator := GetDiscriminator(methodConfig.ChainSpecificName) encodedPayload = append(discriminator[:], encodedPayload...) + // Combine the two sets of instructions into one slice + var instructions []solana.Instruction + instructions = append(instructions, createATAinstructions...) + instructions = append(instructions, solana.NewInstruction(programID, accounts, encodedPayload)) + tx, err := solana.NewTransaction( - []solana.Instruction{ - solana.NewInstruction(programID, accounts, encodedPayload), - }, + instructions, blockhash.Value.Blockhash, solana.TransactionPayer(feePayer), solana.TransactionAddressTables(filteredLookupTableMap), diff --git a/pkg/solana/chainwriter/chain_writer_test.go b/pkg/solana/chainwriter/chain_writer_test.go index 9c984b69a..81a330c74 100644 --- a/pkg/solana/chainwriter/chain_writer_test.go +++ b/pkg/solana/chainwriter/chain_writer_test.go @@ -799,221 +799,6 @@ func TestChainWriter_CCIPRouter(t *testing.T) { }) } -func TestChainWriter_CCIPRouter(t *testing.T) { - t.Parallel() - - // setup admin key - adminPk, err := solana.NewRandomPrivateKey() - require.NoError(t, err) - admin := adminPk.PublicKey() - - routerAddr := chainwriter.GetRandomPubKey(t) - destTokenAddr := chainwriter.GetRandomPubKey(t) - - poolKeys := []solana.PublicKey{destTokenAddr} - poolKeys = append(poolKeys, chainwriter.CreateTestPubKeys(t, 3)...) - - // simplified CCIP Config - does not contain full account list - ccipCWConfig := chainwriter.ChainWriterConfig{ - Programs: map[string]chainwriter.ProgramConfig{ - "ccip_router": { - Methods: map[string]chainwriter.MethodConfig{ - "execute": { - FromAddress: admin.String(), - InputModifications: []codec.ModifierConfig{ - &codec.RenameModifierConfig{ - Fields: map[string]string{"ReportContextByteWords": "ReportContext"}, - }, - &codec.RenameModifierConfig{ - Fields: map[string]string{"RawExecutionReport": "Report"}, - }, - }, - ChainSpecificName: "execute", - ArgsTransform: "CCIP", - LookupTables: chainwriter.LookupTables{}, - Accounts: []chainwriter.Lookup{ - chainwriter.AccountConstant{ - Name: "testAcc1", - Address: chainwriter.GetRandomPubKey(t).String(), - }, - chainwriter.AccountConstant{ - Name: "testAcc2", - Address: chainwriter.GetRandomPubKey(t).String(), - }, - chainwriter.AccountConstant{ - Name: "testAcc3", - Address: chainwriter.GetRandomPubKey(t).String(), - }, - chainwriter.AccountConstant{ - Name: "poolAddr1", - Address: poolKeys[0].String(), - }, - chainwriter.AccountConstant{ - Name: "poolAddr2", - Address: poolKeys[1].String(), - }, - chainwriter.AccountConstant{ - Name: "poolAddr3", - Address: poolKeys[2].String(), - }, - chainwriter.AccountConstant{ - Name: "poolAddr4", - Address: poolKeys[3].String(), - }, - }, - }, - "commit": { - FromAddress: admin.String(), - InputModifications: []codec.ModifierConfig{ - &codec.RenameModifierConfig{ - Fields: map[string]string{"ReportContextByteWords": "ReportContext"}, - }, - &codec.RenameModifierConfig{ - Fields: map[string]string{"RawReport": "Report"}, - }, - }, - ChainSpecificName: "commit", - ArgsTransform: "", - LookupTables: chainwriter.LookupTables{}, - Accounts: []chainwriter.Lookup{ - chainwriter.AccountConstant{ - Name: "testAcc1", - Address: chainwriter.GetRandomPubKey(t).String(), - }, - chainwriter.AccountConstant{ - Name: "testAcc2", - Address: chainwriter.GetRandomPubKey(t).String(), - }, - chainwriter.AccountConstant{ - Name: "testAcc3", - Address: chainwriter.GetRandomPubKey(t).String(), - }, - }, - }, - }, - IDL: ccipRouterIDL, - }, - }, - } - - ctx := tests.Context(t) - // mock client - rw := clientmocks.NewReaderWriter(t) - // mock estimator - ge := feemocks.NewEstimator(t) - - t.Run("CCIP execute is encoded successfully and ArgsTransform is applied correctly.", func(t *testing.T) { - // mock txm - txm := txmMocks.NewTxManager(t) - // initialize chain writer - cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, ccipCWConfig) - require.NoError(t, err) - - recentBlockHash := solana.Hash{} - rw.On("LatestBlockhash", mock.Anything).Return(&rpc.GetLatestBlockhashResult{Value: &rpc.LatestBlockhashResult{Blockhash: recentBlockHash, LastValidBlockHeight: uint64(100)}}, nil).Once() - - pda, _, err := solana.FindProgramAddress([][]byte{[]byte("token_admin_registry"), destTokenAddr.Bytes()}, routerAddr) - require.NoError(t, err) - - lookupTable := mockTokenAdminRegistryLookupTable(t, rw, pda) - - mockFetchLookupTableAddresses(t, rw, lookupTable, poolKeys) - - txID := uuid.NewString() - txm.On("Enqueue", mock.Anything, admin.String(), mock.MatchedBy(func(tx *solana.Transaction) bool { - txData := tx.Message.Instructions[0].Data - payload := txData[8:] - var decoded ccip_router.Execute - dec := ag_binary.NewBorshDecoder(payload) - err = dec.Decode(&decoded) - require.NoError(t, err) - - tokenIndexes := *decoded.TokenIndexes - - require.Len(t, tokenIndexes, 1) - require.Equal(t, uint8(3), tokenIndexes[0]) - return true - }), &txID, mock.Anything).Return(nil).Once() - - // stripped back report just for purposes of example - abstractReport := ccipocr3.ExecutePluginReportSingleChain{ - Messages: []ccipocr3.Message{ - { - TokenAmounts: []ccipocr3.RampTokenAmount{ - { - DestTokenAddress: destTokenAddr.Bytes(), - }, - }, - }, - }, - } - - // Marshal the abstract report to json just for testing purposes. - encodedReport, err := json.Marshal(abstractReport) - require.NoError(t, err) - - args := chainwriter.ReportPreTransform{ - ReportContext: [2][32]byte{{0x01}, {0x02}}, - Report: encodedReport, - Info: ccipocr3.ExecuteReportInfo{ - MerkleRoots: []ccipocr3.MerkleRootChain{}, - AbstractReports: []ccipocr3.ExecutePluginReportSingleChain{abstractReport}, - }, - } - - submitErr := cw.SubmitTransaction(ctx, "ccip_router", "execute", args, txID, routerAddr.String(), nil, nil) - require.NoError(t, submitErr) - }) - - t.Run("CCIP commit is encoded successfully", func(t *testing.T) { - // mock txm - txm := txmMocks.NewTxManager(t) - // initialize chain writer - cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, ccipCWConfig) - require.NoError(t, err) - - recentBlockHash := solana.Hash{} - rw.On("LatestBlockhash", mock.Anything).Return(&rpc.GetLatestBlockhashResult{Value: &rpc.LatestBlockhashResult{Blockhash: recentBlockHash, LastValidBlockHeight: uint64(100)}}, nil).Once() - - type CommitArgs struct { - ReportContext [2][32]byte - Report []byte - Rs [][32]byte - Ss [][32]byte - RawVs [32]byte - Info ccipocr3.CommitReportInfo - } - - txID := uuid.NewString() - - // TODO: Replace with actual type from ccipocr3 - args := CommitArgs{ - ReportContext: [2][32]byte{{0x01}, {0x02}}, - Report: []byte{0x01, 0x02}, - Rs: [][32]byte{{0x01, 0x02}}, - Ss: [][32]byte{{0x01, 0x02}}, - RawVs: [32]byte{0x01, 0x02}, - Info: ccipocr3.CommitReportInfo{ - RemoteF: 1, - MerkleRoots: []ccipocr3.MerkleRootChain{}, - }, - } - - txm.On("Enqueue", mock.Anything, admin.String(), mock.MatchedBy(func(tx *solana.Transaction) bool { - txData := tx.Message.Instructions[0].Data - payload := txData[8:] - var decoded ccip_router.Commit - dec := ag_binary.NewBorshDecoder(payload) - err := dec.Decode(&decoded) - require.NoError(t, err) - return true - }), &txID, mock.Anything).Return(nil).Once() - - submitErr := cw.SubmitTransaction(ctx, "ccip_router", "commit", args, txID, routerAddr.String(), nil, nil) - require.NoError(t, submitErr) - }) -} - func TestChainWriter_GetTransactionStatus(t *testing.T) { t.Parallel() diff --git a/pkg/solana/chainwriter/helpers.go b/pkg/solana/chainwriter/helpers.go index 6e2a3e5be..1c135fb69 100644 --- a/pkg/solana/chainwriter/helpers.go +++ b/pkg/solana/chainwriter/helpers.go @@ -43,39 +43,48 @@ func FetchTestContractIDL() string { return testContractIDL } +var ( + errFieldNotFound = errors.New("key not found") +) + // GetValuesAtLocation parses through nested types and arrays to find all locations of values func GetValuesAtLocation(args any, location string) ([][]byte, error) { var vals [][]byte + // If the user specified no location, just return empty (no-op). + if location == "" { + return nil, nil + } + path := strings.Split(location, ".") - addressList, err := traversePath(args, path) + items, err := traversePath(args, path) if err != nil { return nil, err } - for _, value := range addressList { - // Dereference if it's a pointer - rv := reflect.ValueOf(value) + + for _, item := range items { + rv := reflect.ValueOf(item) if rv.Kind() == reflect.Ptr && !rv.IsNil() { - value = rv.Elem().Interface() + item = rv.Elem().Interface() } - if byteArray, ok := value.([]byte); ok { - vals = append(vals, byteArray) - } else if address, ok := value.(solana.PublicKey); ok { - vals = append(vals, address.Bytes()) - } else if num, ok := value.(uint64); ok { + switch value := item.(type) { + case []byte: + vals = append(vals, value) + case solana.PublicKey: + vals = append(vals, value.Bytes()) + case ccipocr3.UnknownAddress: + vals = append(vals, value) + case uint64: buf := make([]byte, 8) - binary.LittleEndian.PutUint64(buf, num) + binary.LittleEndian.PutUint64(buf, value) vals = append(vals, buf) - } else if addr, ok := value.(ccipocr3.UnknownAddress); ok { - vals = append(vals, addr) - } else if arr, ok := value.([32]uint8); ok { - vals = append(vals, arr[:]) - } else { + case [32]uint8: + vals = append(vals, value[:]) + default: return nil, fmt.Errorf("invalid value format at path: %s, type: %s", location, reflect.TypeOf(value).String()) } } - return vals, nil } @@ -135,7 +144,7 @@ func traversePath(data any, path []string) ([]any, error) { case reflect.Struct: field := val.FieldByName(path[0]) if !field.IsValid() { - return nil, errors.New("field not found: " + path[0]) + return []any{}, errFieldNotFound } return traversePath(field.Interface(), path[1:]) @@ -150,13 +159,13 @@ func traversePath(data any, path []string) ([]any, error) { if len(result) > 0 { return result, nil } - return nil, errors.New("no matching field found in array") + return []any{}, errFieldNotFound case reflect.Map: key := reflect.ValueOf(path[0]) value := val.MapIndex(key) if !value.IsValid() { - return nil, errors.New("key not found: " + path[0]) + return []any{}, errFieldNotFound } return traversePath(value.Interface(), path[1:]) default: diff --git a/pkg/solana/chainwriter/lookups.go b/pkg/solana/chainwriter/lookups.go index 36719538a..982201b23 100644 --- a/pkg/solana/chainwriter/lookups.go +++ b/pkg/solana/chainwriter/lookups.go @@ -83,6 +83,16 @@ type AccountsFromLookupTable struct { IncludeIndexes []int } +type ATALookup struct { + // Field that determines whether the ATA lookup is necessary. Basically + // just need to check this field exists. Dot separated location. + Location string + // If the field exists, initialize a ATA account using the Wallet, Token Program, and Mint addresses below + WalletAddress Lookup + TokenProgram Lookup + MintAddress Lookup +} + func (ac AccountConstant) Resolve(_ context.Context, _ any, _ map[string]map[string][]*solana.AccountMeta, _ client.Reader, _ string) ([]*solana.AccountMeta, error) { address, err := solana.PublicKeyFromBase58(ac.Address) if err != nil {