Skip to content

Commit

Permalink
Merge pull request #6847 from The-K-R-O-K/illia-malachyn/6845-unify-s…
Browse files Browse the repository at this point in the history
…ubscription-and-message-id

[Access] Unify subscription id with client message id
  • Loading branch information
Guitarheroua authored Jan 21, 2025
2 parents fedc48a + fe70a58 commit a550d37
Show file tree
Hide file tree
Showing 33 changed files with 362 additions and 238 deletions.
111 changes: 63 additions & 48 deletions engine/access/rest/websockets/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"sync"
"time"

"golang.org/x/time/rate"

"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -129,7 +129,7 @@ type Controller struct {
// issues such as sending on a closed channel while maintaining proper cleanup.
multiplexedStream chan interface{}

dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider]
dataProviders *concurrentmap.Map[SubscriptionID, dp.DataProvider]
dataProviderFactory dp.DataProviderFactory
dataProvidersGroup *sync.WaitGroup
limiter *rate.Limiter
Expand All @@ -146,7 +146,7 @@ func NewWebSocketController(
config: config,
conn: conn,
multiplexedStream: make(chan interface{}),
dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](),
dataProviders: concurrentmap.New[SubscriptionID, dp.DataProvider](),
dataProviderFactory: dataProviderFactory,
dataProvidersGroup: &sync.WaitGroup{},
limiter: rate.NewLimiter(rate.Limit(config.MaxResponsesPerSecond), 1),
Expand Down Expand Up @@ -246,7 +246,7 @@ func (c *Controller) keepalive(ctx context.Context) error {
// If no messages are sent within InactivityTimeout and no active data providers exist,
// the connection will be closed.
func (c *Controller) writeMessages(ctx context.Context) error {
inactivityTicker := time.NewTicker(c.config.InactivityTimeout / 10)
inactivityTicker := time.NewTicker(c.inactivityTickerPeriod())
defer inactivityTicker.Stop()

lastMessageSentAt := time.Now()
Expand Down Expand Up @@ -301,6 +301,10 @@ func (c *Controller) writeMessages(ctx context.Context) error {
}
}

func (c *Controller) inactivityTickerPeriod() time.Duration {
return c.config.InactivityTimeout / 10
}

// readMessages continuously reads messages from a client WebSocket connection,
// validates each message, and processes it based on the message type.
func (c *Controller) readMessages(ctx context.Context) error {
Expand All @@ -314,7 +318,8 @@ func (c *Controller) readMessages(ctx context.Context) error {
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(InvalidMessage, "error reading message", "", "", ""))
wrapErrorMessage(http.StatusBadRequest, "error reading message", "", ""),
)
continue
}

Expand All @@ -323,7 +328,8 @@ func (c *Controller) readMessages(ctx context.Context) error {
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(InvalidMessage, "error parsing message", "", "", ""))
wrapErrorMessage(http.StatusBadRequest, "error parsing message", "", ""),
)
continue
}
}
Expand Down Expand Up @@ -366,24 +372,34 @@ func (c *Controller) handleMessage(ctx context.Context, message json.RawMessage)
}

func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) {
subscriptionID, err := c.parseOrCreateSubscriptionID(msg.SubscriptionID)
if err != nil {
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(http.StatusBadRequest, "error parsing subscription id",
models.SubscribeAction, msg.SubscriptionID),
)
return
}

// register new provider
provider, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.multiplexedStream)
provider, err := c.dataProviderFactory.NewDataProvider(ctx, subscriptionID.String(), msg.Topic, msg.Arguments, c.multiplexedStream)
if err != nil {
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(InvalidArgument, "error creating data provider", msg.ClientMessageID, models.SubscribeAction, ""),
wrapErrorMessage(http.StatusBadRequest, "error creating data provider",
models.SubscribeAction, subscriptionID.String()),
)
return
}
c.dataProviders.Add(provider.ID(), provider)
c.dataProviders.Add(subscriptionID, provider)

// write OK response to client
responseOk := models.SubscribeMessageResponse{
BaseMessageResponse: models.BaseMessageResponse{
ClientMessageID: msg.ClientMessageID,
Success: true,
SubscriptionID: provider.ID().String(),
SubscriptionID: subscriptionID.String(),
},
}
c.writeResponse(ctx, responseOk)
Expand All @@ -396,72 +412,63 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(SubscriptionError, "subscription finished with error", "", "", ""),
wrapErrorMessage(http.StatusInternalServerError, "internal error",
models.SubscribeAction, subscriptionID.String()),
)
}

c.dataProvidersGroup.Done()
c.dataProviders.Remove(provider.ID())
c.dataProviders.Remove(subscriptionID)
}()
}

func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.UnsubscribeMessageRequest) {
id, err := uuid.Parse(msg.SubscriptionID)
subscriptionID, err := ParseClientSubscriptionID(msg.SubscriptionID)
if err != nil {
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(InvalidArgument, "error parsing subscription ID", msg.ClientMessageID, models.UnsubscribeAction, msg.SubscriptionID),
wrapErrorMessage(http.StatusBadRequest, "error parsing subscription id",
models.UnsubscribeAction, msg.SubscriptionID),
)
return
}

provider, ok := c.dataProviders.Get(id)
provider, ok := c.dataProviders.Get(subscriptionID)
if !ok {
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(NotFound, "subscription not found", msg.ClientMessageID, models.UnsubscribeAction, msg.SubscriptionID),
wrapErrorMessage(http.StatusNotFound, "subscription not found",
models.UnsubscribeAction, subscriptionID.String()),
)
return
}

provider.Close()
c.dataProviders.Remove(id)
c.dataProviders.Remove(subscriptionID)

responseOk := models.UnsubscribeMessageResponse{
BaseMessageResponse: models.BaseMessageResponse{
ClientMessageID: msg.ClientMessageID,
Success: true,
SubscriptionID: msg.SubscriptionID,
SubscriptionID: subscriptionID.String(),
},
}
c.writeResponse(ctx, responseOk)
}

func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.ListSubscriptionsMessageRequest) {
func (c *Controller) handleListSubscriptions(ctx context.Context, _ models.ListSubscriptionsMessageRequest) {
var subs []*models.SubscriptionEntry
err := c.dataProviders.ForEach(func(id uuid.UUID, provider dp.DataProvider) error {
_ = c.dataProviders.ForEach(func(id SubscriptionID, provider dp.DataProvider) error {
subs = append(subs, &models.SubscriptionEntry{
ID: id.String(),
Topic: provider.Topic(),
SubscriptionID: id.String(),
Topic: provider.Topic(),
})
return nil
})

if err != nil {
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(NotFound, "error listing subscriptions", msg.ClientMessageID, models.ListSubscriptionsAction, ""),
)
return
}

responseOk := models.ListSubscriptionsMessageResponse{
Success: true,
ClientMessageID: msg.ClientMessageID,
Subscriptions: subs,
Subscriptions: subs,
Action: models.ListSubscriptionsAction,
}
c.writeResponse(ctx, responseOk)
}
Expand All @@ -472,13 +479,10 @@ func (c *Controller) shutdownConnection() {
c.logger.Debug().Err(err).Msg("error closing connection")
}

err = c.dataProviders.ForEach(func(_ uuid.UUID, provider dp.DataProvider) error {
_ = c.dataProviders.ForEach(func(_ SubscriptionID, provider dp.DataProvider) error {
provider.Close()
return nil
})
if err != nil {
c.logger.Debug().Err(err).Msg("error closing data provider")
}

c.dataProviders.Clear()
c.dataProvidersGroup.Wait()
Expand All @@ -498,15 +502,26 @@ func (c *Controller) writeResponse(ctx context.Context, response interface{}) {
}
}

func wrapErrorMessage(code Code, message string, msgId string, action string, subscriptionID string) models.BaseMessageResponse {
func wrapErrorMessage(code int, message string, action string, subscriptionID string) models.BaseMessageResponse {
return models.BaseMessageResponse{
ClientMessageID: msgId,
Success: false,
SubscriptionID: subscriptionID,
SubscriptionID: subscriptionID,
Error: models.ErrorMessage{
Code: int(code),
Code: code,
Message: message,
Action: action,
},
Action: action,
}
}

func (c *Controller) parseOrCreateSubscriptionID(id string) (SubscriptionID, error) {
newId, err := NewSubscriptionID(id)
if err != nil {
return SubscriptionID{}, err
}

if c.dataProviders.Has(newId) {
return SubscriptionID{}, fmt.Errorf("subscription ID is already in use: %s", newId)
}

return newId, nil
}
Loading

0 comments on commit a550d37

Please sign in to comment.