diff --git a/controllers/invoicestream.ctrl.go b/controllers/invoicestream.ctrl.go deleted file mode 100644 index bff6be81..00000000 --- a/controllers/invoicestream.ctrl.go +++ /dev/null @@ -1,152 +0,0 @@ -package controllers - -import ( - "net/http" - "strconv" - "time" - - "github.com/getAlby/lndhub.go/common" - "github.com/getAlby/lndhub.go/lib/service" - "github.com/getAlby/lndhub.go/lib/tokens" - "github.com/gorilla/websocket" - "github.com/labstack/echo/v4" -) - -// GetTXSController : GetTXSController struct -type InvoiceStreamController struct { - svc *service.LndhubService -} - -type InvoiceEventWrapper struct { - Type string `json:"type"` - Invoice *IncomingInvoice `json:"invoice,omitempty"` -} - -func NewInvoiceStreamController(svc *service.LndhubService) *InvoiceStreamController { - return &InvoiceStreamController{svc: svc} -} - -func (controller *InvoiceStreamController) StreamInvoices(c echo.Context) error { - userId, err := tokens.ParseToken(controller.svc.Config.JWTSecret, (c.QueryParam("token")), false) - if err != nil { - return err - } - ticker := time.NewTicker(30 * time.Second) - ws, done, err := createWebsocketUpgrader(c) - if err != nil { - return err - } - defer ws.Close() - //start subscription - invoiceChan, subId, err := controller.svc.InvoicePubSub.Subscribe(strconv.FormatInt(userId, 10)) - if err != nil { - controller.svc.Logger.Error(err) - return err - } - //start with keepalive message - err = ws.WriteJSON(&InvoiceEventWrapper{Type: "keepalive"}) - if err != nil { - controller.svc.Logger.Error(err) - controller.svc.InvoicePubSub.Unsubscribe(subId, strconv.FormatInt(userId, 10)) - return err - } - fromPaymentHash := c.QueryParam("since_payment_hash") - if fromPaymentHash != "" { - err = controller.writeMissingInvoices(c, userId, ws, fromPaymentHash) - if err != nil { - controller.svc.Logger.Error(err) - controller.svc.InvoicePubSub.Unsubscribe(subId, strconv.FormatInt(userId, 10)) - return err - } - } -SocketLoop: - for { - select { - case <-done: - break SocketLoop - case <-ticker.C: - err := ws.WriteJSON(&InvoiceEventWrapper{Type: "keepalive"}) - if err != nil { - controller.svc.Logger.Error(err) - break SocketLoop - } - case invoice := <-invoiceChan: - err := ws.WriteJSON( - &InvoiceEventWrapper{ - Type: "invoice", - Invoice: &IncomingInvoice{ - PaymentHash: invoice.RHash, - PaymentRequest: invoice.PaymentRequest, - Description: invoice.Memo, - PayReq: invoice.PaymentRequest, - Timestamp: invoice.CreatedAt.Unix(), - Type: common.InvoiceTypeUser, - Amount: invoice.Amount, - IsPaid: invoice.State == common.InvoiceStateSettled, - CustomRecords: invoice.DestinationCustomRecords, - }}) - if err != nil { - controller.svc.Logger.Error(err) - break SocketLoop - } - } - } - controller.svc.InvoicePubSub.Unsubscribe(subId, strconv.FormatInt(userId, 10)) - return nil -} - -// open the websocket and start listening for close messages in a goroutine -func createWebsocketUpgrader(c echo.Context) (conn *websocket.Conn, done chan struct{}, err error) { - upgrader := websocket.Upgrader{} - upgrader.CheckOrigin = func(r *http.Request) bool { return true } - ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil) - if err != nil { - return nil, nil, err - } - - //start listening for close messages - done = make(chan struct{}) - go func() { - defer close(done) - for { - _, _, err := ws.ReadMessage() - if err != nil { - return - } - } - }() - return ws, done, nil -} - -func (controller *InvoiceStreamController) writeMissingInvoices(c echo.Context, userId int64, ws *websocket.Conn, hash string) error { - invoices, err := controller.svc.InvoicesFor(c.Request().Context(), userId, common.InvoiceTypeIncoming) - if err != nil { - return err - } - for _, inv := range invoices { - //invoices are order from newest to oldest (with a maximum of 100 invoices being returned) - //so if we get a match on the hash, we have processed all missing invoices for this client - if inv.RHash == hash { - break - } - if inv.State == common.InvoiceStateSettled { - err := ws.WriteJSON( - &InvoiceEventWrapper{ - Type: "invoice", - Invoice: &IncomingInvoice{ - PaymentHash: inv.RHash, - PaymentRequest: inv.PaymentRequest, - Description: inv.Memo, - PayReq: inv.PaymentRequest, - Timestamp: inv.CreatedAt.Unix(), - Type: common.InvoiceTypeUser, - Amount: inv.Amount, - IsPaid: inv.State == common.InvoiceStateSettled, - }}) - if err != nil { - return err - } - } - } - return nil -} diff --git a/integration_tests/websocket_test.go b/integration_tests/websocket_test.go deleted file mode 100644 index cc930661..00000000 --- a/integration_tests/websocket_test.go +++ /dev/null @@ -1,241 +0,0 @@ -package integration_tests - -import ( - "context" - "encoding/json" - "fmt" - "log" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/getAlby/lndhub.go/controllers" - "github.com/getAlby/lndhub.go/lib" - "github.com/getAlby/lndhub.go/lib/responses" - "github.com/getAlby/lndhub.go/lib/service" - "github.com/getAlby/lndhub.go/lib/tokens" - "github.com/go-playground/validator/v10" - "github.com/gorilla/websocket" - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" -) - -type KeepAlive struct { - Type string -} - -type WebSocketTestSuite struct { - TestSuite - service *service.LndhubService - mlnd *MockLND - userLogin ExpectedCreateUserResponseBody - userToken string - userToken2 string - invoiceUpdateSubCancelFn context.CancelFunc - websocketServer *httptest.Server - wsUrl string - wsUrl2 string -} -type WsHandler struct { - handler echo.HandlerFunc -} - -func (h *WsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - e := echo.New() - c := e.NewContext(r, w) - - err := h.handler(c) - if err != nil { - _, _ = w.Write([]byte(err.Error())) - } -} - -func (suite *WebSocketTestSuite) SetupSuite() { - mlnd := newDefaultMockLND() - svc, err := LndHubTestServiceInit(mlnd) - if err != nil { - log.Fatalf("Error initializing test service: %v", err) - } - suite.mlnd = mlnd - users, userTokens, err := createUsers(svc, 2) - if err != nil { - log.Fatalf("Error creating test users: %v", err) - } - // Subscribe to LND invoice updates in the background - // store cancel func to be called in tear down suite - ctx, cancel := context.WithCancel(context.Background()) - suite.invoiceUpdateSubCancelFn = cancel - go svc.InvoiceUpdateSubscription(ctx) - suite.service = svc - e := echo.New() - - e.HTTPErrorHandler = responses.HTTPErrorHandler - e.Validator = &lib.CustomValidator{Validator: validator.New()} - suite.echo = e - assert.Equal(suite.T(), 2, len(users)) - assert.Equal(suite.T(), 2, len(userTokens)) - suite.userLogin = users[0] - suite.userToken = userTokens[0] - suite.userToken2 = userTokens[1] - suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) - suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice) - - //websocket server - h := WsHandler{handler: controllers.NewInvoiceStreamController(suite.service).StreamInvoices} - server := httptest.NewServer(http.HandlerFunc(h.ServeHTTP)) - suite.websocketServer = server - suite.wsUrl = "ws" + strings.TrimPrefix(suite.websocketServer.URL, "http") + fmt.Sprintf("?token=%s", suite.userToken) - suite.wsUrl2 = "ws" + strings.TrimPrefix(suite.websocketServer.URL, "http") + fmt.Sprintf("?token=%s", suite.userToken2) -} - -func (suite *WebSocketTestSuite) TestWebSocket() { - //start listening to websocket - ws, _, err := websocket.DefaultDialer.Dial(suite.wsUrl, nil) - assert.NoError(suite.T(), err) - _, msg, err := ws.ReadMessage() - assert.NoError(suite.T(), err) - keepAlive := KeepAlive{} - err = json.Unmarshal([]byte(msg), &keepAlive) - assert.NoError(suite.T(), err) - assert.Equal(suite.T(), "keepalive", keepAlive.Type) - - // create incoming invoice and fund account - invoice := suite.createAddInvoiceReq(1000, "integration test websocket 1", suite.userToken) - err = suite.mlnd.mockPaidInvoice(invoice, 0, false, nil) - assert.NoError(suite.T(), err) - - _, msg, err = ws.ReadMessage() - assert.NoError(suite.T(), err) - event := ExpectedInvoiceEventWrapper{} - err = json.Unmarshal([]byte(msg), &event) - assert.NoError(suite.T(), err) - assert.Equal(suite.T(), event.Type, "invoice") - assert.Equal(suite.T(), int64(1000), event.Invoice.Amount) - assert.Equal(suite.T(), "integration test websocket 1", event.Invoice.Description) -} - -func (suite *WebSocketTestSuite) TestWebSocketDoubeSubscription() { - //create 1st subscription - ws1, _, err := websocket.DefaultDialer.Dial(suite.wsUrl, nil) - assert.NoError(suite.T(), err) - //read keepalive msg - _, _, err = ws1.ReadMessage() - //create 2nd subscription, create invoice, pay invoice, assert that invoice is received twice - //start listening to websocket - ws2, _, err := websocket.DefaultDialer.Dial(suite.wsUrl, nil) - assert.NoError(suite.T(), err) - //read keepalive msg - _, _, err = ws2.ReadMessage() - assert.NoError(suite.T(), err) - invoice := suite.createAddInvoiceReq(1000, "integration test websocket 2", suite.userToken) - err = suite.mlnd.mockPaidInvoice(invoice, 0, false, nil) - assert.NoError(suite.T(), err) - _, msg1, err := ws1.ReadMessage() - assert.NoError(suite.T(), err) - _, msg2, err := ws2.ReadMessage() - assert.NoError(suite.T(), err) - - event1 := ExpectedInvoiceEventWrapper{} - err = json.Unmarshal([]byte(msg1), &event1) - assert.NoError(suite.T(), err) - event2 := ExpectedInvoiceEventWrapper{} - err = json.Unmarshal([]byte(msg2), &event2) - assert.NoError(suite.T(), err) - assert.Equal(suite.T(), "integration test websocket 2", event1.Invoice.Description) - assert.Equal(suite.T(), "integration test websocket 2", event2.Invoice.Description) - //close 1 subscription, assert that the existing sub still receives their invoices - ws1.Close() - invoice = suite.createAddInvoiceReq(1000, "integration test websocket 3", suite.userToken) - err = suite.mlnd.mockPaidInvoice(invoice, 0, false, nil) - assert.NoError(suite.T(), err) - _, msg3, err := ws2.ReadMessage() - assert.NoError(suite.T(), err) - event3 := ExpectedInvoiceEventWrapper{} - err = json.Unmarshal([]byte(msg3), &event3) - assert.NoError(suite.T(), err) - assert.Equal(suite.T(), "integration test websocket 3", event3.Invoice.Description) - -} -func (suite *WebSocketTestSuite) TestWebSocketDoubleUser() { - //create subs for 2 different users, assert that they each get their own invoice updates - user1Ws, _, err := websocket.DefaultDialer.Dial(suite.wsUrl, nil) - assert.NoError(suite.T(), err) - //read keepalive msg - _, _, err = user1Ws.ReadMessage() - assert.NoError(suite.T(), err) - //create subs for 2 different users, assert that they each get their own invoice updates - user2Ws, _, err := websocket.DefaultDialer.Dial(suite.wsUrl2, nil) - assert.NoError(suite.T(), err) - //read keepalive msg - _, _, err = user2Ws.ReadMessage() - assert.NoError(suite.T(), err) - // add invoice for user 1 - user1Invoice := suite.createAddInvoiceReq(1000, "integration test websocket user 1", suite.userToken) - // add invoice for user 2 - user2Invoice := suite.createAddInvoiceReq(1000, "integration test websocket user 2", suite.userToken2) - //mock pay invoices - err = suite.mlnd.mockPaidInvoice(user1Invoice, 0, false, nil) - assert.NoError(suite.T(), err) - err = suite.mlnd.mockPaidInvoice(user2Invoice, 0, false, nil) - assert.NoError(suite.T(), err) - //read user 1 received msg - _, user1Msg, err := user1Ws.ReadMessage() - assert.NoError(suite.T(), err) - //assert it's their's - eventUser1 := ExpectedInvoiceEventWrapper{} - err = json.Unmarshal([]byte(user1Msg), &eventUser1) - assert.NoError(suite.T(), err) - assert.Equal(suite.T(), "integration test websocket user 1", eventUser1.Invoice.Description) - //read user 2 received msg - _, user2Msg, err := user2Ws.ReadMessage() - assert.NoError(suite.T(), err) - //assert it's their's - eventUser2 := ExpectedInvoiceEventWrapper{} - err = json.Unmarshal([]byte(user2Msg), &eventUser2) - assert.NoError(suite.T(), err) - assert.Equal(suite.T(), "integration test websocket user 2", eventUser2.Invoice.Description) - -} -func (suite *WebSocketTestSuite) TestWebSocketMissingInvoice() { - // create incoming invoice and fund account - invoice1 := suite.createAddInvoiceReq(1000, "integration test websocket missing invoices", suite.userToken) - err := suite.mlnd.mockPaidInvoice(invoice1, 0, false, nil) - assert.NoError(suite.T(), err) - - //create 2nd invoice and pay it as well - invoice2 := suite.createAddInvoiceReq(1000, "integration test websocket missing invoices 2nd", suite.userToken) - err = suite.mlnd.mockPaidInvoice(invoice2, 0, false, nil) - assert.NoError(suite.T(), err) - - //start listening to websocket after 2nd invoice has been paid - //we should get an event for the 2nd invoice if we specify the hash as the query parameter - ws, _, err := websocket.DefaultDialer.Dial(fmt.Sprintf("%s&since_payment_hash=%s", suite.wsUrl, invoice1.RHash), nil) - assert.NoError(suite.T(), err) - _, msg, err := ws.ReadMessage() - assert.NoError(suite.T(), err) - keepAlive := KeepAlive{} - err = json.Unmarshal([]byte(msg), &keepAlive) - assert.NoError(suite.T(), err) - assert.Equal(suite.T(), "keepalive", keepAlive.Type) - - _, msg, err = ws.ReadMessage() - assert.NoError(suite.T(), err) - event := ExpectedInvoiceEventWrapper{} - err = json.Unmarshal([]byte(msg), &event) - assert.NoError(suite.T(), err) - assert.Equal(suite.T(), event.Type, "invoice") - assert.Equal(suite.T(), int64(1000), event.Invoice.Amount) - assert.Equal(suite.T(), "integration test websocket missing invoices 2nd", event.Invoice.Description) -} - -func (suite *WebSocketTestSuite) TearDownSuite() { - suite.invoiceUpdateSubCancelFn() - suite.websocketServer.Close() - clearTable(suite.service, "invoices") -} - -func TestWebSocketSuite(t *testing.T) { - suite.Run(t, new(WebSocketTestSuite)) -} diff --git a/main.go b/main.go index 99b28aea..24d35742 100644 --- a/main.go +++ b/main.go @@ -16,7 +16,6 @@ import ( cache "github.com/SporkHubr/echo-http-cache" "github.com/SporkHubr/echo-http-cache/adapter/memory" - "github.com/getAlby/lndhub.go/controllers" "github.com/getAlby/lndhub.go/db" "github.com/getAlby/lndhub.go/db/migrations" "github.com/getAlby/lndhub.go/docs" @@ -194,10 +193,6 @@ func main() { RegisterLegacyEndpoints(svc, e, secured, securedWithStrictRateLimit, strictRateLimitMiddleware, tokens.AdminTokenMiddleware(c.AdminToken)) RegisterV2Endpoints(svc, e, secured, securedWithStrictRateLimit, strictRateLimitMiddleware, tokens.AdminTokenMiddleware(c.AdminToken)) - //invoice streaming - //Authentication should be done through the query param because this is a websocket - e.GET("/invoices/stream", controllers.NewInvoiceStreamController(svc).StreamInvoices) - //Swagger API spec docs.SwaggerInfo.Host = c.Host e.GET("/swagger/*", echoSwagger.WrapHandler)