Skip to content

Commit

Permalink
Merge pull request #17 from getAlby/feature/block-malformed-keysend
Browse files Browse the repository at this point in the history
add a check on wallet id tlv
  • Loading branch information
kiwiidb authored Sep 11, 2023
2 parents f3ce209 + e5f6322 commit c033db9
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 5 deletions.
74 changes: 74 additions & 0 deletions check_invoice_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package main

import (
"testing"

"github.com/lightningnetwork/lnd/lnrpc"
"github.com/stretchr/testify/assert"
)

func TestCheckInvoice(t *testing.T) {
//test non keysend
assert.True(t, shouldPublishInvoice(&lnrpc.Invoice{
State: lnrpc.Invoice_SETTLED,
IsKeysend: false,
Htlcs: []*lnrpc.InvoiceHTLC{
{
ChanId: 0,
HtlcIndex: 0,
AmtMsat: 0,
AcceptHeight: 0,
AcceptTime: 0,
ResolveTime: 0,
ExpiryHeight: 0,
State: 0,
MppTotalAmtMsat: 0,
Amp: &lnrpc.AMP{},
},
},
}))
//test keysend with wallet id tlv
assert.True(t, shouldPublishInvoice(&lnrpc.Invoice{
State: lnrpc.Invoice_SETTLED,
IsKeysend: true,
Htlcs: []*lnrpc.InvoiceHTLC{
{
ChanId: 0,
HtlcIndex: 0,
AmtMsat: 0,
AcceptHeight: 0,
AcceptTime: 0,
ResolveTime: 0,
ExpiryHeight: 0,
State: 0,
CustomRecords: map[uint64][]byte{
696969: {69, 69, 69},
},
MppTotalAmtMsat: 0,
Amp: &lnrpc.AMP{},
},
},
}))
//test keysend without wallet id tlv
assert.False(t, shouldPublishInvoice(&lnrpc.Invoice{
State: lnrpc.Invoice_SETTLED,
IsKeysend: true,
Htlcs: []*lnrpc.InvoiceHTLC{
{
ChanId: 0,
HtlcIndex: 0,
AmtMsat: 0,
AcceptHeight: 0,
AcceptTime: 0,
ResolveTime: 0,
ExpiryHeight: 0,
State: 0,
CustomRecords: map[uint64][]byte{
696970: {69, 69, 70},
},
MppTotalAmtMsat: 0,
Amp: &lnrpc.AMP{},
},
},
}))
}
15 changes: 15 additions & 0 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,21 @@ func (mlnd *MockLND) mockPaidInvoice(amtPaid int64, memo string) error {
AmtPaidSat: amtPaid,
AmtPaidMsat: 1000 * amtPaid,
State: lnrpc.Invoice_SETTLED,
Htlcs: []*lnrpc.InvoiceHTLC{
{
ChanId: 0,
HtlcIndex: 0,
AmtMsat: 0,
AcceptHeight: 0,
AcceptTime: 0,
ResolveTime: 0,
ExpiryHeight: 0,
State: 0,
CustomRecords: map[uint64][]byte{},
MppTotalAmtMsat: 0,
Amp: &lnrpc.AMP{},
},
},
}
mlnd.Sub.invoiceChan <- incoming
return nil
Expand Down
30 changes: 25 additions & 5 deletions service.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ const (
LNDInvoiceRoutingKey = "invoice.incoming.settled"
LNDPaymentSuccessRoutingKey = "payment.outgoing.settled"
LNDPaymentErrorRoutingKey = "payment.outgoing.error"

TLV_WALLET_ID = 696969
)

type Service struct {
Expand Down Expand Up @@ -327,7 +329,7 @@ func (svc *Service) ProcessPayment(ctx context.Context, payment *lnrpc.Payment)
}

func (svc *Service) ProcessInvoice(ctx context.Context, invoice *lnrpc.Invoice) error {
if invoice.State == lnrpc.Invoice_SETTLED {
if shouldPublishInvoice(invoice) {
startTime := time.Now()
err := svc.PublishPayload(ctx, invoice, LNDInvoiceExchange, LNDInvoiceRoutingKey)
if err != nil {
Expand All @@ -348,14 +350,32 @@ func (svc *Service) ProcessInvoice(ctx context.Context, invoice *lnrpc.Invoice)
"settle_date": invoice.SettleDate,
"payment_hash": hex.EncodeToString(invoice.RHash),
}).Info("published invoice")
//add it to the database if we have one
if svc.db != nil {
return svc.AddLastPublishedInvoice(ctx, invoice)
}
return svc.AddLastPublishedInvoice(ctx, invoice)
}
logrus.
WithField("payment_hash", hex.EncodeToString(invoice.RHash)).
WithField("state", invoice.State).
WithField("keysend", invoice.IsKeysend).
Info("not publishing invoice")
return nil
}

// check if we need to publish an invoice
func shouldPublishInvoice(invoice *lnrpc.Invoice) (ok bool) {

//don't publish unsettled invoice
if invoice.State != lnrpc.Invoice_SETTLED {
return false
}
//if the invoice is keysend, it needs record 696969
//(invoices always have always at least one htlc in them)
recs := invoice.Htlcs[0].CustomRecords
if invoice.IsKeysend {
return recs[TLV_WALLET_ID] != nil
}
return true
}

func (svc *Service) PublishPayload(ctx context.Context, payload interface{}, exchange, key string) error {
payloadBytes := new(bytes.Buffer)
err := json.NewEncoder(payloadBytes).Encode(payload)
Expand Down

0 comments on commit c033db9

Please sign in to comment.