From 1fb9cfe07da6ce3fcadf145b5d16f17f7d33be3c Mon Sep 17 00:00:00 2001 From: Mike Mason Date: Fri, 4 Aug 2023 10:29:36 -0500 Subject: [PATCH] rewrite events handling to remove watermill (#130) As discussed in the community call, watermill doesn't give us the necessary features we'd like to utilize with the underlying nats message. We decided to switch to using nats directly but still wanted some support for possibly changing this out later. This rewrites events to use our own interfaces to allow for the possibility of a different event driver later. Additionally this switches to using pull subscriptions instead of push, supports Ack, Nak and Term as well as Request/Reply semantics. Due to the Request/Reply semantics, no longer are there separate Publisher and Subscriber configurations as the driver needs to be able to handle both. --------- Signed-off-by: Mike Mason --- entx/annotation.go | 8 +- entx/template/event_hooks.tmpl | 63 +++++-- events/config.go | 74 +++----- events/connection.go | 69 +++++++ events/errors.go | 32 ++++ events/message.go | 160 +++++++++++++++- events/nats.go | 109 ----------- events/nats_config.go | 193 +++++++++++++++++++ events/nats_connection.go | 150 +++++++++++++++ events/nats_errors.go | 14 ++ events/nats_message.go | 235 +++++++++++++++++++++++ events/nats_publish.go | 173 +++++++++++++++++ events/nats_subscribe.go | 180 ++++++++++++++++++ events/nats_test.go | 192 ++++++++++--------- events/publisher.go | 260 -------------------------- events/subscriber.go | 95 ---------- go.mod | 8 +- go.sum | 13 -- testing/eventtools/mock_connection.go | 72 +++++++ testing/eventtools/mock_message.go | 101 ++++++++++ testing/eventtools/nats.go | 22 +-- testing/eventtools/nats_test.go | 49 +++-- 22 files changed, 1584 insertions(+), 688 deletions(-) create mode 100644 events/connection.go create mode 100644 events/errors.go delete mode 100644 events/nats.go create mode 100644 events/nats_config.go create mode 100644 events/nats_connection.go create mode 100644 events/nats_errors.go create mode 100644 events/nats_message.go create mode 100644 events/nats_publish.go create mode 100644 events/nats_subscribe.go delete mode 100644 events/publisher.go delete mode 100644 events/subscriber.go create mode 100644 testing/eventtools/mock_connection.go create mode 100644 testing/eventtools/mock_message.go diff --git a/entx/annotation.go b/entx/annotation.go index fee7cf72..c352bd86 100644 --- a/entx/annotation.go +++ b/entx/annotation.go @@ -19,8 +19,8 @@ var EventsHookAnnotationName = "INFRA9_EVENTHOOKS" // EventsHookAnnotation provides a ent.Annotation spec. These shouldn't be set directly, you should use EventsHookAdditionalSubject() and EventsHookSubjectName instead type EventsHookAnnotation struct { - SubjectName string - IsAdditionalSubjectField bool + SubjectName string + AdditionalSubjectRelation string } // Name implements the ent Annotation interface. @@ -29,9 +29,9 @@ func (a EventsHookAnnotation) Name() string { } // EventsHookAdditionalSubject marks this field as a field to return as an additional subject -func EventsHookAdditionalSubject() *EventsHookAnnotation { +func EventsHookAdditionalSubject(relation string) *EventsHookAnnotation { return &EventsHookAnnotation{ - IsAdditionalSubjectField: true, + AdditionalSubjectRelation: relation, } } diff --git a/entx/template/event_hooks.tmpl b/entx/template/event_hooks.tmpl index df0bfc0e..20e985f1 100644 --- a/entx/template/event_hooks.tmpl +++ b/entx/template/event_hooks.tmpl @@ -7,6 +7,8 @@ {{ $genPackage := base $.Config.Package }} + import "go.infratographer.com/permissions-api/pkg/permissions" + {{- range $node := $.Nodes }} {{- if $nodeAnnotation := $node.Annotations.INFRA9_EVENTHOOKS }} {{- if ne $nodeAnnotation.SubjectName "" }} @@ -17,6 +19,7 @@ return hook.{{ $node.Name }}Func(func(ctx context.Context, m *generated.{{ $node.Name }}Mutation) (ent.Value, error) { var err error additionalSubjects := []gidx.PrefixedID{} + relationships := []events.AuthRelationshipRelation{} objID, ok := m.{{ $node.ID.MutationGet }}() if !ok { @@ -40,7 +43,7 @@ {{ $currentValue }} := "" {{ $f.Name }}, ok := m.{{ $f.MutationGet }}() {{- $annotation := $f.Annotations.INFRA9_EVENTHOOKS }} - {{- if $annotation.IsAdditionalSubjectField }} + {{- if $annotation.AdditionalSubjectRelation }} if !ok && !m.Op().Is(ent.OpCreate) { // since we are doing an update or delete and these fields didn't change, load the "old" value {{ $f.Name }}, err = m.{{ $f.MutationGetOld }}(ctx) @@ -51,9 +54,19 @@ {{- if $f.Optional }} if {{ $f.Name }} != gidx.NullPrefixedID { additionalSubjects = append(additionalSubjects, {{ $f.Name }}) + + relationships = append(relationships, events.AuthRelationshipRelation{ + Relation: "{{ $annotation.AdditionalSubjectRelation }}", + SubjectID: {{ $f.Name }}, + }) } {{- else }} additionalSubjects = append(additionalSubjects, {{ $f.Name }}) + + relationships = append(relationships, events.AuthRelationshipRelation{ + Relation: "{{ $annotation.AdditionalSubjectRelation }}", + SubjectID: {{ $f.Name }}, + }) {{- end }} {{ end }} @@ -99,13 +112,19 @@ } {{ end }} {{ end }} + + if len(relationships) != 0 { + if err := permissions.CreateAuthRelationships(ctx, "{{ $nodeAnnotation.SubjectName }}", objID, relationships...); err != nil { + return nil, fmt.Errorf("relationship request failed with error: %w", err) + } + } msg := events.ChangeMessage{ - EventType: eventType(m.Op()), - SubjectID: objID, - AdditionalSubjectIDs: additionalSubjects, - Timestamp: time.Now().UTC(), - FieldChanges: changeset, + EventType: eventType(m.Op()), + SubjectID: objID, + AdditionalSubjectIDs: additionalSubjects, + Timestamp: time.Now().UTC(), + FieldChanges: changeset, } // complete the mutation before we process the event @@ -114,7 +133,7 @@ return retValue, err } - if err := m.EventsPublisher.PublishChange(ctx, "{{ $nodeAnnotation.SubjectName }}", msg); err != nil { + if _, err := m.EventsPublisher.PublishChange(ctx, "{{ $nodeAnnotation.SubjectName }}", msg); err != nil { return nil, fmt.Errorf("failed to publish change: %w", err) } @@ -128,6 +147,7 @@ func(next ent.Mutator) ent.Mutator { return hook.{{ $node.Name }}Func(func(ctx context.Context, m *generated.{{ $node.Name }}Mutation) (ent.Value, error) { additionalSubjects := []gidx.PrefixedID{} + relationships := []events.AuthRelationshipRelation{} objID, ok := m.{{ $node.ID.MutationGet }}() if !ok { @@ -142,18 +162,34 @@ {{- range $f := $node.Fields }} {{- if not $f.Sensitive }} {{- $annotation := $f.Annotations.INFRA9_EVENTHOOKS }} - {{- if $annotation.IsAdditionalSubjectField }} + {{- if $annotation.AdditionalSubjectRelation }} {{- if $f.Optional }} if dbObj.{{ $f.MutationGet }} != gidx.NullPrefixedID { additionalSubjects = append(additionalSubjects, dbObj.{{ $f.MutationGet }}) + + relationships = append(relationships, events.AuthRelationshipRelation{ + Relation: "{{ $annotation.AdditionalSubjectRelation }}", + SubjectID: dbObj.{{ $f.MutationGet }}, + }) } {{- else }} additionalSubjects = append(additionalSubjects, dbObj.{{ $f.MutationGet }}) + + relationships = append(relationships, events.AuthRelationshipRelation{ + Relation: "{{ $annotation.AdditionalSubjectRelation }}", + SubjectID: dbObj.{{ $f.MutationGet }}, + }) {{- end }} {{ end }} {{ end }} {{ end }} + if len(relationships) != 0 { + if err := permissions.DeleteAuthRelationships(ctx, "{{ $nodeAnnotation.SubjectName }}", objID, relationships...); err != nil { + return nil, fmt.Errorf("relationship request failed with error: %w", err) + } + } + // we have all the info we need, now complete the mutation before we process the event retValue, err := next.Mutate(ctx, m) if err != nil { @@ -161,14 +197,13 @@ } msg := events.ChangeMessage{ - EventType: eventType(m.Op()), - SubjectID: objID, - AdditionalSubjectIDs: additionalSubjects, - Timestamp: time.Now().UTC(), + EventType: eventType(m.Op()), + SubjectID: objID, + AdditionalSubjectIDs: additionalSubjects, + Timestamp: time.Now().UTC(), } - - if err := m.EventsPublisher.PublishChange(ctx, "{{ $nodeAnnotation.SubjectName }}", msg); err != nil { + if _, err := m.EventsPublisher.PublishChange(ctx, "{{ $nodeAnnotation.SubjectName }}", msg); err != nil { return nil, fmt.Errorf("failed to publish change: %w", err) } diff --git a/events/config.go b/events/config.go index acef5be5..27cace29 100644 --- a/events/config.go +++ b/events/config.go @@ -19,62 +19,46 @@ import ( "github.com/spf13/pflag" "github.com/spf13/viper" - - "go.infratographer.com/x/viperx" + "go.uber.org/multierr" + "go.uber.org/zap" ) -var defaultTimeout = time.Second * 10 - -// PublisherConfig handles reading in all the config values available for setting up a pubsub publisher -type PublisherConfig struct { - URL string `mapstructure:"url"` - Timeout time.Duration `mapstructure:"timeout"` - Prefix string `mapstructure:"prefix"` - Source string `mapstructure:"source"` - NATSConfig NATSConfig `mapstructure:"nats"` -} +const ( + defaultTimeout = time.Second * 10 + tracerName = "go.infratographer.com/x/events" +) -// SubscriberConfig handles reading in all the config values available for setting up a pubsub publisher -type SubscriberConfig struct { - URL string `mapstructure:"url"` - Timeout time.Duration `mapstructure:"timeout"` - Prefix string `mapstructure:"prefix"` - QueueGroup string `mapstructure:"queueGroup"` - NATSConfig NATSConfig `mapstructure:"nats"` +// Config contains event provider configs. +type Config struct { + NATS NATSConfig `mapstructure:"nats"` } -// NATSConfig handles reading in all pubsub values specific to NATS -type NATSConfig struct { - Token string `mapstructure:"token"` - CredsFile string `mapstructure:"credsFile"` +// MustViperFlags returns the cobra flags and viper config for events. +func MustViperFlags(v *viper.Viper, flags *pflag.FlagSet, appName string) { + MustViperFlagsForNATS(v, flags, appName) } -// MustViperFlagsForPublisher returns the cobra flags and viper config for an event publisher -func MustViperFlagsForPublisher(v *viper.Viper, flags *pflag.FlagSet, appName string) { - flags.String("events-publisher-url", "nats://nats:4222", "nats server connection url") - viperx.MustBindFlag(v, "events.publisher.url", flags.Lookup("events-publisher-url")) +// Option configures a connection option. +type Option func(config *Config) error - v.MustBindEnv("events.publisher.timeout") - v.MustBindEnv("events.publisher.prefix") - v.MustBindEnv("events.publisher.source") - v.MustBindEnv("events.publisher.nats.token") - v.MustBindEnv("events.publisher.nats.credsFile") +// WithLogger sets the logger for the connection. +func WithLogger(logger *zap.SugaredLogger) Option { + return func(config *Config) error { + config.NATS.logger = logger - v.SetDefault("events.publisher.timeout", defaultTimeout) - v.SetDefault("events.publisher.source", appName) + return nil + } } -// MustViperFlagsForSubscriber returns the cobra flags and viper config for an event subscriber -func MustViperFlagsForSubscriber(v *viper.Viper, flags *pflag.FlagSet) { - flags.String("events-subscriber-url", "nats://nats:4222", "nats server connection url") - viperx.MustBindFlag(v, "events.subscriber.url", flags.Lookup("events-subscriber-url")) - flags.String("events-subscriber-queuegroup", "", "subscriber queue group") - viperx.MustBindFlag(v, "events.subscriber.queueGroup", flags.Lookup("events-subscriber-queuegroup")) +// WithNATSOptions configures nats options. +func WithNATSOptions(options ...NATSOption) Option { + return func(config *Config) error { + var err error - v.MustBindEnv("events.subscriber.timeout") - v.MustBindEnv("events.subscriber.prefix") - v.MustBindEnv("events.subscriber.nats.token") - v.MustBindEnv("events.subscriber.nats.credsFile") + for _, opt := range options { + err = multierr.Append(err, opt(&config.NATS)) + } - v.SetDefault("events.subscriber.timeout", defaultTimeout) + return err + } } diff --git a/events/connection.go b/events/connection.go new file mode 100644 index 00000000..e8dd8e93 --- /dev/null +++ b/events/connection.go @@ -0,0 +1,69 @@ +package events + +import ( + "context" + + "go.uber.org/multierr" +) + +// Connection defines a connection handler. +type Connection interface { + // Gracefully close the connection. + Shutdown(ctx context.Context) error + + // Source gives you the raw underlying connection object. + Source() any + + Subscriber + Publisher + + AuthRelationshipSubscriber + AuthRelationshipPublisher +} + +// Subscriber specifies subscriber methods. +type Subscriber interface { + // SubscribeChanges subscribes to the provided topic responding with an ChangeMessage message. + SubscribeChanges(ctx context.Context, topic string) (<-chan Message[ChangeMessage], error) + // SubscribeEvents subscribes to the provided topic responding with an EventMessage message. + SubscribeEvents(ctx context.Context, topic string) (<-chan Message[EventMessage], error) +} + +// Publisher specifies publisher methods. +type Publisher interface { + // PublishChange publishes to the specified topic with the message given. + PublishChange(ctx context.Context, topic string, message ChangeMessage) (Message[ChangeMessage], error) + // PublishEvent publishes to the specified topic with the message given. + PublishEvent(ctx context.Context, topic string, message EventMessage) (Message[EventMessage], error) +} + +// AuthRelationshipSubscriber specifies the auth relationship subscriber methods. +type AuthRelationshipSubscriber interface { + // SubscribeAuthRelationshipRequests subscribes to the provided topic responding with an AuthRelationshipRequest message. + SubscribeAuthRelationshipRequests(ctx context.Context, topic string) (<-chan Request[AuthRelationshipRequest, AuthRelationshipResponse], error) +} + +// AuthRelationshipPublisher specifies the auth relationship publisher methods. +type AuthRelationshipPublisher interface { + // PublishAuthRelationshipRequest publishes to the specified topic with the message given. + PublishAuthRelationshipRequest(ctx context.Context, topic string, message AuthRelationshipRequest) (Message[AuthRelationshipResponse], error) +} + +// NewConnection creates a new Connection from the provided config. +func NewConnection(config Config, options ...Option) (Connection, error) { + var err error + + for _, opt := range options { + err = multierr.Append(err, opt(&config)) + } + + if err != nil { + return nil, err + } + + if config.NATS.Configured() { + return NewNATSConnection(config.NATS) + } + + return nil, ErrProviderNotConfigured +} diff --git a/events/errors.go b/events/errors.go new file mode 100644 index 00000000..591ee4f7 --- /dev/null +++ b/events/errors.go @@ -0,0 +1,32 @@ +package events + +import "errors" + +var ( + // ErrProviderNotConfigured is an error packages should return if no events provider is configured. + ErrProviderNotConfigured = errors.New("events provider not configured") + + // ErrMissingChangeMessageEventType is returned when the event message has the incorrect field EventType value. + ErrMissingChangeMessageEventType = errors.New("change message EventType field required") + // ErrMissingChangeMessageSubjectID is returned when the event message has the incorrect field SubjectID value. + ErrMissingChangeMessageSubjectID = errors.New("change message SubjectID field required") + + // ErrMissingEventMessageEventType is returned when the event message has the incorrect field EventType value. + ErrMissingEventMessageEventType = errors.New("event message EventType field required") + // ErrMissingEventMessageSubjectID is returned when the event message has the incorrect field SubjectID value. + ErrMissingEventMessageSubjectID = errors.New("event message SubjectID field required") + + // ErrInvalidAuthRelationshipRequestAction is returned when the event message has the incorrect field Action value. + ErrInvalidAuthRelationshipRequestAction = errors.New("auth relationship request message Action field must be write or delete") + // ErrMissingAuthRelationshipRequestObjectID is returned when the event message has the incorrect field ObjectID value. + ErrMissingAuthRelationshipRequestObjectID = errors.New("auth relationship request message ObjectID field required") + // ErrMissingAuthRelationshipRequestRelation is returned when the event message has no relations defined. + ErrMissingAuthRelationshipRequestRelation = errors.New("auth relationship request message Relations field required") + // ErrMissingAuthRelationshipRequestRelationRelation is returned when the event message Relations has the incorrect field for Relation value. + ErrMissingAuthRelationshipRequestRelationRelation = errors.New("auth relationship request message Relations Relation field required") + // ErrMissingAuthRelationshipRequestRelationSubjectID is returned when the event message Relations has the incorrect field SubjectID value. + ErrMissingAuthRelationshipRequestRelationSubjectID = errors.New("auth relationship request message Relations SubjectID field required") + + // ErrRequestNoResponders is returned when a request is attempted but no responder is listening. + ErrRequestNoResponders = errors.New("no responders for request") +) diff --git a/events/message.go b/events/message.go index 324f4a00..15a7be00 100644 --- a/events/message.go +++ b/events/message.go @@ -16,12 +16,55 @@ package events import ( + "context" "encoding/json" + "fmt" "time" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" + "go.uber.org/multierr" + "go.infratographer.com/x/gidx" ) +// Message contains a message which has been published or received from a subscription. +type Message[T any] interface { + // Connection returns the underlying connection the message was received on. + Connection() Connection + + // ID returns the unique message id. + ID() string + // Topic returns the topic the message was sent to. + Topic() string + // Message returns the decoded message object. + Message() T + // Ack acks the message. + Ack() error + // Nak nacks the message. + Nak(delay time.Duration) error + // Term terminates the message. + Term() error + // Timestamp returns the time the message was submitted. + Timestamp() time.Time + // Deliveries returns the number of times the message was delivered. + Deliveries() uint64 + + // Error returns any error encountered while decoding the message + Error() error + + // Source returns the underlying message object. + Source() any +} + +// Request extends Message by allowing replies to be sent for the received message. +type Request[TRequest, TResponse any] interface { + Message[TRequest] + + // Reply publishes a response to the received message. + Reply(ctx context.Context, message TResponse) (Message[TResponse], error) +} + // ChangeType represents the possible event types for a ChangeMessage type ChangeType string @@ -82,6 +125,28 @@ type ChangeMessage struct { AdditionalData map[string]interface{} `json:"additionalData"` } +// GetTraceContext creates a new OpenTelementry context for the message. +func (m ChangeMessage) GetTraceContext(ctx context.Context) context.Context { + tp := otel.GetTextMapPropagator() + + return tp.Extract(ctx, propagation.MapCarrier(m.TraceContext)) +} + +// Validate ensures the message has all the required fields. +func (m ChangeMessage) Validate() error { + var err error + + if m.SubjectID == "" { + err = multierr.Append(err, ErrMissingChangeMessageSubjectID) + } + + if m.EventType == "" { + err = multierr.Append(err, ErrMissingChangeMessageEventType) + } + + return err +} + // EventMessage contains the data structure expected to be received when picking // an event from an events message queue type EventMessage struct { @@ -107,6 +172,28 @@ type EventMessage struct { Data map[string]interface{} `json:"data"` } +// GetTraceContext creates a new OpenTelementry context for the message. +func (m EventMessage) GetTraceContext(ctx context.Context) context.Context { + tp := otel.GetTextMapPropagator() + + return tp.Extract(ctx, propagation.MapCarrier(m.TraceContext)) +} + +// Validate ensures the message has all the required fields. +func (m EventMessage) Validate() error { + var err error + + if m.SubjectID == "" { + err = multierr.Append(err, ErrMissingEventMessageSubjectID) + } + + if m.EventType == "" { + err = multierr.Append(err, ErrMissingEventMessageEventType) + } + + return err +} + // AuthRelationshipRequest contains the data structure expected to be used to write or delete // an auth relationship from PermissionsAPI type AuthRelationshipRequest struct { @@ -114,10 +201,8 @@ type AuthRelationshipRequest struct { Action AuthRelationshipAction `json:"action"` // ObjectID is the PrefixedID of the object the permissions will be granted on ObjectID gidx.PrefixedID `json:"objectID"` - // RelationshipName is the relationship being created on the object for the subject - RelationshipName string `json:"relationshipName"` - // SubjectID is the PrefixedID of the object the permissions apply to - SubjectID gidx.PrefixedID `json:"subjectID"` + // Relations defines all relations which should be written or deleted for this object. + Relations []AuthRelationshipRelation `json:"relations"` // ConditionName represents the name of a conditional check that will be applied to this relationship. (Optional) // In SpiceDB this would be a caveat name ConditionName string `json:"conditionName"` @@ -133,6 +218,61 @@ type AuthRelationshipRequest struct { SpanID string `json:"spanID"` } +// GetTraceContext creates a new OpenTelementry context for the message. +func (m AuthRelationshipRequest) GetTraceContext(ctx context.Context) context.Context { + tp := otel.GetTextMapPropagator() + + return tp.Extract(ctx, propagation.MapCarrier(m.TraceContext)) +} + +// Validate ensures the message has all the required fields. +func (m AuthRelationshipRequest) Validate() error { + var err error + + if m.Action == "" || m.Action != WriteAuthRelationshipAction && m.Action != DeleteAuthRelationshipAction { + err = multierr.Append(err, ErrInvalidAuthRelationshipRequestAction) + } + + if m.ObjectID == "" { + err = multierr.Append(err, ErrMissingAuthRelationshipRequestObjectID) + } + + if len(m.Relations) == 0 { + err = multierr.Append(err, ErrMissingAuthRelationshipRequestRelation) + } + + for i, rel := range m.Relations { + if rErr := rel.Validate(); rErr != nil { + err = multierr.Append(err, fmt.Errorf("%w: relation %d", rErr, i)) + } + } + + return err +} + +// AuthRelationshipRelation defines the relation an object from an AuthRelationshipRequest has to a subject. +type AuthRelationshipRelation struct { + // Relation is the name of the relation the object from AuthRelationshipRequest has to the subject. + Relation string `json:"relation"` + // The subject the relation is to. + SubjectID gidx.PrefixedID `json:"subjectID"` +} + +// Validate ensures the message has all the required fields. +func (r AuthRelationshipRelation) Validate() error { + var err error + + if r.Relation == "" { + err = multierr.Append(err, ErrMissingAuthRelationshipRequestRelationRelation) + } + + if r.SubjectID == "" { + err = multierr.Append(err, ErrMissingAuthRelationshipRequestRelationSubjectID) + } + + return err +} + // AuthRelationshipResponse contains the data structure expected to be received from an AuthRelationshipRequest // message to write or delete an auth relationship from PermissionsAPI type AuthRelationshipResponse struct { @@ -148,6 +288,18 @@ type AuthRelationshipResponse struct { SpanID string `json:"spanID"` } +// GetTraceContext creates a new OpenTelementry context for the message. +func (m AuthRelationshipResponse) GetTraceContext(ctx context.Context) context.Context { + tp := otel.GetTextMapPropagator() + + return tp.Extract(ctx, propagation.MapCarrier(m.TraceContext)) +} + +// Validate ensures the message has all the required fields. +func (m AuthRelationshipResponse) Validate() error { + return nil +} + // UnmarshalChangeMessage returns a ChangeMessage from a json []byte. func UnmarshalChangeMessage(b []byte) (ChangeMessage, error) { var c ChangeMessage diff --git a/events/nats.go b/events/nats.go deleted file mode 100644 index 5a851890..00000000 --- a/events/nats.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2023 The Infratographer Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package events - -import ( - "crypto/md5" - "encoding/hex" - - "github.com/ThreeDotsLabs/watermill-nats/v2/pkg/nats" - "github.com/ThreeDotsLabs/watermill/message" - "github.com/garsue/watermillzap" - nc "github.com/nats-io/nats.go" - "go.uber.org/zap" -) - -var natsMarshaler = &nats.JSONMarshaler{} - -func newNATSPublisher(cfg PublisherConfig, logger *zap.SugaredLogger) (message.Publisher, error) { - logAdapter := watermillzap.NewLogger(logger.Desugar()) - - options := []nc.Option{ - nc.Timeout(cfg.Timeout), - } - - switch { - case cfg.NATSConfig.CredsFile != "": - options = append(options, nc.UserCredentials(cfg.NATSConfig.CredsFile)) - case cfg.NATSConfig.Token != "": - options = append(options, nc.Token(cfg.NATSConfig.Token)) - } - - jsConfig := nats.JetStreamConfig{ - Disabled: false, - AutoProvision: false, - ConnectOptions: nil, - PublishOptions: nil, - TrackMsgId: false, - AckAsync: false, - DurablePrefix: "", - } - - return nats.NewPublisher( - nats.PublisherConfig{ - URL: cfg.URL, - NatsOptions: options, - Marshaler: natsMarshaler, - JetStream: jsConfig, - }, - logAdapter, - ) -} - -func newNATSSubscriber(cfg SubscriberConfig, logger *zap.SugaredLogger, subOpts ...nc.SubOpt) (message.Subscriber, error) { - logAdapter := watermillzap.NewLogger(logger.Desugar()) - - options := []nc.Option{ - nc.Timeout(cfg.Timeout), - } - - switch { - case cfg.NATSConfig.CredsFile != "": - options = append(options, nc.UserCredentials(cfg.NATSConfig.CredsFile)) - case cfg.NATSConfig.Token != "": - options = append(options, nc.Token(cfg.NATSConfig.Token)) - } - - jsConfig := nats.JetStreamConfig{ - Disabled: false, - AutoProvision: false, - ConnectOptions: nil, - PublishOptions: nil, - SubscribeOptions: subOpts, - TrackMsgId: false, - AckAsync: false, - DurablePrefix: "", - } - - if cfg.QueueGroup != "" { - jsConfig.DurableCalculator = func(_ string, topic string) string { - hash := md5.Sum([]byte(topic)) - return cfg.QueueGroup + hex.EncodeToString(hash[:]) - } - } - - sub, err := nats.NewSubscriber( - nats.SubscriberConfig{ - URL: cfg.URL, - NatsOptions: options, - Unmarshaler: natsMarshaler, - JetStream: jsConfig, - QueueGroupPrefix: cfg.QueueGroup, - }, - logAdapter, - ) - - return sub, err -} diff --git a/events/nats_config.go b/events/nats_config.go new file mode 100644 index 00000000..d07ce7ce --- /dev/null +++ b/events/nats_config.go @@ -0,0 +1,193 @@ +package events + +import ( + "time" + + "github.com/nats-io/nats.go" + "github.com/spf13/pflag" + "github.com/spf13/viper" + "go.uber.org/multierr" + "go.uber.org/zap" + + "go.infratographer.com/x/viperx" +) + +var ( + // NATSDefaultConnectTimeout is the default connection timeout. + NATSDefaultConnectTimeout = 5 * time.Second + // NATSDefaultSubscriberFetchBatchSize is the default pull subscribe batch size. + NATSDefaultSubscriberFetchBatchSize = 20 + // NATSDefaultSubscriberFetchTimeout is the max time a fetch will block before releasing. + NATSDefaultSubscriberFetchTimeout = 5 * time.Second + // NATSDefaultSubscriberFetchBackoff is the delay between a batch attempts. + NATSDefaultSubscriberFetchBackoff = 5 * time.Second + // NATSDefaultShutdownTimeout is the timeout for a shutdown to complete. + NATSDefaultShutdownTimeout = 5 * time.Second +) + +// NATSConfig defines the NATS connection configuration. +type NATSConfig struct { + URL string + SubscribePrefix string + PublishPrefix string + QueueGroup string + Token string + CredsFile string + Source string + + ConnectTimeout time.Duration + ShutdownTimeout time.Duration + SubscriberFetchBatchSize int + SubscriberFetchTimeout time.Duration + SubscriberFetchBackoff time.Duration + SubscriberNoAckExplicit bool + SubscriberNoManualAck bool + + SubscriberDeliveryPolicy string + SubscriberStartSequence uint64 + + logger *zap.SugaredLogger + connectOptions []nats.Option + jetStreamOptions []nats.JSOpt + subscribeOptions []nats.SubOpt +} + +// Configured checks whether the provider has been configured. +func (c NATSConfig) Configured() bool { + return c.URL != "" || c.QueueGroup != "" +} + +// Validate ensures the configuration is valid. +func (c NATSConfig) Validate() error { + var err error + + if c.Token != "" && c.CredsFile != "" { + err = multierr.Append(err, ErrNATSInvalidAuthConfiguration) + } + + switch c.SubscriberDeliveryPolicy { + case "", "all", "start-sequence": + default: + err = multierr.Append(err, ErrNATSInvalidDeliveryPolicy) + } + + return err +} + +// WithDefaults sets default values for the field unset. +func (c NATSConfig) WithDefaults() NATSConfig { + if c.logger == nil { + c.logger = zap.NewNop().Sugar() + } + + if c.SubscriberFetchBatchSize == 0 { + c.SubscriberFetchBatchSize = NATSDefaultSubscriberFetchBatchSize + } + + if c.SubscriberFetchTimeout == 0 { + c.SubscriberFetchTimeout = NATSDefaultSubscriberFetchTimeout + } + + if c.SubscriberFetchBackoff == 0 { + c.SubscriberFetchBackoff = NATSDefaultSubscriberFetchBackoff + } + + if !c.SubscriberNoAckExplicit { + c.subscribeOptions = append(c.subscribeOptions, nats.AckExplicit()) + } + + if !c.SubscriberNoManualAck { + c.subscribeOptions = append(c.subscribeOptions, nats.ManualAck()) + } + + switch c.SubscriberDeliveryPolicy { + case "start-sequence": + c.subscribeOptions = append(c.subscribeOptions, nats.StartSequence(c.SubscriberStartSequence)) + default: + c.subscribeOptions = append(c.subscribeOptions, nats.DeliverAll()) + } + + if c.ShutdownTimeout == 0 { + c.ShutdownTimeout = NATSDefaultShutdownTimeout + } + + if c.ConnectTimeout == 0 { + c.ConnectTimeout = NATSDefaultConnectTimeout + } + + c.connectOptions = append(c.connectOptions, nats.Timeout(c.ConnectTimeout)) + + if c.Token != "" { + c.connectOptions = append(c.connectOptions, nats.Token(c.Token)) + } + + if c.CredsFile != "" { + c.connectOptions = append(c.connectOptions, nats.UserCredentials(c.CredsFile)) + } + + return c +} + +// NATSOption defines a nats configuration option. +type NATSOption func(c *NATSConfig) error + +// WithNATSLogger sets the logger for the nats connection. +func WithNATSLogger(logger *zap.SugaredLogger) NATSOption { + return func(c *NATSConfig) error { + c.logger = logger + + return nil + } +} + +// WithNATSConnectOptions configures the connection options for nats. +func WithNATSConnectOptions(options ...nats.Option) NATSOption { + return func(c *NATSConfig) error { + c.connectOptions = append(c.connectOptions, options...) + + return nil + } +} + +// WithNATSJetStreamOptions configures the jetstream connection options. +func WithNATSJetStreamOptions(options ...nats.JSOpt) NATSOption { + return func(c *NATSConfig) error { + c.jetStreamOptions = append(c.jetStreamOptions, options...) + + return nil + } +} + +// WithNATSSubscribeOptions configures the subscribe options for new subscriptions. +func WithNATSSubscribeOptions(options ...nats.SubOpt) NATSOption { + return func(c *NATSConfig) error { + c.subscribeOptions = append(c.subscribeOptions, options...) + + return nil + } +} + +// MustViperFlagsForNATS returns the cobra flags and viper config for a nats handler. +func MustViperFlagsForNATS(v *viper.Viper, flags *pflag.FlagSet, appName string) { + flags.String("events-nats-url", "nats://nats:4222", "nats server connection url") + viperx.MustBindFlag(v, "events.nats.url", flags.Lookup("events-nats-url")) + + v.MustBindEnv("events.nats.subscribePrefix") + v.MustBindEnv("events.nats.publishPrefix") + v.MustBindEnv("events.nats.queueGroup") + v.MustBindEnv("events.nats.token") + v.MustBindEnv("events.nats.credsFile") + v.MustBindEnv("events.nats.source") + v.MustBindEnv("events.nats.connectTimeout") + v.MustBindEnv("events.nats.shutdownTimeout") + v.MustBindEnv("events.nats.subscriberFetchBatchSize") + v.MustBindEnv("events.nats.subscriberFetchTimeout") + v.MustBindEnv("events.nats.subscriberFetchBackoff") + v.MustBindEnv("events.nats.subscriberNoAckExplicit") + v.MustBindEnv("events.nats.subscriberNoManualAck") + v.MustBindEnv("events.nats.subscriberDeliveryPolicy") + v.MustBindEnv("events.nats.subscriberStartSequence") + + v.SetDefault("events.nats.connectTimeout", defaultTimeout) + v.SetDefault("events.nats.source", appName) +} diff --git a/events/nats_connection.go b/events/nats_connection.go new file mode 100644 index 00000000..da0c1edf --- /dev/null +++ b/events/nats_connection.go @@ -0,0 +1,150 @@ +package events + +import ( + "context" + "crypto/md5" + "encoding/hex" + "encoding/json" + "strings" + + "github.com/nats-io/nats.go" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +const ( + base10 = 10 + + natsTracerName = tracerName + ":nats" +) + +var _ Connection = (*NATSConnection)(nil) + +// NATSConnection implements Connection. +type NATSConnection struct { + logger *zap.SugaredLogger + tracer trace.Tracer + conn *nats.Conn + jetstream nats.JetStreamContext + cfg NATSConfig +} + +// Shutdown gracefully drains the connection. +func (c *NATSConnection) Shutdown(ctx context.Context) error { + ctx, cancelTimeout := context.WithTimeout(ctx, c.cfg.ShutdownTimeout) + ctx, cancel := context.WithCancelCause(ctx) + + defer cancelTimeout() + + closedCB := c.conn.Opts.ClosedCB + + c.conn.Opts.ClosedCB = func(c *nats.Conn) { + defer cancel(nil) + + if closedCB != nil { + closedCB(c) + } + } + + if err := c.conn.Drain(); err != nil { + cancel(err) + + return err + } + + <-ctx.Done() + + return ctx.Err() +} + +// Source returns the underlying NATS Connection. +func (c *NATSConnection) Source() any { + return c.conn +} + +func (c *NATSConnection) durableName(topic string) string { + return NATSConsumerDurableName(c.cfg.QueueGroup, topic) +} + +func (c *NATSConnection) buildSubscribeSubject(parts ...string) string { + var subjectParts []string + + if c.cfg.SubscribePrefix != "" { + subjectParts = append(subjectParts, c.cfg.SubscribePrefix) + } + + subjectParts = append(subjectParts, parts...) + + return strings.Join(subjectParts, ".") +} + +func (c *NATSConnection) buildPublishSubject(parts ...string) string { + var subjectParts []string + + if c.cfg.PublishPrefix != "" { + subjectParts = append(subjectParts, c.cfg.PublishPrefix) + } + + subjectParts = append(subjectParts, parts...) + + return strings.Join(subjectParts, ".") +} + +func newNATSMessage[T any](conn *NATSConnection, subject string, message T) (*NATSMessage[T], error) { + data, err := json.Marshal(message) + if err != nil { + return nil, err + } + + return &NATSMessage[T]{ + conn: conn, + source: &nats.Msg{ + Subject: subject, + Data: data, + }, + message: message, + }, nil +} + +// NewNATSConnection creates a new nats jetstream connection. +func NewNATSConnection(config NATSConfig, options ...NATSOption) (*NATSConnection, error) { + nc := config.WithDefaults() + + if err := nc.Validate(); err != nil { + return nil, err + } + + for _, opt := range options { + if err := opt(&nc); err != nil { + return nil, err + } + } + + conn, err := nats.Connect(config.URL, nc.connectOptions...) + if err != nil { + return nil, err + } + + js, err := conn.JetStream(nc.jetStreamOptions...) + if err != nil { + conn.Close() + + return nil, err + } + + return &NATSConnection{ + logger: nc.logger, + tracer: otel.GetTracerProvider().Tracer(natsTracerName), + conn: conn, + jetstream: js, + cfg: nc, + }, nil +} + +// NATSConsumerDurableName is the generator function to create a new durable consumer name. +func NATSConsumerDurableName(queueGroup, subject string) string { + hash := md5.Sum([]byte(subject)) + + return queueGroup + hex.EncodeToString(hash[:]) +} diff --git a/events/nats_errors.go b/events/nats_errors.go new file mode 100644 index 00000000..c59051f0 --- /dev/null +++ b/events/nats_errors.go @@ -0,0 +1,14 @@ +package events + +import "errors" + +var ( + // ErrNATSInvalidAuthConfiguration is returned when the config has both Tokena nd CredsFile specified. + ErrNATSInvalidAuthConfiguration = errors.New("invalid nats confinguration, both token and creds file are specified") + + // ErrNATSInvalidDeliveryPolicy is returned when an incorrect delivery policy is provided. + ErrNATSInvalidDeliveryPolicy = errors.New("invalid delivery policy") + + // ErrNATSMessageNoReplySubject is returned when calling ReplyAuthRelationshipRequest when the request has no reply subject defined. + ErrNATSMessageNoReplySubject = errors.New("unable to reply to auth relationship request, no reply subject specified") +) diff --git a/events/nats_message.go b/events/nats_message.go new file mode 100644 index 00000000..cf086c31 --- /dev/null +++ b/events/nats_message.go @@ -0,0 +1,235 @@ +package events + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" + "time" + + "github.com/nats-io/nats.go" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/propagation" +) + +func natsSubscriptionMessageChan[T any](ctx context.Context, conn *NATSConnection, batchSize int, natsCh <-chan *nats.Msg) chan Message[T] { + msgCh := make(chan Message[T], batchSize) + + go func() { + defer close(msgCh) + + for nMsg := range natsCh { + msg := natsDecodeMessage[T](conn, nMsg) + + select { + case msgCh <- msg: + case <-ctx.Done(): + return + } + } + }() + + return msgCh +} + +func natsSubscriptionAuthRelationshipRequestChan(ctx context.Context, conn *NATSConnection, batchSize int, natsCh <-chan *nats.Msg) chan Request[AuthRelationshipRequest, AuthRelationshipResponse] { + msgCh := make(chan Request[AuthRelationshipRequest, AuthRelationshipResponse], batchSize) + + go func() { + defer close(msgCh) + + for nMsg := range natsCh { + msg := natsDecodeMessage[AuthRelationshipRequest](conn, nMsg) + + req := &NATSAuthRelationshipRequest{ + NATSMessage: msg.(*NATSMessage[AuthRelationshipRequest]), + } + + select { + case msgCh <- req: + case <-ctx.Done(): + return + } + } + }() + + return msgCh +} + +func natsDecodeMessage[T any](conn *NATSConnection, nMsg *nats.Msg) Message[T] { + msg := &NATSMessage[T]{ + conn: conn, + source: nMsg, + } + + if err := json.Unmarshal(nMsg.Data, &msg.message); err != nil { + msg.err = err + } + + return msg +} + +var _ Message[any] = (*NATSMessage[any])(nil) + +// NATSMessage implements Message +type NATSMessage[T any] struct { + conn *NATSConnection + source *nats.Msg + sourceMetadata *nats.MsgMetadata + message T + err error +} + +// Connection returns the underlying Connection. +func (m *NATSMessage[T]) Connection() Connection { + return m.conn +} + +func (m *NATSMessage[T]) metadata() nats.MsgMetadata { + if m.sourceMetadata != nil { + return *m.sourceMetadata + } + + metadata, err := m.source.Metadata() + if err != nil { + m.conn.logger.Errorw("failed to load metadata for nats message", "nats.subject", m.source.Subject) + + return nats.MsgMetadata{} + } + + m.sourceMetadata = metadata + + return *m.sourceMetadata +} + +// ID returns the nats message sequence number for the consumer. +func (m *NATSMessage[T]) ID() string { + return strconv.FormatUint(m.metadata().Sequence.Consumer, base10) +} + +// Topic returns the nats subject. +func (m *NATSMessage[T]) Topic() string { + return m.source.Subject +} + +// Message returns the decoded message object. +func (m *NATSMessage[T]) Message() T { + return m.message +} + +// Ack acks the message. +func (m *NATSMessage[T]) Ack() error { + return m.source.Ack() +} + +// Nak calls a Nak with the provided delay. +func (m *NATSMessage[T]) Nak(delay time.Duration) error { + return m.source.NakWithDelay(delay) +} + +// Term terminates the message from being processed again. +func (m *NATSMessage[T]) Term() error { + return m.source.Term() +} + +// Timestamp returns the timestamp of the message. +func (m *NATSMessage[T]) Timestamp() time.Time { + return m.metadata().Timestamp +} + +// Deliveries returns the number of times the message was delivered. +func (m *NATSMessage[T]) Deliveries() uint64 { + return m.metadata().NumDelivered +} + +// Error returns any error with the message. +func (m *NATSMessage[T]) Error() error { + if m.err != nil { + return m.err + } + + return nil +} + +// Source returns the underlying nats message. +func (m *NATSMessage[T]) Source() any { + return m.source +} + +func (m *NATSMessage[T]) publish() error { + return m.conn.conn.PublishMsg(m.source) +} + +func (m *NATSMessage[T]) request(ctx context.Context) (Message[AuthRelationshipResponse], error) { + if m.source.Reply == "" { + m.source.Reply = m.conn.conn.NewRespInbox() + } + + nMsg, err := m.conn.conn.RequestMsgWithContext(ctx, m.source) + if err != nil { + // ensure we wrap no responder errors with ErrRequestNoResponders. + if errors.Is(err, nats.ErrNoResponders) { + return nil, fmt.Errorf("%w: %w", ErrRequestNoResponders, err) + } + + return nil, err + } + + respMsg := natsDecodeMessage[AuthRelationshipResponse](m.conn, nMsg) + + return respMsg, nil +} + +var _ Request[AuthRelationshipRequest, AuthRelationshipResponse] = (*NATSAuthRelationshipRequest)(nil) + +// NATSAuthRelationshipRequest implements Request for AuthRelationshipRequest / AuthRelationshipResponse +type NATSAuthRelationshipRequest struct { + *NATSMessage[AuthRelationshipRequest] +} + +// Reply responds to an AuthRelationshipRequest with an AuthRelationshipResponse. +func (r *NATSAuthRelationshipRequest) Reply(ctx context.Context, message AuthRelationshipResponse) (Message[AuthRelationshipResponse], error) { + ctx, span := r.conn.tracer.Start(ctx, "events.Reply") + + defer span.End() + + if r.source.Reply == "" { + span.RecordError(ErrNATSMessageNoReplySubject) + span.SetStatus(codes.Error, ErrNATSMessageNoReplySubject.Error()) + + return nil, ErrNATSMessageNoReplySubject + } + + if err := message.Validate(); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + + return nil, err + } + + // Propagate trace context into the message for the subscriber + var mapCarrier propagation.MapCarrier = make(map[string]string) + + otel.GetTextMapPropagator().Inject(ctx, mapCarrier) + + message.TraceContext = mapCarrier + + respMsg, err := newNATSMessage(r.conn, r.source.Reply, message) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + + return nil, err + } + + if err := r.source.RespondMsg(respMsg.source); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + + return respMsg, err + } + + return respMsg, nil +} diff --git a/events/nats_publish.go b/events/nats_publish.go new file mode 100644 index 00000000..77191866 --- /dev/null +++ b/events/nats_publish.go @@ -0,0 +1,173 @@ +package events + +import ( + "context" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/trace" + + "go.infratographer.com/x/echojwtx" + "go.infratographer.com/x/gidx" +) + +// PublishAuthRelationshipRequest publishes an AuthRelationshipRequest message and blocks until an AuthRelationshipResponse is provided. +func (c *NATSConnection) PublishAuthRelationshipRequest(ctx context.Context, topic string, message AuthRelationshipRequest) (Message[AuthRelationshipResponse], error) { + ctx, span := c.tracer.Start(ctx, "events.nats.PublishAuthRelationshipRequest", trace.WithAttributes( + attribute.String("events.subject_type", topic), + attribute.String("events.subject_id", message.ObjectID.String()), + attribute.String("events.event_type", string(message.Action)), + )) + + defer span.End() + + if err := message.Validate(); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + + return nil, err + } + + // Propagate trace context into the message for the subscriber + var mapCarrier propagation.MapCarrier = make(map[string]string) + + otel.GetTextMapPropagator().Inject(ctx, mapCarrier) + + message.TraceContext = mapCarrier + + topic = c.buildPublishSubject("auth", "relationships", string(message.Action), topic) + + reqMsg, err := newNATSMessage(c, topic, message) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + + return nil, err + } + + c.logger.Debugf("publishing auth relation request message to topic %s", topic) + + respMsg, err := reqMsg.request(ctx) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + + return nil, err + } + + return respMsg, nil +} + +// PublishChange publishes a ChangeMessage. +func (c *NATSConnection) PublishChange(ctx context.Context, topic string, message ChangeMessage) (Message[ChangeMessage], error) { + ctx, span := c.tracer.Start(ctx, "events.nats.PublishChange", trace.WithAttributes( + attribute.String("events.subject_type", topic), + attribute.String("events.subject_id", message.SubjectID.String()), + attribute.String("events.event_type", message.EventType), + attribute.String("events.source", message.Source), + )) + + defer span.End() + + if err := message.Validate(); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + + return nil, err + } + + // Propagate trace context into the message for the subscriber + var mapCarrier propagation.MapCarrier = make(map[string]string) + + otel.GetTextMapPropagator().Inject(ctx, mapCarrier) + + message.TraceContext = mapCarrier + + topic = c.buildPublishSubject("changes", message.EventType, topic) + + message.Source = c.cfg.Source + + if message.ActorID == gidx.NullPrefixedID { + id, ok := ctx.Value(echojwtx.ActorCtxKey).(string) + if ok { + message.ActorID = gidx.PrefixedID(id) + } else { + message.ActorID = "unknown-actor" + } + } + + span.SetAttributes( + attribute.String( + "events.actor_id", + message.ActorID.String(), + ), + ) + + msg, err := newNATSMessage(c, topic, message) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + + return nil, err + } + + c.logger.Debugf("publishing change message to topic %s", topic) + + if err = msg.publish(); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + + return msg, err + } + + return msg, nil +} + +// PublishEvent publishes an EventMessage. +func (c *NATSConnection) PublishEvent(ctx context.Context, topic string, message EventMessage) (Message[EventMessage], error) { + ctx, span := c.tracer.Start(ctx, "events.nats.PublishEvent", trace.WithAttributes( + attribute.String("events.subject_type", topic), + attribute.String("events.subject_id", message.SubjectID.String()), + attribute.String("events.event_type", message.EventType), + attribute.String("events.source", message.Source), + )) + + defer span.End() + + if err := message.Validate(); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + + return nil, err + } + + // Propagate trace context into the message for the subscriber + var mapCarrier propagation.MapCarrier = make(map[string]string) + + otel.GetTextMapPropagator().Inject(ctx, mapCarrier) + + message.TraceContext = mapCarrier + + topic = c.buildPublishSubject("events", message.EventType, topic) + + msg, err := newNATSMessage(c, topic, message) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + + return nil, err + } + + c.logger.Debugf("publishing event message to topic %s", topic) + + if err = msg.publish(); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + + return msg, err + } + + return msg, nil +} diff --git a/events/nats_subscribe.go b/events/nats_subscribe.go new file mode 100644 index 00000000..29a37cc4 --- /dev/null +++ b/events/nats_subscribe.go @@ -0,0 +1,180 @@ +package events + +import ( + "context" + "errors" + "time" + + "github.com/nats-io/nats.go" +) + +func (c *NATSConnection) coreSubscribe(ctx context.Context, subject string) (<-chan *nats.Msg, error) { + logger := c.logger.With( + "nats.provider", "core", + "nats.subject", subject, + ) + + sub, err := c.conn.SubscribeSync(subject) + if err != nil { + return nil, err + } + + msgCh := make(chan *nats.Msg, c.cfg.SubscriberFetchBatchSize) + + go func() { + for { + if err := c.nextMessage(ctx, sub, msgCh); err != nil { + if !errors.Is(err, context.DeadlineExceeded) { + logger.Errorw("error fetching messages", "error", err) + } + + select { + case <-ctx.Done(): + case <-time.After(c.cfg.SubscriberFetchBackoff): + } + } + + select { + case <-ctx.Done(): + close(msgCh) + + if err := sub.Unsubscribe(); err != nil { + logger.Warnw("error unsubscribing", "error", err) + } + + return + default: + } + } + }() + + return msgCh, nil +} + +func (c *NATSConnection) jsSubscribe(ctx context.Context, subject string) (<-chan *nats.Msg, error) { + durableName := c.durableName(subject) + + logger := c.logger.With( + "nats.provider", "jetstream", + "nats.subject", subject, + "nats.durable_name", durableName, + ) + + sub, err := c.jetstream.PullSubscribe(subject, durableName, c.cfg.subscribeOptions...) + if err != nil { + return nil, err + } + + msgCh := make(chan *nats.Msg, c.cfg.SubscriberFetchBatchSize) + + go func() { + for { + if err := c.fetchMessages(ctx, sub, msgCh); err != nil { + if !errors.Is(err, context.DeadlineExceeded) { + logger.Errorw("error fetching messages", "error", err) + } + + select { + case <-ctx.Done(): + case <-time.After(c.cfg.SubscriberFetchBackoff): + } + } + + select { + case <-ctx.Done(): + close(msgCh) + + if err := sub.Unsubscribe(); err != nil { + logger.Warnw("error unsubscribing", "error", err) + } + + return + default: + } + } + }() + + return msgCh, nil +} + +func (c *NATSConnection) fetchMessages(ctx context.Context, sub *nats.Subscription, msgCh chan<- *nats.Msg) error { + ctx, cancel := context.WithTimeout(ctx, c.cfg.SubscriberFetchTimeout) + + defer cancel() + + batch, err := sub.FetchBatch(c.cfg.SubscriberFetchBatchSize, nats.Context(ctx)) + if err != nil { + return err + } + + for msg := range batch.Messages() { + select { + case msgCh <- msg: + case <-ctx.Done(): + return ctx.Err() + } + } + + return batch.Error() +} + +func (c *NATSConnection) nextMessage(ctx context.Context, sub *nats.Subscription, msgCh chan<- *nats.Msg) error { + ctx, cancel := context.WithTimeout(ctx, c.cfg.SubscriberFetchTimeout) + + defer cancel() + + msg, err := sub.NextMsgWithContext(ctx) + if err != nil { + return err + } + + select { + case msgCh <- msg: + case <-ctx.Done(): + return ctx.Err() + } + + return nil +} + +// SubscribeAuthRelationshipRequests creates a new pull subscription parsing incoming messages as AuthRelationshipRequest messages and returning a new Message channel. +func (c *NATSConnection) SubscribeAuthRelationshipRequests(ctx context.Context, topic string) (<-chan Request[AuthRelationshipRequest, AuthRelationshipResponse], error) { + topic = c.buildSubscribeSubject("auth", "relationships", topic) + + natsCh, err := c.coreSubscribe(ctx, topic) + if err != nil { + return nil, err + } + + c.logger.Debugf("subscribing to auth relation request message on topic %s", topic) + + return natsSubscriptionAuthRelationshipRequestChan(ctx, c, c.cfg.SubscriberFetchBatchSize, natsCh), nil +} + +// SubscribeChanges creates a new pull subscription parsing incoming messages as ChangeMessage messages and returning a new Message channel. +func (c *NATSConnection) SubscribeChanges(ctx context.Context, topic string) (<-chan Message[ChangeMessage], error) { + topic = c.buildSubscribeSubject("changes", topic) + + natsCh, err := c.jsSubscribe(ctx, topic) + if err != nil { + return nil, err + } + + c.logger.Debugf("subscribing to changes message on topic %s", topic) + + return natsSubscriptionMessageChan[ChangeMessage](ctx, c, c.cfg.SubscriberFetchBatchSize, natsCh), nil +} + +// SubscribeEvents creates a new pull subscription parsing incoming messages as EventMessage messages and returning a new Message channel. +func (c *NATSConnection) SubscribeEvents(ctx context.Context, topic string) (<-chan Message[EventMessage], error) { + topic = c.buildSubscribeSubject("events", topic) + + natsCh, err := c.jsSubscribe(ctx, topic) + if err != nil { + return nil, err + } + + c.logger.Debugf("subscribing to events message on topic %s", topic) + + return natsSubscriptionMessageChan[EventMessage](ctx, c, c.cfg.SubscriberFetchBatchSize, natsCh), nil +} diff --git a/events/nats_test.go b/events/nats_test.go index b9913f26..39135c76 100644 --- a/events/nats_test.go +++ b/events/nats_test.go @@ -1,17 +1,3 @@ -// Copyright 2023 The Infratographer Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package events_test import ( @@ -20,164 +6,174 @@ import ( "testing" "time" - "github.com/ThreeDotsLabs/watermill/message" "github.com/brianvoe/gofakeit/v6" + nc "github.com/nats-io/nats.go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.infratographer.com/x/gidx" - "go.infratographer.com/x/events" + "go.infratographer.com/x/gidx" "go.infratographer.com/x/testing/eventtools" ) var errTimeout = errors.New("timeout waiting for event") -func TestNatsPublishAndSubscribe(t *testing.T) { +func TestNATSPublishAndSubscribe(t *testing.T) { ctx := context.Background() nats, err := eventtools.NewNatsServer() require.NoError(t, err) defer nats.Close() - publisher, err := events.NewPublisher(nats.PublisherConfig) + conn, err := events.NewNATSConnection(nats.Config.NATS) require.NoError(t, err) + defer conn.Shutdown(ctx) //nolint:errcheck // within test + change := testCreateChange() - err = publisher.PublishChange(ctx, "test", change) + msg, err := conn.PublishChange(ctx, "test", change) require.NoError(t, err) + require.Equal(t, change, msg.Message()) change2 := testCreateChange() - err = publisher.PublishChange(ctx, "test", change2) + msg, err = conn.PublishChange(ctx, "test", change2) require.NoError(t, err) + require.Equal(t, change2, msg.Message()) change3 := testCreateChange() change3.ActorID = "" - err = publisher.PublishChange(ctx, "test", change3) - require.NoError(t, err) - - sub, err := events.NewSubscriber(nats.SubscriberConfig) + msg, err = conn.PublishChange(ctx, "test", change3) require.NoError(t, err) + require.NotEqual(t, change3, msg.Message()) - messages, err := sub.SubscribeChanges(context.Background(), ">") + messages, err := conn.SubscribeChanges(ctx, ">") require.NoError(t, err) receivedMsg, err := getSingleMessage(messages, time.Second*1) require.NoError(t, err) - - chgMsg, err := events.UnmarshalChangeMessage(receivedMsg.Payload) - require.NoError(t, err) - assert.EqualValues(t, change, chgMsg) - assert.True(t, receivedMsg.Ack()) + require.NoError(t, receivedMsg.Error()) + assert.EqualValues(t, change, receivedMsg.Message()) receivedMsg, err = getSingleMessage(messages, time.Second*1) require.NoError(t, err) - - chgMsg, err = events.UnmarshalChangeMessage(receivedMsg.Payload) - require.NoError(t, err) - assert.EqualValues(t, change2, chgMsg) - assert.True(t, receivedMsg.Ack()) + require.NoError(t, receivedMsg.Error()) + assert.EqualValues(t, change2, receivedMsg.Message()) + assert.NoError(t, receivedMsg.Ack()) receivedMsg, err = getSingleMessage(messages, time.Second*1) require.NoError(t, err) - - chgMsg, err = events.UnmarshalChangeMessage(receivedMsg.Payload) - require.NoError(t, err) - assert.NotEqualValues(t, change3, chgMsg) - assert.Equal(t, "unknown-actor", chgMsg.ActorID.String()) - assert.True(t, receivedMsg.Ack()) + require.NoError(t, receivedMsg.Error()) + assert.NotEqualValues(t, change3, receivedMsg.Message()) + assert.Equal(t, "unknown-actor", receivedMsg.Message().ActorID.String()) + assert.NoError(t, receivedMsg.Ack()) } -func TestNatsMultipleSubscribers(t *testing.T) { +func TestNATSRequestReply(t *testing.T) { ctx := context.Background() nats, err := eventtools.NewNatsServer() require.NoError(t, err) defer nats.Close() - publisher, err := events.NewPublisher(nats.PublisherConfig) + conn, err := events.NewNATSConnection(nats.Config.NATS) require.NoError(t, err) - change := testCreateChange() + defer conn.Shutdown(ctx) //nolint:errcheck // within test - err = publisher.PublishChange(ctx, "test", change) - require.NoError(t, err) + authRequest := events.AuthRelationshipRequest{ + Action: events.WriteAuthRelationshipAction, + ObjectID: gidx.PrefixedID("prntobj-abc123"), + Relations: []events.AuthRelationshipRelation{ + { + Relation: "owner", + SubjectID: gidx.PrefixedID("chldobj-abc123"), + }, + }, + TraceContext: map[string]string{}, + } - sub, err := events.NewSubscriber(nats.SubscriberConfig) - require.NoError(t, err) + authResponse := events.AuthRelationshipResponse{ + TraceID: "some-id", + TraceContext: map[string]string{}, + } - messages, err := sub.SubscribeChanges(context.Background(), ">") - require.NoError(t, err) + resp, err := conn.PublishAuthRelationshipRequest(ctx, "test", authRequest) + require.Error(t, err) + require.ErrorIs(t, err, events.ErrRequestNoResponders) + require.ErrorIs(t, err, nc.ErrNoResponders) + require.Nil(t, resp) - receivedMsg, err := getSingleMessage(messages, time.Second*1) - require.NoError(t, err) + reqGot := make(chan events.Message[events.AuthRelationshipRequest], 1) + respGot := make(chan events.Message[events.AuthRelationshipResponse], 1) - chgMsg, err := events.UnmarshalChangeMessage(receivedMsg.Payload) - require.NoError(t, err) - assert.EqualValues(t, change, chgMsg) - assert.True(t, receivedMsg.Ack()) + authSubscribed := make(chan bool, 1) - sub2, err := events.NewSubscriber(nats.SubscriberConfig) - require.NoError(t, err) + go func() { + ctx, cancel := context.WithCancel(ctx) - messages, err = sub2.SubscribeChanges(context.Background(), ">") - require.NoError(t, err) + defer cancel() - receivedMsg, err = getSingleMessage(messages, time.Second*1) - require.NoError(t, err) + msgs, err := conn.SubscribeAuthRelationshipRequests(ctx, "*.test") - chgMsg, err = events.UnmarshalChangeMessage(receivedMsg.Payload) - require.NoError(t, err) - assert.EqualValues(t, change, chgMsg) - assert.True(t, receivedMsg.Ack()) -} + close(authSubscribed) -func TestNatsGroupedSubscribers(t *testing.T) { - ctx := context.Background() - nats, err := eventtools.NewNatsServer() - require.NoError(t, err) + require.NoError(t, err) - publisher, err := events.NewPublisher(nats.PublisherConfig) - require.NoError(t, err) + select { + case reqMsg, ok := <-msgs: + if !ok { + return + } - change := testCreateChange() + reqGot <- reqMsg - err = publisher.PublishChange(ctx, "test", change) - require.NoError(t, err) + respMsg, err := reqMsg.Reply(ctx, authResponse) + assert.NoError(t, err) + assert.NotNil(t, respMsg) + case <-time.After(time.Second * 2): + } + }() - // put both subscribers in the same queue group so that combined the message is only delivered once - nats.SubscriberConfig.QueueGroup = "queue-test" + <-authSubscribed - sub, err := events.NewSubscriber(nats.SubscriberConfig) - require.NoError(t, err) + go func() { + ctx, cancel := context.WithTimeout(ctx, time.Second*2) - messages, err := sub.SubscribeChanges(context.Background(), ">") - require.NoError(t, err) + defer cancel() - receivedMsg, err := getSingleMessage(messages, time.Second*1) - require.NoError(t, err) + resp, err := conn.PublishAuthRelationshipRequest(ctx, "test", authRequest) + assert.NoError(t, err) - chgMsg, err := events.UnmarshalChangeMessage(receivedMsg.Payload) - require.NoError(t, err) - assert.EqualValues(t, change, chgMsg) - assert.True(t, receivedMsg.Ack()) + respGot <- resp + }() - sub2, err := events.NewSubscriber(nats.SubscriberConfig) - require.NoError(t, err) + select { + case authRequestGot := <-reqGot: + require.NotNil(t, authRequestGot) + require.NoError(t, authRequestGot.Error()) + require.EqualValues(t, authRequest, authRequestGot.Message()) + case <-time.After(time.Second * 2): + t.Error("timed out waiting for auth relationship request") + } - messages, err = sub2.SubscribeChanges(context.Background(), ">") - require.NoError(t, err) + close(reqGot) - receivedMsg, err = getSingleMessage(messages, time.Second*1) - assert.Error(t, err, "this should fail since the other subscriber in the group already received the message") - assert.ErrorContains(t, err, "timeout") - assert.Nil(t, receivedMsg) + select { + case authResponseGot := <-respGot: + require.NotNil(t, authResponseGot) + require.NoError(t, authResponseGot.Error()) + require.EqualValues(t, authResponse, authResponseGot.Message()) + case <-time.After(time.Second * 2): + t.Error("timed out waiting for auth relationship response") + } + + close(respGot) } -func getSingleMessage(messages <-chan *message.Message, timeout time.Duration) (*message.Message, error) { +func getSingleMessage[T any](messages <-chan events.Message[T], timeout time.Duration) (events.Message[T], error) { select { case message := <-messages: return message, nil diff --git a/events/publisher.go b/events/publisher.go deleted file mode 100644 index 5d92e2ef..00000000 --- a/events/publisher.go +++ /dev/null @@ -1,260 +0,0 @@ -// Copyright 2023 The Infratographer Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package events - -import ( - "context" - "encoding/json" - "errors" - "strings" - - "github.com/ThreeDotsLabs/watermill" - "github.com/ThreeDotsLabs/watermill/message" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/propagation" - "go.opentelemetry.io/otel/trace" - "go.uber.org/zap" - - "go.infratographer.com/x/echojwtx" - "go.infratographer.com/x/gidx" -) - -const instrumentationName = "go.infratographer.com/x/events" - -// ErrUnsupportedPubsub is returned when the pubsub URL is not a supported provider -var ErrUnsupportedPubsub = errors.New("unsupported pubsub provider") - -// ErrMissingEventType is returned when attempting to publish an event without an event type specified -var ErrMissingEventType = errors.New("event type missing") - -// Publisher provides a pubsub publisher that uses the watermill pubsub package -type Publisher struct { - prefix string - source string - publisher message.Publisher - logger *zap.SugaredLogger - tracer trace.Tracer -} - -// NewPublisherWithLogger returns a publisher for the given config provided -func NewPublisherWithLogger(cfg PublisherConfig, logger *zap.SugaredLogger) (*Publisher, error) { - tracer := otel.GetTracerProvider().Tracer(instrumentationName) - - p := &Publisher{ - prefix: cfg.Prefix, - source: cfg.Source, - logger: logger, - tracer: tracer, - } - - switch { - case strings.HasPrefix(cfg.URL, "nats://"): - np, err := newNATSPublisher(cfg, p.logger) - if err != nil { - return nil, err - } - - p.publisher = np - default: - return nil, ErrUnsupportedPubsub - } - - return p, nil -} - -// NewPublisher returns a publisher for the given config provided -func NewPublisher(cfg PublisherConfig) (*Publisher, error) { - return NewPublisherWithLogger(cfg, zap.NewNop().Sugar()) -} - -// PublishChange will publish a ChangeMessage to the topic for the change -func (p *Publisher) PublishChange(ctx context.Context, subjectType string, change ChangeMessage) error { - ctx, span := p.tracer.Start( - ctx, - "events.publishChange", - trace.WithAttributes( - attribute.String( - "events.subject_type", - subjectType, - ), - attribute.String( - "events.subject_id", - change.SubjectID.String(), - ), - attribute.String( - "events.event_type", - change.EventType, - ), - attribute.String( - "events.source", - change.Source, - ), - ), - ) - - defer span.End() - - // Propagate trace context into the message for the subscriber - var mapCarrier propagation.MapCarrier = make(map[string]string) - - otel.GetTextMapPropagator().Inject(ctx, mapCarrier) - - change.TraceContext = mapCarrier - - if change.EventType == "" { - span.RecordError(ErrMissingEventType) - span.SetStatus(codes.Error, ErrMissingEventType.Error()) - - return ErrMissingEventType - } - - topic := strings.Join([]string{p.prefix, "changes", change.EventType, subjectType}, ".") - - span.SetAttributes( - attribute.String( - "events.topic", - topic, - ), - ) - - change.Source = p.source - if change.ActorID == gidx.NullPrefixedID { - id, ok := ctx.Value(echojwtx.ActorCtxKey).(string) - if ok { - change.ActorID = gidx.PrefixedID(id) - } else { - change.ActorID = "unknown-actor" - } - } - - span.SetAttributes( - attribute.String( - "events.actor_id", - change.ActorID.String(), - ), - ) - - v, err := json.Marshal(change) - if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) - - return err - } - - msg := message.NewMessage(watermill.NewUUID(), v) - - span.SetAttributes( - attribute.String( - "events.message_id", - msg.UUID, - ), - ) - - if err := p.publisher.Publish(topic, msg); err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) - - return err - } - - return nil -} - -// PublishEvent will publish an EventMessage to the proper topic for that event -func (p *Publisher) PublishEvent(ctx context.Context, subjectType string, event EventMessage) error { - ctx, span := p.tracer.Start( - ctx, - "events.publishEvent", - trace.WithAttributes( - attribute.String( - "events.subject_type", - subjectType, - ), - attribute.String( - "events.subject_id", - event.SubjectID.String(), - ), - attribute.String( - "events.event_type", - event.EventType, - ), - attribute.String( - "events.source", - event.Source, - ), - ), - ) - - defer span.End() - - // Propagate trace context into the message for the subscriber - var mapCarrier propagation.MapCarrier = make(map[string]string) - - otel.GetTextMapPropagator().Inject(ctx, mapCarrier) - - event.TraceContext = mapCarrier - - if event.EventType == "" { - span.RecordError(ErrMissingEventType) - span.SetStatus(codes.Error, ErrMissingEventType.Error()) - - return ErrMissingEventType - } - - topic := strings.Join([]string{p.prefix, "events", subjectType, event.EventType}, ".") - - span.SetAttributes( - attribute.String( - "events.topic", - topic, - ), - ) - - event.Source = p.source - - v, err := json.Marshal(event) - if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) - - return err - } - - msg := message.NewMessage(watermill.NewUUID(), v) - - span.SetAttributes( - attribute.String( - "events.message_id", - msg.UUID, - ), - ) - - if err := p.publisher.Publish(topic, msg); err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) - - return err - } - - return nil -} - -// Close will close the publisher -func (p *Publisher) Close() error { - return p.publisher.Close() -} diff --git a/events/subscriber.go b/events/subscriber.go deleted file mode 100644 index dcf34d35..00000000 --- a/events/subscriber.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2023 The Infratographer Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package events - -import ( - "context" - "strings" - - "github.com/ThreeDotsLabs/watermill/message" - "github.com/nats-io/nats.go" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/propagation" - "go.uber.org/zap" -) - -// Subscriber provides a pubsub subscriber that uses the watermill pubsub package -type Subscriber struct { - prefix string - subscriber message.Subscriber - logger *zap.SugaredLogger -} - -// NewSubscriberWithLogger returns a subscriber for the given config provided -func NewSubscriberWithLogger(cfg SubscriberConfig, logger *zap.SugaredLogger, options ...nats.SubOpt) (*Subscriber, error) { - s := &Subscriber{ - prefix: cfg.Prefix, - logger: logger, - } - - switch { - case strings.HasPrefix(cfg.URL, "nats://"): - ns, err := newNATSSubscriber(cfg, s.logger, options...) - if err != nil { - return nil, err - } - - s.subscriber = ns - default: - return nil, ErrUnsupportedPubsub - } - - return s, nil -} - -// NewSubscriber returns a subscriber for the given config provided -func NewSubscriber(cfg SubscriberConfig, options ...nats.SubOpt) (*Subscriber, error) { - return NewSubscriberWithLogger(cfg, zap.NewNop().Sugar(), options...) -} - -// SubscribeChanges will subscribe you to the changes for a given topic. To receive all changes of any kind you can -// pass in ">". -func (s *Subscriber) SubscribeChanges(ctx context.Context, topic string) (<-chan *message.Message, error) { - topic = strings.Join([]string{s.prefix, "changes", topic}, ".") - - return s.subscriber.Subscribe(ctx, topic) -} - -// SubscribeEvents will subscribe you to the events for a given topic. To receive all changes of any kind you can -// pass in ">". -func (s *Subscriber) SubscribeEvents(ctx context.Context, topic string) (<-chan *message.Message, error) { - topic = strings.Join([]string{s.prefix, "events", topic}, ".") - - return s.subscriber.Subscribe(ctx, topic) -} - -// Close will close the subscriber -func (s *Subscriber) Close() error { - return s.subscriber.Close() -} - -// TraceContextFromChangeMessage creates a new OpenTelemetry context from the given ChangeMessage. -func TraceContextFromChangeMessage(ctx context.Context, msg ChangeMessage) context.Context { - tp := otel.GetTextMapPropagator() - - return tp.Extract(ctx, propagation.MapCarrier(msg.TraceContext)) -} - -// TraceContextFromEventMessage creates a new OpenTelemetry context from the given ChangeMessage. -func TraceContextFromEventMessage(ctx context.Context, msg EventMessage) context.Context { - tp := otel.GetTextMapPropagator() - - return tp.Extract(ctx, propagation.MapCarrier(msg.TraceContext)) -} diff --git a/go.mod b/go.mod index 9b1e6f19..d8e0c298 100644 --- a/go.mod +++ b/go.mod @@ -7,13 +7,10 @@ require ( entgo.io/ent v0.12.3 github.com/99designs/gqlgen v0.17.34 github.com/MicahParks/keyfunc/v2 v2.0.3 - github.com/ThreeDotsLabs/watermill v1.2.0 - github.com/ThreeDotsLabs/watermill-nats/v2 v2.0.0 github.com/XSAM/otelsql v0.23.0 github.com/brianvoe/gofakeit/v6 v6.23.0 github.com/cockroachdb/cockroach-go/v2 v2.3.5 github.com/docker/go-connections v0.4.0 - github.com/garsue/watermillzap v1.2.0 github.com/gin-contrib/requestid v0.0.6 github.com/gin-contrib/zap v0.1.0 github.com/gin-gonic/gin v1.9.1 @@ -76,7 +73,6 @@ require ( github.com/imdario/mergo v0.3.16 // indirect github.com/klauspost/compress v1.16.7 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect - github.com/lithammer/shortuuid/v3 v3.0.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/minio/highwayhash v1.0.2 // indirect github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 // indirect @@ -87,12 +83,12 @@ require ( github.com/nats-io/jwt/v2 v2.4.1 // indirect github.com/nats-io/nkeys v0.4.4 // indirect github.com/nats-io/nuid v1.0.1 // indirect - github.com/oklog/ulid v1.3.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0-rc4 // indirect github.com/opencontainers/runc v1.1.7 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect @@ -158,7 +154,7 @@ require ( go.opentelemetry.io/proto/otlp v0.19.0 // indirect go.step.sm/crypto v0.31.2 go.uber.org/atomic v1.10.0 // indirect - go.uber.org/multierr v1.9.0 // indirect + go.uber.org/multierr v1.9.0 golang.org/x/crypto v0.11.0 // indirect golang.org/x/net v0.12.0 // indirect golang.org/x/sys v0.10.0 // indirect diff --git a/go.sum b/go.sum index 400545f1..96fef40a 100644 --- a/go.sum +++ b/go.sum @@ -59,10 +59,6 @@ github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migc github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/Microsoft/hcsshim v0.9.7 h1:mKNHW/Xvv1aFH87Jb6ERDzXTJTLPlmzfZ28VBFD/bfg= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= -github.com/ThreeDotsLabs/watermill v1.2.0 h1:TU3TML1dnQ/ifK09F2+4JQk2EKhmhXe7Qv7eb5ZpTS8= -github.com/ThreeDotsLabs/watermill v1.2.0/go.mod h1:IuVxGk/kgCN0cex2S94BLglUiB0PwOm8hbUhm6g2Nx4= -github.com/ThreeDotsLabs/watermill-nats/v2 v2.0.0 h1:ZbdQ+cHwOZmXByEoKUH8SS6qR/erNQfrsNpvH5z/gfk= -github.com/ThreeDotsLabs/watermill-nats/v2 v2.0.0/go.mod h1:X6pcl579pScj4mII3KM/WJ+bcOqORqiCToy92f4gqJ4= github.com/XSAM/otelsql v0.23.0 h1:NsJQS9YhI1+RDsFqE9mW5XIQmPmdF/qa8qQOLZN8XEA= github.com/XSAM/otelsql v0.23.0/go.mod h1:oX4LXMsb+9lAZhvHjUS61oQP/hbcJRadWHnBKNL+LuM= github.com/agext/levenshtein v1.2.1 h1:QmvMAjj2aEICytGiWzmxoE0x2KZvE0fvmqMOfy2tjT8= @@ -148,8 +144,6 @@ github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4 github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= -github.com/garsue/watermillzap v1.2.0 h1:IA0zGb5b7mIGLXN9P2/6CmP5+f7Qgb00BdL2VCAk2SA= -github.com/garsue/watermillzap v1.2.0/go.mod h1:uo3SDSGYaw6RBzUx9jcHMYqypOTqlQ4/vz+8r1olRto= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gin-contrib/requestid v0.0.6 h1:mGcxTnHQ45F6QU5HQRgQUDsAfHprD3P7g2uZ4cSZo9o= github.com/gin-contrib/requestid v0.0.6/go.mod h1:9i4vKATX/CdggbkY252dPVasgVucy/ggBeELXuQztm4= @@ -263,7 +257,6 @@ github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= @@ -274,8 +267,6 @@ github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFb github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0/go.mod h1:hgWBS7lorOAVIJEQMi4ZsPv9hVvWI6+ch50m39Pf2Ks= github.com/grpc-ecosystem/grpc-gateway/v2 v2.14.0 h1:t7uX3JBHdVwAi3G7sSSdbsk8NfgA+LnUS88V/2EKaA0= github.com/grpc-ecosystem/grpc-gateway/v2 v2.14.0/go.mod h1:4OGVnY4qf2+gw+ssiHbW+pq4mo2yko94YxxMmXZ7jCA= -github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= -github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= @@ -387,8 +378,6 @@ github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lithammer/shortuuid/v3 v3.0.7 h1:trX0KTHy4Pbwo/6ia8fscyHoGA+mf1jWbPJVuvyJQQ8= -github.com/lithammer/shortuuid/v3 v3.0.7/go.mod h1:vMk8ke37EmiewwolSO1NLW8vP4ZaKlRuDIi8tWWmAts= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= @@ -436,8 +425,6 @@ github.com/nats-io/nkeys v0.4.4 h1:xvBJ8d69TznjcQl9t6//Q5xXuVhyYiSos6RPtvQNTwA= github.com/nats-io/nkeys v0.4.4/go.mod h1:XUkxdLPTufzlihbamfzQ7mw/VGx6ObUs+0bN5sNvt64= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= -github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= -github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.0-rc4 h1:oOxKUJWnFC4YGHCCMNql1x4YaDfYBTS5Y4x/Cgeo1E0= diff --git a/testing/eventtools/mock_connection.go b/testing/eventtools/mock_connection.go new file mode 100644 index 00000000..8f213fb6 --- /dev/null +++ b/testing/eventtools/mock_connection.go @@ -0,0 +1,72 @@ +package eventtools + +import ( + "context" + + "github.com/stretchr/testify/mock" + + "go.infratographer.com/x/events" +) + +var _ events.Connection = (*MockConnection)(nil) + +// MockConnection implements events.Connection +type MockConnection struct { + mock.Mock +} + +// Shutdown implements events.Connection +func (c *MockConnection) Shutdown(_ context.Context) error { + args := c.Called() + + return args.Error(0) +} + +// PublishAuthRelationshipRequest implements events.Connection +func (c *MockConnection) PublishAuthRelationshipRequest(_ context.Context, topic string, message events.AuthRelationshipRequest) (events.Message[events.AuthRelationshipResponse], error) { + args := c.Called(topic, message) + + return args.Get(0).(events.Message[events.AuthRelationshipResponse]), args.Error(1) +} + +// PublishChange implements events.Connection +func (c *MockConnection) PublishChange(_ context.Context, topic string, message events.ChangeMessage) (events.Message[events.ChangeMessage], error) { + args := c.Called(topic, message) + + return args.Get(0).(events.Message[events.ChangeMessage]), args.Error(1) +} + +// PublishEvent implements events.Connection +func (c *MockConnection) PublishEvent(_ context.Context, topic string, message events.EventMessage) (events.Message[events.EventMessage], error) { + args := c.Called(topic, message) + + return args.Get(0).(events.Message[events.EventMessage]), args.Error(1) +} + +// Source implements events.Connection +func (c *MockConnection) Source() any { + args := c.Called() + + return args.Error(0) +} + +// SubscribeAuthRelationshipRequests implements events.Connection +func (c *MockConnection) SubscribeAuthRelationshipRequests(_ context.Context, topic string) (<-chan events.Request[events.AuthRelationshipRequest, events.AuthRelationshipResponse], error) { + args := c.Called(topic) + + return args.Get(0).(<-chan events.Request[events.AuthRelationshipRequest, events.AuthRelationshipResponse]), args.Error(1) +} + +// SubscribeChanges implements events.Connection +func (c *MockConnection) SubscribeChanges(_ context.Context, topic string) (<-chan events.Message[events.ChangeMessage], error) { + args := c.Called(topic) + + return args.Get(0).(<-chan events.Message[events.ChangeMessage]), args.Error(1) +} + +// SubscribeEvents implements events.Connection +func (c *MockConnection) SubscribeEvents(_ context.Context, topic string) (<-chan events.Message[events.EventMessage], error) { + args := c.Called(topic) + + return args.Get(0).(<-chan events.Message[events.EventMessage]), args.Error(1) +} diff --git a/testing/eventtools/mock_message.go b/testing/eventtools/mock_message.go new file mode 100644 index 00000000..7e49d90a --- /dev/null +++ b/testing/eventtools/mock_message.go @@ -0,0 +1,101 @@ +package eventtools + +import ( + "context" + "time" + + "github.com/stretchr/testify/mock" + + "go.infratographer.com/x/events" +) + +var _ events.Message[any] = (*MockMessage[interface{}])(nil) + +// MockMessage implements events.Message. +type MockMessage[T any] struct { + mock.Mock +} + +// Connection implements events.Message. +func (m *MockMessage[T]) Connection() events.Connection { + args := m.Called() + + return args.Get(0).(events.Connection) +} + +// ID implements events.Message. +func (m *MockMessage[T]) ID() string { + args := m.Called() + + return args.String(0) +} + +// Topic implements events.Message. +func (m *MockMessage[T]) Topic() string { + args := m.Called() + + return args.String(0) +} + +// Message implements events.Message. +func (m *MockMessage[T]) Message() T { + args := m.Called() + + return args.Get(0).(T) +} + +// Ack implements events.Message. +func (m *MockMessage[T]) Ack() error { + args := m.Called() + + return args.Error(0) +} + +// Nak implements events.Message. +func (m *MockMessage[T]) Nak(delay time.Duration) error { + args := m.Called(delay) + + return args.Error(0) +} + +// Term implements events.Message. +func (m *MockMessage[T]) Term() error { + args := m.Called() + + return args.Error(0) +} + +// Timestamp implements events.Message. +func (m *MockMessage[T]) Timestamp() time.Time { + args := m.Called() + + return args.Get(0).(time.Time) +} + +// Deliveries implements events.Message. +func (m *MockMessage[T]) Deliveries() uint64 { + args := m.Called() + + return args.Get(0).(uint64) +} + +// Error implements events.Message. +func (m *MockMessage[T]) Error() error { + args := m.Called() + + return args.Error(0) +} + +// ReplyAuthRelationshipRequest implements events.Message. +func (m *MockMessage[T]) ReplyAuthRelationshipRequest(_ context.Context, message events.AuthRelationshipResponse) (events.Message[events.AuthRelationshipResponse], error) { + args := m.Called(message) + + return args.Get(0).(events.Message[events.AuthRelationshipResponse]), args.Error(1) +} + +// Source implements events.Message. +func (m *MockMessage[T]) Source() any { + args := m.Called() + + return args.Error(0) +} diff --git a/testing/eventtools/nats.go b/testing/eventtools/nats.go index b3d8662e..c9fa264a 100644 --- a/testing/eventtools/nats.go +++ b/testing/eventtools/nats.go @@ -46,11 +46,10 @@ var ( // TestNats maintains the nats environment type TestNats struct { - Server *server.Server - Conn *nats.Conn - JetStream nats.JetStreamContext - PublisherConfig events.PublisherConfig - SubscriberConfig events.SubscriberConfig + Server *server.Server + Conn *nats.Conn + JetStream nats.JetStreamContext + Config events.Config } // Close closes the connection @@ -178,13 +177,12 @@ func NewNatsServer() (*TestNats, error) { Server: s, Conn: nc, JetStream: js, - PublisherConfig: events.PublisherConfig{ - URL: s.ClientURL(), - Prefix: Prefix, - }, - SubscriberConfig: events.SubscriberConfig{ - URL: s.ClientURL(), - Prefix: Prefix, + Config: events.Config{ + NATS: events.NATSConfig{ + URL: s.ClientURL(), + SubscribePrefix: Prefix, + PublishPrefix: Prefix, + }, }, }, nil } diff --git a/testing/eventtools/nats_test.go b/testing/eventtools/nats_test.go index f585558c..488384c0 100644 --- a/testing/eventtools/nats_test.go +++ b/testing/eventtools/nats_test.go @@ -6,9 +6,7 @@ import ( "testing" "time" - "github.com/ThreeDotsLabs/watermill/message" "github.com/brianvoe/gofakeit/v6" - nc "github.com/nats-io/nats.go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,64 +20,59 @@ var errTimeout = errors.New("timeout waiting for event") func TestNats(t *testing.T) { ctx := context.Background() + consumerName := events.NATSConsumerDurableName("", eventtools.Prefix+".changes.>") + nats, err := eventtools.NewNatsServer() require.NoError(t, err) defer nats.Close() - publisher, err := events.NewPublisher(nats.PublisherConfig) + conn, err := events.NewNATSConnection(nats.Config.NATS) require.NoError(t, err) change1 := testCreateChange() - err = publisher.PublishChange(ctx, "test", change1) + chgMsg, err := conn.PublishChange(ctx, "test", change1) require.NoError(t, err) + require.NoError(t, chgMsg.Error()) + require.Equal(t, change1, chgMsg.Message()) change2 := testCreateChange() - err = publisher.PublishChange(ctx, "test", change2) - require.NoError(t, err) - - sub, err := events.NewSubscriber(nats.SubscriberConfig, - nc.ManualAck(), - nc.AckExplicit(), - nc.Durable("test-consumer"), - ) + chgMsg, err = conn.PublishChange(ctx, "test", change2) require.NoError(t, err) + require.NoError(t, chgMsg.Error()) + require.Equal(t, change2, chgMsg.Message()) - messages, err := sub.SubscribeChanges(context.Background(), ">") + messages, err := conn.SubscribeChanges(context.Background(), ">") require.NoError(t, err) - err = nats.SetConsumerSampleFrequency("test-consumer", "100") + err = nats.SetConsumerSampleFrequency(consumerName, "100") require.NoError(t, err) receivedMsg, err := getSingleMessage(messages, time.Second*1) require.NoError(t, err) + require.NoError(t, receivedMsg.Error()) + assert.EqualValues(t, change1, receivedMsg.Message()) + assert.NoError(t, receivedMsg.Ack()) - chgMsg, err := events.UnmarshalChangeMessage(receivedMsg.Payload) - require.NoError(t, err) - assert.EqualValues(t, change1, chgMsg) - assert.True(t, receivedMsg.Ack()) - - err = nats.WaitForAck("test-consumer", time.Second) + err = nats.WaitForAck(consumerName, time.Second*2) require.NoError(t, err) receivedMsg, err = getSingleMessage(messages, time.Second*1) require.NoError(t, err) + require.NoError(t, receivedMsg.Error()) + assert.EqualValues(t, change2, receivedMsg.Message()) + assert.NoError(t, receivedMsg.Nak(0)) - chgMsg, err = events.UnmarshalChangeMessage(receivedMsg.Payload) - require.NoError(t, err) - assert.EqualValues(t, change2, chgMsg) - assert.True(t, receivedMsg.Nack()) - - err = nats.WaitForAck("test-consumer", time.Second) + err = nats.WaitForAck(consumerName, time.Second*2) require.ErrorIs(t, err, eventtools.ErrNack) - err = nats.WaitForAck("test-consumer", time.Second) + err = nats.WaitForAck(consumerName, time.Second*2) require.ErrorIs(t, err, eventtools.ErrNoAck) } -func getSingleMessage(messages <-chan *message.Message, timeout time.Duration) (*message.Message, error) { +func getSingleMessage[T any](messages <-chan events.Message[T], timeout time.Duration) (events.Message[T], error) { select { case message := <-messages: return message, nil