From 3eccb7486bf60feea5fd2d5703102f001b85277d Mon Sep 17 00:00:00 2001 From: Sergio Moya <1083296+smoya@users.noreply.github.com> Date: Thu, 22 Jul 2021 19:49:13 +0200 Subject: [PATCH 1/4] feat: add message validation based on schema --- .golangci.yml | 6 ++ asyncapi/document.go | 2 +- asyncapi/v2/decode_test.go | 2 +- asyncapi/v2/v2.go | 133 ++++++++++++++++++++------------- asyncapi/v2/validation.go | 50 +++++++++++++ asyncapi/v2/validation_test.go | 73 ++++++++++++++++++ config/config.go | 17 ++++- config/kafka.go | 79 ++++++++++++++++++-- config/kafka_test.go | 42 ++++++++--- go.mod | 1 + go.sum | 4 +- kafka/config.go | 14 ++++ kafka/proxy.go | 98 ++++++++++++++++++------ kafka/proxy_test.go | 52 +++++++++---- main.go | 46 +++++++++++- proxy/validation.go | 106 ++++++++++++++++++++++++++ proxy/validation_test.go | 128 +++++++++++++++++++++++++++++++ 17 files changed, 737 insertions(+), 116 deletions(-) create mode 100644 asyncapi/v2/validation.go create mode 100644 asyncapi/v2/validation_test.go create mode 100644 proxy/validation.go create mode 100644 proxy/validation_test.go diff --git a/.golangci.yml b/.golangci.yml index 6ea1438..ecf1c22 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -45,3 +45,9 @@ linters-settings: forbidigo: forbid: - fmt.Print.* # usually just used for debugging purpose + +issues: + exclude-rules: + - path: _test.go + linters: + - funlen diff --git a/asyncapi/document.go b/asyncapi/document.go index c1a1737..ea05418 100644 --- a/asyncapi/document.go +++ b/asyncapi/document.go @@ -128,7 +128,7 @@ type Schema interface { Property(name string) Schema PropertyNames() Schema ReadOnly() bool - Required() string // TODO string[] + Required() []string Then() Schema Title() string Type() []string // TODO // string | string[] diff --git a/asyncapi/v2/decode_test.go b/asyncapi/v2/decode_test.go index 654072b..197fd62 100644 --- a/asyncapi/v2/decode_test.go +++ b/asyncapi/v2/decode_test.go @@ -55,7 +55,7 @@ func TestDecodeFromFile(t *testing.T) { } } -//nolint:misspell,funlen +//nolint:misspell func TestDecodeFromPlainText(t *testing.T) { raw := []byte(` asyncapi: '2.0.0' diff --git a/asyncapi/v2/v2.go b/asyncapi/v2/v2.go index cf88c46..cd41d14 100644 --- a/asyncapi/v2/v2.go +++ b/asyncapi/v2/v2.go @@ -168,6 +168,13 @@ func (d Document) filterOperations(filter func(operation asyncapi.Operation) boo return operations } +// NewChannel creates a new Channel. Useful for testing. +func NewChannel(path string) *Channel { + return &Channel{ + PathField: path, + } +} + type Channel struct { Extendable Describable `mapstructure:",squash"` @@ -250,12 +257,22 @@ type SubscribeOperation struct { Operation } +// NewSubscribeOperation creates a new SubscribeOperation. Useful for testing. +func NewSubscribeOperation(msg *Message) *SubscribeOperation { + return &SubscribeOperation{Operation: *NewOperation(OperationTypeSubscribe, msg)} +} + func (o SubscribeOperation) MapStructureDefaults() map[string]interface{} { return map[string]interface{}{ "operationType": OperationTypeSubscribe, } } +// NewPublishOperation creates a new PublishOperation. Useful for testing. +func NewPublishOperation(msg *Message) *PublishOperation { + return &PublishOperation{Operation: *NewOperation(OperationTypePublish, msg)} +} + type PublishOperation struct { Operation } @@ -266,6 +283,19 @@ func (o PublishOperation) MapStructureDefaults() map[string]interface{} { } } +// NewOperation creates a new Operation. Useful for testing. +func NewOperation(operationType asyncapi.OperationType, msg *Message) *Operation { + op := &Operation{ + OperationType: operationType, + } + + if msg != nil { + op.MessageField = msg + } + + return op +} + type Operation struct { Extendable Describable `mapstructure:",squash"` @@ -377,57 +407,57 @@ func (s Schemas) ToInterface(dst map[string]asyncapi.Schema) map[string]asyncapi type Schema struct { Extendable - AdditionalItemsField *Schema `mapstructure:"additionalItems"` - AdditionalPropertiesField *Schema `mapstructure:"additionalProperties"` - AllOfField []asyncapi.Schema `mapstructure:"allOf"` - AnyOfField []asyncapi.Schema `mapstructure:"anyOf"` - ConstField interface{} `mapstructure:"const"` - ContainsField *Schema `mapstructure:"contains"` - ContentEncodingField string `mapstructure:"contentEncoding"` - ContentMediaTypeField string `mapstructure:"contentMediaType"` - DefaultField interface{} `mapstructure:"default"` - DefinitionsField Schemas `mapstructure:"definitions"` - DependenciesField Schemas `mapstructure:"dependencies"` - DeprecatedField bool `mapstructure:"deprecated"` - DescriptionField string `mapstructure:"description"` - DiscriminatorField string `mapstructure:"discriminator"` - ElseField *Schema `mapstructure:"else"` - EnumField []interface{} `mapstructure:"enum"` - ExamplesField []interface{} `mapstructure:"examples"` - ExclusiveMaximumField *float64 `mapstructure:"exclusiveMaximum"` - ExclusiveMinimumField *float64 `mapstructure:"exclusiveMinimum"` - FormatField string `mapstructure:"format"` - IDField string `mapstructure:"$id"` - IfField *Schema `mapstructure:"if"` - ItemsField []asyncapi.Schema `mapstructure:"items"` - MaximumField *float64 `mapstructure:"maximum"` - MaxItemsField *float64 `mapstructure:"maxItems"` - MaxLengthField *float64 `mapstructure:"maxLength"` - MaxPropertiesField *float64 `mapstructure:"maxProperties"` - MinimumField *float64 `mapstructure:"minimum"` - MinItemsField *float64 `mapstructure:"minItems"` - MinLengthField *float64 `mapstructure:"minLength"` - MinPropertiesField *float64 `mapstructure:"minProperties"` - MultipleOfField *float64 `mapstructure:"multipleOf"` - NotField *Schema `mapstructure:"not"` - OneOfField *Schema `mapstructure:"oneOf"` - PatternField string `mapstructure:"pattern"` - PatternPropertiesField Schemas `mapstructure:"patternProperties"` - PropertiesField Schemas `mapstructure:"properties"` - PropertyNamesField *Schema `mapstructure:"propertyNames"` - ReadOnlyField bool `mapstructure:"readOnly"` - RequiredField string `mapstructure:"required"` - ThenField *Schema `mapstructure:"then"` - TitleField string `mapstructure:"title"` - TypeField interface{} `mapstructure:"type"` // string | []string - UniqueItemsField bool `mapstructure:"uniqueItems"` - WriteOnlyField bool `mapstructure:"writeOnly"` + AdditionalItemsField *Schema `mapstructure:"additionalItems" json:"additionalItems,omitempty"` + AdditionalPropertiesField *Schema `mapstructure:"additionalProperties" json:"additionalProperties,omitempty"` + AllOfField []asyncapi.Schema `mapstructure:"allOf" json:"allOf,omitempty"` + AnyOfField []asyncapi.Schema `mapstructure:"anyOf" json:"anyOf,omitempty"` + ConstField interface{} `mapstructure:"const" json:"const,omitempty"` + ContainsField *Schema `mapstructure:"contains" json:"contains,omitempty"` + ContentEncodingField string `mapstructure:"contentEncoding" json:"contentEncoding,omitempty"` + ContentMediaTypeField string `mapstructure:"contentMediaType" json:"contentMediaType,omitempty"` + DefaultField interface{} `mapstructure:"default" json:"default,omitempty"` + DefinitionsField Schemas `mapstructure:"definitions" json:"definitions,omitempty"` + DependenciesField Schemas `mapstructure:"dependencies" json:"dependencies,omitempty"` + DeprecatedField bool `mapstructure:"deprecated" json:"deprecated,omitempty"` + DescriptionField string `mapstructure:"description" json:"description,omitempty"` + DiscriminatorField string `mapstructure:"discriminator" json:"discriminator,omitempty"` + ElseField *Schema `mapstructure:"else" json:"else,omitempty"` + EnumField []interface{} `mapstructure:"enum" json:"enum,omitempty"` + ExamplesField []interface{} `mapstructure:"examples" json:"examples,omitempty"` + ExclusiveMaximumField *float64 `mapstructure:"exclusiveMaximum" json:"exclusiveMaximum,omitempty"` + ExclusiveMinimumField *float64 `mapstructure:"exclusiveMinimum" json:"exclusiveMinimum,omitempty"` + FormatField string `mapstructure:"format" json:"format,omitempty"` + IDField string `mapstructure:"$id" json:"$id,omitempty"` + IfField *Schema `mapstructure:"if" json:"if,omitempty"` + ItemsField []asyncapi.Schema `mapstructure:"items" json:"items,omitempty"` + MaximumField *float64 `mapstructure:"maximum" json:"maximum,omitempty"` + MaxItemsField *float64 `mapstructure:"maxItems" json:"maxItems,omitempty"` + MaxLengthField *float64 `mapstructure:"maxLength" json:"maxLength,omitempty"` + MaxPropertiesField *float64 `mapstructure:"maxProperties" json:"maxProperties,omitempty"` + MinimumField *float64 `mapstructure:"minimum" json:"minimum,omitempty"` + MinItemsField *float64 `mapstructure:"minItems" json:"minItems,omitempty"` + MinLengthField *float64 `mapstructure:"minLength" json:"minLength,omitempty"` + MinPropertiesField *float64 `mapstructure:"minProperties" json:"minProperties,omitempty"` + MultipleOfField *float64 `mapstructure:"multipleOf" json:"multipleOf,omitempty"` + NotField *Schema `mapstructure:"not" json:"not,omitempty"` + OneOfField *Schema `mapstructure:"oneOf" json:"oneOf,omitempty"` + PatternField string `mapstructure:"pattern" json:"pattern,omitempty"` + PatternPropertiesField Schemas `mapstructure:"patternProperties" json:"patternProperties,omitempty"` + PropertiesField Schemas `mapstructure:"properties" json:"properties,omitempty"` + PropertyNamesField *Schema `mapstructure:"propertyNames" json:"propertyNames,omitempty"` + ReadOnlyField bool `mapstructure:"readOnly" json:"readOnly,omitempty"` + RequiredField []string `mapstructure:"required" json:"required,omitempty"` + ThenField *Schema `mapstructure:"then" json:"then,omitempty"` + TitleField string `mapstructure:"title" json:"title,omitempty"` + TypeField interface{} `mapstructure:"type" json:"type"` // string | []string + UniqueItemsField bool `mapstructure:"uniqueItems" json:"uniqueItems,omitempty"` + WriteOnlyField bool `mapstructure:"writeOnly" json:"writeOnly,omitempty"` // cached converted map[string]asyncapi.Schema from map[string]*Schema - propertiesFieldMap map[string]asyncapi.Schema - patternPropertiesFieldMap map[string]asyncapi.Schema - DefinitionsFieldMap map[string]asyncapi.Schema - DependenciesFieldMap map[string]asyncapi.Schema + propertiesFieldMap map[string]asyncapi.Schema `json:"-"` + patternPropertiesFieldMap map[string]asyncapi.Schema `json:"-"` + DefinitionsFieldMap map[string]asyncapi.Schema `json:"-"` + DependenciesFieldMap map[string]asyncapi.Schema `json:"-"` } func (s *Schema) AdditionalItems() asyncapi.Schema { @@ -617,8 +647,7 @@ func (s *Schema) ReadOnly() bool { return s.ReadOnlyField } -func (s *Schema) Required() string { - // TODO string[] +func (s *Schema) Required() []string { return s.RequiredField } @@ -759,7 +788,7 @@ func (d Describable) HasDescription() bool { } type Extendable struct { - Raw map[string]interface{} `mapstructure:",remain"` + Raw map[string]interface{} `mapstructure:",remain" json:"-"` } func (e Extendable) HasExtension(name string) bool { diff --git a/asyncapi/v2/validation.go b/asyncapi/v2/validation.go new file mode 100644 index 0000000..114904c --- /dev/null +++ b/asyncapi/v2/validation.go @@ -0,0 +1,50 @@ +package v2 + +import ( + "encoding/json" + "fmt" + + "github.com/asyncapi/event-gateway/asyncapi" + "github.com/asyncapi/event-gateway/proxy" + "github.com/xeipuuv/gojsonschema" +) + +func FromDocJSONSchemaMessageValidator(doc asyncapi.Document) (proxy.MessageValidator, error) { + channels := doc.ApplicationSubscribableChannels() + messageSchemas := make(map[string]gojsonschema.JSONLoader) + for _, c := range channels { + for _, o := range c.Operations() { + if !o.IsApplicationSubscribing() { + continue + } + + // Assuming there is only one message per operation as per Asyncapi 2.x.x. + // See https://github.com/asyncapi/event-gateway/issues/10 + if len(o.Messages()) > 1 { + return nil, fmt.Errorf("can not generate message validation for operation %s. Reason: the operation has more than one message and we can't correlate which one is it", o.ID()) + } + + if len(o.Messages()) == 0 { + return nil, fmt.Errorf("can not generate message validation for operation %s. Reason:. Operation has no message. This is totally unexpected", o.ID()) + } + + // Assuming there is only one message per operation and one operation of a particular type per Channel. + // See https://github.com/asyncapi/event-gateway/issues/10 + msg := o.Messages()[0] + + raw, err := json.Marshal(msg.Payload()) + if err != nil { + return nil, fmt.Errorf("error marshaling message payload for generating json schema for validation. Operation: %s, Message: %s", o.ID(), msg.Name()) + } + + messageSchemas[c.ID()] = gojsonschema.NewBytesLoader(raw) + } + } + + idProvider := func(msg *proxy.Message) string { + // messageSchemas map is indexed by Channel name, so we need to tell the validator. + return msg.Context.Channel + } + + return proxy.JSONSchemaMessageValidator(messageSchemas, idProvider) +} diff --git a/asyncapi/v2/validation_test.go b/asyncapi/v2/validation_test.go new file mode 100644 index 0000000..cefaa0f --- /dev/null +++ b/asyncapi/v2/validation_test.go @@ -0,0 +1,73 @@ +package v2 + +import ( + "testing" + + "github.com/asyncapi/event-gateway/proxy" + "github.com/stretchr/testify/assert" +) + +func TestFromDocJsonSchemaMessageValidator(t *testing.T) { + msg := &Message{ + PayloadField: &Schema{ + TypeField: "object", + PropertiesField: Schemas{ + "AnIntergerField": &Schema{ + Extendable: Extendable{}, + MaximumField: refFloat64(10), + MinimumField: refFloat64(3), + RequiredField: []string{"AnIntergerField"}, + TypeField: "number", + }, + }, + }, + } + channel := NewChannel("test") + channel.Subscribe = NewSubscribeOperation(msg) + + doc := Document{ + Extendable: Extendable{}, + ChannelsField: map[string]Channel{ + "test": *channel, + }, + } + + tests := []struct { + name string + valid bool + payload []byte + }{ + { + name: "Valid payload", + payload: []byte(`{"AnIntergerField": 5}`), + valid: true, + }, + { + name: "Invalid payload", + payload: []byte(`{"AnIntergerField": 1}`), + valid: false, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + validator, err := FromDocJSONSchemaMessageValidator(doc) + assert.NoError(t, err) + + msg := &proxy.Message{ + Context: proxy.MessageContext{ + Channel: "test", + }, + Value: test.payload, + } + validationErr, err := validator(msg) + assert.NoError(t, err) + + if test.valid { + assert.Nil(t, validationErr) + } else { + assert.NotNil(t, validationErr) + assert.False(t, validationErr.Result.Valid()) + } + }) + } +} diff --git a/config/config.go b/config/config.go index c4fd8d5..3921232 100644 --- a/config/config.go +++ b/config/config.go @@ -9,8 +9,21 @@ import ( // App holds the config for the whole application. type App struct { Debug bool - AsyncAPIDoc []byte `split_words:"true"` - KafkaProxy KafkaProxy `split_words:"true"` + AsyncAPIDoc []byte `split_words:"true"` + KafkaProxy *KafkaProxy `split_words:"true"` +} + +// Opt is a functional option used for configuring an App. +type Opt func(*App) + +// NewApp creates a App config with defaults. +func NewApp(opts ...Opt) *App { + c := &App{KafkaProxy: NewKafkaProxy()} + for _, opt := range opts { + opt(c) + } + + return c } // ProxyConfig creates a config struct for the Kafka Proxy. diff --git a/config/kafka.go b/config/kafka.go index cb2bea1..923d9cb 100644 --- a/config/kafka.go +++ b/config/kafka.go @@ -5,6 +5,8 @@ import ( "net" "strings" + "github.com/asyncapi/event-gateway/proxy" + "github.com/asyncapi/event-gateway/asyncapi" v2 "github.com/asyncapi/event-gateway/asyncapi/v2" "github.com/asyncapi/event-gateway/kafka" @@ -17,10 +19,31 @@ type KafkaProxy struct { BrokersMapping pipeSeparatedValues `split_words:"true"` BrokersDialMapping pipeSeparatedValues `split_words:"true"` ExtraFlags pipeSeparatedValues `split_words:"true"` + MessageValidation MessageValidation `split_words:"true"` +} + +// MessageValidation holds the config about message validation. +type MessageValidation struct { + Enabled bool + Notifier proxy.ValidationErrorNotifier +} + +// NotifyValidationErrorOnChan sets a channel as ValidationError notifier. +func NotifyValidationErrorOnChan(errChan chan *proxy.ValidationError) Opt { + return func(app *App) { + app.KafkaProxy.MessageValidation.Notifier = proxy.ValidationErrorToChanNotifier(errChan) + } +} + +// NewKafkaProxy creates a KafkaProxy with defaults. +func NewKafkaProxy() *KafkaProxy { + return &KafkaProxy{MessageValidation: MessageValidation{ + Enabled: true, + }} } // ProxyConfig creates a config struct for the Kafka Proxy based on a given AsyncAPI doc (if provided). -func (c *KafkaProxy) ProxyConfig(doc []byte, debug bool) (*kafka.ProxyConfig, error) { +func (c *KafkaProxy) ProxyConfig(doc []byte, debug bool, messageHandlers ...kafka.MessageHandler) (*kafka.ProxyConfig, error) { if len(doc) == 0 && len(c.BrokersMapping.Values) == 0 { return nil, errors.New("either AsyncAPIDoc or KafkaProxyBrokersMapping config should be provided") } @@ -42,6 +65,7 @@ func (c *KafkaProxy) ProxyConfig(doc []byte, debug bool) (*kafka.ProxyConfig, er } kafkaProxyConfig.Debug = debug + kafkaProxyConfig.MessageHandlers = append(kafkaProxyConfig.MessageHandlers, messageHandlers...) return kafkaProxyConfig, nil } @@ -52,18 +76,32 @@ func (c *KafkaProxy) configFromDoc(d []byte) (*kafka.ProxyConfig, error) { return nil, errors.Wrap(err, "error decoding AsyncAPI json doc to Document struct") } + var opts []kafka.Option + if c.MessageValidation.Enabled { + validator, err := v2.FromDocJSONSchemaMessageValidator(doc) + if err != nil { + return nil, errors.Wrap(err, "error creating message validator") + } + + if notifier := c.MessageValidation.Notifier; notifier != nil { + validator = proxy.NotifyOnValidationError(validator, notifier) + } + + opts = append(opts, kafka.WithMessageHandlers(validateMessageHandler(validator))) + } + if c.BrokerFromServer != "" { - return kafkaProxyConfigFromServer(c.BrokerFromServer, doc) + return kafkaProxyConfigFromServer(c.BrokerFromServer, doc, opts...) } - return kafkaProxyConfigFromAllServers(doc.Servers()) + return kafkaProxyConfigFromAllServers(doc.Servers(), opts...) } func isValidKafkaProtocol(s asyncapi.Server) bool { return strings.HasPrefix(s.Protocol(), "kafka") } -func kafkaProxyConfigFromAllServers(servers []asyncapi.Server) (*kafka.ProxyConfig, error) { +func kafkaProxyConfigFromAllServers(servers []asyncapi.Server, opts ...kafka.Option) (*kafka.ProxyConfig, error) { var brokersMapping []string var dialAddressMapping []string for _, s := range servers { @@ -82,10 +120,12 @@ func kafkaProxyConfigFromAllServers(servers []asyncapi.Server) (*kafka.ProxyConf } } - return kafka.NewProxyConfig(brokersMapping, kafka.WithDialAddressMapping(dialAddressMapping)) + opts = append(opts, kafka.WithDialAddressMapping(dialAddressMapping)) + + return kafka.NewProxyConfig(brokersMapping, opts...) } -func kafkaProxyConfigFromServer(name string, doc asyncapi.Document) (*kafka.ProxyConfig, error) { +func kafkaProxyConfigFromServer(name string, doc asyncapi.Document, opts ...kafka.Option) (*kafka.ProxyConfig, error) { s, ok := doc.Server(name) if !ok { return nil, fmt.Errorf("server %s not found in the provided AsyncAPI doc", name) @@ -101,10 +141,35 @@ func kafkaProxyConfigFromServer(name string, doc asyncapi.Document) (*kafka.Prox return nil, errors.Wrapf(err, "error getting port from broker %s. URL:%s", s.Name(), s.URL()) } - var opts []kafka.Option if dialMapping := s.Extension(asyncapi.ExtensionEventGatewayDialMapping); dialMapping != nil { opts = append(opts, kafka.WithDialAddressMapping([]string{fmt.Sprintf("%s,%s", s.URL(), dialMapping)})) } return kafka.NewProxyConfig([]string{fmt.Sprintf("%s,:%s", s.URL(), port)}, opts...) } + +func validateMessageHandler(validator proxy.MessageValidator) kafka.MessageHandler { + return func(msg kafka.Message) error { + pMsg := &proxy.Message{ + Context: proxy.MessageContext{ + Channel: msg.Context.Topic, + }, + Key: msg.Key, + Value: msg.Value, + } + + if len(msg.Headers) > 0 { + pMsg.Headers = make([]proxy.MessageHeader, len(msg.Headers)) + for i := 0; i < len(msg.Headers); i++ { + pMsg.Headers[i] = proxy.MessageHeader{ + Key: msg.Headers[i].Key, + Value: msg.Headers[i].Value, + } + } + } + + _, err := validator(pMsg) + + return err + } +} diff --git a/config/kafka_test.go b/config/kafka_test.go index d228f2f..a5b87eb 100644 --- a/config/kafka_test.go +++ b/config/kafka_test.go @@ -9,13 +9,12 @@ import ( "github.com/stretchr/testify/assert" ) -//nolint:funlen func TestKafkaProxy_ProxyConfig(t *testing.T) { tests := []struct { name string config *KafkaProxy doc []byte - expectedProxyConfig *kafka.ProxyConfig + expectedProxyConfig func(*testing.T, *kafka.ProxyConfig) *kafka.ProxyConfig expectedErr error }{ { @@ -23,8 +22,25 @@ func TestKafkaProxy_ProxyConfig(t *testing.T) { config: &KafkaProxy{ BrokerFromServer: "test", }, - expectedProxyConfig: &kafka.ProxyConfig{ - BrokersMapping: []string{"broker.mybrokers.org:9092,:9092"}, + expectedProxyConfig: func(_ *testing.T, _ *kafka.ProxyConfig) *kafka.ProxyConfig { + return &kafka.ProxyConfig{ + BrokersMapping: []string{"broker.mybrokers.org:9092,:9092"}, + } + }, + doc: []byte(`testdata/simple-kafka.yaml`), + }, + { + name: "Valid config. Only one broker from doc + enable message validation", + config: &KafkaProxy{ + BrokerFromServer: "test", + MessageValidation: MessageValidation{ + Enabled: true, + }, + }, + expectedProxyConfig: func(t *testing.T, c *kafka.ProxyConfig) *kafka.ProxyConfig { + assert.Equal(t, []string{"broker.mybrokers.org:9092,:9092"}, c.BrokersMapping) + assert.Len(t, c.MessageHandlers, 1) + return nil }, doc: []byte(`testdata/simple-kafka.yaml`), }, @@ -33,8 +49,10 @@ func TestKafkaProxy_ProxyConfig(t *testing.T) { config: &KafkaProxy{ BrokersMapping: pipeSeparatedValues{Values: []string{"broker.mybrokers.org:9092,:9092"}}, }, - expectedProxyConfig: &kafka.ProxyConfig{ - BrokersMapping: []string{"broker.mybrokers.org:9092,:9092"}, + expectedProxyConfig: func(_ *testing.T, _ *kafka.ProxyConfig) *kafka.ProxyConfig { + return &kafka.ProxyConfig{ + BrokersMapping: []string{"broker.mybrokers.org:9092,:9092"}, + } }, }, { @@ -43,9 +61,11 @@ func TestKafkaProxy_ProxyConfig(t *testing.T) { BrokersMapping: pipeSeparatedValues{Values: []string{"broker.mybrokers.org:9092,:9092"}}, BrokersDialMapping: pipeSeparatedValues{Values: []string{"broker.mybrokers.org:9092,192.168.1.10:9092"}}, }, - expectedProxyConfig: &kafka.ProxyConfig{ - BrokersMapping: []string{"broker.mybrokers.org:9092,:9092"}, - DialAddressMapping: []string{"broker.mybrokers.org:9092,192.168.1.10:9092"}, + expectedProxyConfig: func(_ *testing.T, _ *kafka.ProxyConfig) *kafka.ProxyConfig { + return &kafka.ProxyConfig{ + BrokersMapping: []string{"broker.mybrokers.org:9092,:9092"}, + DialAddressMapping: []string{"broker.mybrokers.org:9092,192.168.1.10:9092"}, + } }, }, { @@ -78,7 +98,9 @@ func TestKafkaProxy_ProxyConfig(t *testing.T) { } if test.expectedProxyConfig != nil { - assert.EqualValues(t, test.expectedProxyConfig, proxyConfig) + if expectedConf := test.expectedProxyConfig(t, proxyConfig); expectedConf != nil { + assert.EqualValues(t, expectedConf, proxyConfig) + } } }) } diff --git a/go.mod b/go.mod index 1c098d5..3fe6f39 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/stretchr/testify v1.7.0 github.com/xdg/scram v1.0.3 // indirect github.com/xdg/stringprep v1.0.3 // indirect + github.com/xeipuuv/gojsonschema v1.2.0 golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b // indirect golang.org/x/net v0.0.0-20210427231257-85d9c07bbe3a // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect diff --git a/go.sum b/go.sum index 2124216..aa3cf79 100644 --- a/go.sum +++ b/go.sum @@ -171,12 +171,14 @@ github.com/xdg/scram v1.0.3/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49 github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= github.com/xdg/stringprep v1.0.3 h1:cmL5Enob4W83ti/ZHuZLuKD/xqJfus4fVPwE+/BDm+4= github.com/xdg/stringprep v1.0.3/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonpointer v0.0.0-20190809123943-df4f5c81cb3b h1:6cLsL+2FW6dRAdl5iMtHgRogVCff0QpRi9653YmdcJA= github.com/xeipuuv/gojsonpointer v0.0.0-20190809123943-df4f5c81cb3b/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= -github.com/xeipuuv/gojsonschema v1.1.0 h1:ngVtJC9TY/lg0AA/1k48FYhBrhRoFlEmWzsehpNAaZg= github.com/xeipuuv/gojsonschema v1.1.0/go.mod h1:5yf86TLmAcydyeJq5YvxkGPE2fm/u4myDekKRoLuqhs= +github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/kafka/config.go b/kafka/config.go index 4347d4e..0c81e81 100644 --- a/kafka/config.go +++ b/kafka/config.go @@ -6,6 +6,8 @@ import ( "regexp" "strings" + "github.com/sirupsen/logrus" + "github.com/pkg/errors" ) @@ -16,11 +18,19 @@ type ProxyConfig struct { BrokersMapping []string DialAddressMapping []string ExtraConfig []string + MessageHandlers []MessageHandler Debug bool } type Option func(*ProxyConfig) error +func WithMessageHandlers(messageHandlers ...MessageHandler) Option { + return func(c *ProxyConfig) error { + c.MessageHandlers = append(c.MessageHandlers, messageHandlers...) + return nil + } +} + // WithDebug enables debug. func WithDebug() Option { return func(c *ProxyConfig) error { @@ -85,6 +95,10 @@ func (c *ProxyConfig) Validate() error { } } + if len(c.MessageHandlers) == 0 { + logrus.Warn("There are no message handlers configured") + } + return nil } diff --git a/kafka/proxy.go b/kafka/proxy.go index be6b6d0..8286a80 100644 --- a/kafka/proxy.go +++ b/kafka/proxy.go @@ -15,6 +15,24 @@ import ( "github.com/sirupsen/logrus" ) +// Context is the context that surrounds a Message. +type Context struct { + Topic string +} + +// Message is a message flowing through a Kafka topic. +type Message struct { + Context Context + Key []byte + Value []byte + Headers []*protocol.RecordHeader +} + +// MessageHandler handles a Kafka message. +// If error is returned, kafka request will fail. +// Note: Message manipulation is not allowed. +type MessageHandler func(Message) error + // NewProxy creates a new Kafka Proxy based on a given configuration. func NewProxy(c *ProxyConfig) (proxy.Proxy, error) { if c == nil { @@ -26,7 +44,7 @@ func NewProxy(c *ProxyConfig) (proxy.Proxy, error) { } // Yeah, not a good practice at all but I guess it's fine for now. - kafkaproxy.ActualDefaultRequestHandler.RequestKeyHandlers.Set(protocol.RequestAPIKeyProduce, &requestKeyHandler{}) + kafkaproxy.ActualDefaultRequestHandler.RequestKeyHandlers.Set(protocol.RequestAPIKeyProduce, NewProduceRequestHandler(c.MessageHandlers...)) if c.BrokersMapping == nil { return nil, errors.New("Brokers mapping is required") @@ -54,13 +72,28 @@ func NewProxy(c *ProxyConfig) (proxy.Proxy, error) { }, nil } -type requestKeyHandler struct{} +// NewProduceRequestHandler creates a new request key handler for the Produce Request. +func NewProduceRequestHandler(msgHandlers ...MessageHandler) kafkaproxy.KeyHandler { + return &produceRequestHandler{ + msgHandlers: msgHandlers, + } +} + +type produceRequestHandler struct { + msgHandlers []MessageHandler +} + +func (h *produceRequestHandler) Handle(requestKeyVersion *kafkaprotocol.RequestKeyVersion, src io.Reader, ctx *kafkaproxy.RequestsLoopContext, bufferRead *bytes.Buffer) (shouldReply bool, err error) { + if len(h.msgHandlers) == 0 { + logrus.Infoln("No message handlers were set. Skipping produceRequestHandler") + return true, nil + } -func (r *requestKeyHandler) Handle(requestKeyVersion *kafkaprotocol.RequestKeyVersion, src io.Reader, ctx *kafkaproxy.RequestsLoopContext, bufferRead *bytes.Buffer) (shouldReply bool, err error) { if requestKeyVersion.ApiKey != protocol.RequestAPIKeyProduce { return true, nil } + // TODO error handling should be responsibility of an error handler instead of being just logged. shouldReply, err = kafkaproxy.DefaultProduceKeyHandlerFunc(requestKeyVersion, src, ctx, bufferRead) if err != nil { return @@ -70,41 +103,58 @@ func (r *requestKeyHandler) Handle(requestKeyVersion *kafkaprotocol.RequestKeyVe if _, err = io.ReadFull(io.TeeReader(src, bufferRead), msg); err != nil { return } + var req protocol.ProduceRequest if err = protocol.VersionedDecode(msg, &req, requestKeyVersion.ApiVersion); err != nil { - logrus.Errorln(errors.Wrap(err, "error decoding ProduceRequest")) - // TODO notify error to a given notifier - - // Do not return an error but log it. + logrus.WithError(err).Error("error decoding ProduceRequest") return shouldReply, nil } - for _, r := range req.Records { + msgs := h.extractMessages(req) + if len(msgs) == 0 { + logrus.Error("The produce request has no messages") + return + } + + for _, m := range msgs { + for _, h := range h.msgHandlers { + if err := h(m); err != nil { + logrus.WithError(err).Error("error handling message") + return shouldReply, nil + } + } + } + + return shouldReply, nil +} + +func (h *produceRequestHandler) extractMessages(req protocol.ProduceRequest) []Message { + var msgs []Message + for topic, r := range req.Records { for _, s := range r { if s.RecordBatch != nil { for _, r := range s.RecordBatch.Records { - if !isValid(r.Value) { - logrus.Debugln("Message is not valid") - } else { - logrus.Debugln("Message is valid") - } + msgs = append(msgs, Message{ + Context: Context{ + Topic: topic, + }, + Value: r.Value, + Headers: r.Headers, + }) } } if s.MsgSet != nil { for _, mb := range s.MsgSet.Messages { - if !isValid(mb.Msg.Value) { - logrus.Debugln("Message is not valid") - } else { - logrus.Debugln("Message is valid") - } + msgs = append(msgs, Message{ + Context: Context{ + Topic: topic, + }, + Value: mb.Msg.Value, + Key: mb.Msg.Key, + }) } } } } - - return shouldReply, nil -} - -func isValid(msg []byte) bool { - return string(msg) != "invalid message" + return msgs } diff --git a/kafka/proxy_test.go b/kafka/proxy_test.go index f5f33e8..35c7855 100644 --- a/kafka/proxy_test.go +++ b/kafka/proxy_test.go @@ -6,12 +6,15 @@ import ( "hash/crc32" "testing" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + logrustest "github.com/sirupsen/logrus/hooks/test" + "github.com/asyncapi/event-gateway/proxy" kafkaproxy "github.com/grepplabs/kafka-proxy/proxy" kafkaprotocol "github.com/grepplabs/kafka-proxy/proxy/protocol" "github.com/pkg/errors" - "github.com/sirupsen/logrus" - logrustest "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" ) @@ -51,13 +54,14 @@ func TestNewKafka(t *testing.T) { } } -func TestRequestKeyHandler_Handle(t *testing.T) { +func TestProduceRequestHandler_Handle(t *testing.T) { tests := []struct { name string request []byte shouldReply bool apiKey int16 shouldSkipRequest bool + expectedLoggedErr error }{ { name: "Valid message", @@ -65,9 +69,10 @@ func TestRequestKeyHandler_Handle(t *testing.T) { shouldReply: true, }, { - name: "Invalid message", - request: generateProduceRequestV8("invalid message"), - shouldReply: true, + name: "Invalid message", + request: generateProduceRequestV8("invalid message"), + shouldReply: true, + expectedLoggedErr: errors.New("message is invalid"), }, { name: "Other Requests (different than Produce type) are skipped", @@ -86,9 +91,17 @@ func TestRequestKeyHandler_Handle(t *testing.T) { Length: int32(len(test.request) + 4), // 4 bytes are ApiKey + Version located in all request headers (already read by the time of validating the msg). } - readBytes := bytes.NewBuffer(nil) - var h requestKeyHandler + simpleMessageValidationHandler := func(m Message) error { + if string(m.Value) == "invalid message" { + return errors.New("message is invalid") + } + + return nil + } + h := NewProduceRequestHandler(simpleMessageValidationHandler) + + readBytes := bytes.NewBuffer(nil) shouldReply, err := h.Handle(kv, bytes.NewReader(test.request), &kafkaproxy.RequestsLoopContext{}, readBytes) assert.NoError(t, err) assert.Equal(t, test.shouldReply, shouldReply) @@ -99,14 +112,20 @@ func TestRequestKeyHandler_Handle(t *testing.T) { assert.Equal(t, readBytes.Len(), len(test.request)) } - for _, l := range log.AllEntries() { - assert.NotEqualf(t, l.Level, logrus.ErrorLevel, "%q logged error unexpected", l.Message) // We don't have a notification mechanism for errors yet + if test.expectedLoggedErr != nil { + entry := log.LastEntry() + require.NotEmpty(t, entry) + require.Contains(t, entry.Data, logrus.ErrorKey) + assert.EqualError(t, entry.Data[logrus.ErrorKey].(error), test.expectedLoggedErr.Error()) + } else { + for _, l := range log.AllEntries() { + assert.NotEqualf(t, l.Level, logrus.ErrorLevel, "%q logged error unexpected", l.Message) // We don't have a notification mechanism for errors yet + } } }) } } -//nolint:funlen func generateProduceRequestV8(payload string) []byte { // Note: Taking V8 as random version. buf := bytes.NewBuffer(nil) @@ -140,8 +159,7 @@ func generateProduceRequestV8(payload string) []byte { // batch len batchLen := make([]byte, 4) - binary.BigEndian.Uint32(requestSize) - binary.BigEndian.PutUint32(batchLen, requestSizeInt+uint32(len(baseOffset)+len(payload))) + binary.BigEndian.PutUint32(batchLen, requestSizeInt-uint32(len(baseOffset)+len(batchLen))) buf.Write(batchLen) // partition leader epoch: 255, 255, 255, 255 @@ -165,8 +183,10 @@ func generateProduceRequestV8(payload string) []byte { buf.Write([]byte{0, 0, 0, 0, 0, 0, 0, 0, 1, 122, 129, 58, 129, 47, 0, 0, 1, 122, 129, 58, 129, 47, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 1}) // record len - recordLenInt := 27 + len(payload) - buf.WriteByte(byte(recordLenInt)) + // attributes + timestamp delta + offset + key + payload len field + actual payload len + recordLen := make([]byte, 1) + binary.PutVarint(recordLen, int64(4+1+len(payload)+1)) + buf.Write(recordLen) // attributes: 0 // timestamp delta: 0 @@ -187,7 +207,7 @@ func generateProduceRequestV8(payload string) []byte { table := crc32.MakeTable(crc32.Castagnoli) crc32Calculator := crc32.New(table) - crc32Calculator.Write(buf.Bytes()[crc32ReservationStart+4:]) + _, _ = crc32Calculator.Write(buf.Bytes()[crc32ReservationStart+4:]) hash := crc32Calculator.Sum(make([]byte, 0)) for i := 0; i < len(hash); i++ { diff --git a/main.go b/main.go index c9c7c58..eaaaaf8 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,12 @@ package main import ( "context" + "os" + "os/signal" + "syscall" + "time" + + "github.com/asyncapi/event-gateway/proxy" "github.com/asyncapi/event-gateway/config" "github.com/asyncapi/event-gateway/kafka" @@ -10,8 +16,10 @@ import ( ) func main() { - var c config.App - if err := envconfig.Process("eventgateway", &c); err != nil { + validationErrChan := make(chan *proxy.ValidationError) + c := config.NewApp(config.NotifyValidationErrorOnChan(validationErrChan)) + + if err := envconfig.Process("eventgateway", c); err != nil { logrus.WithError(err).Fatal() } @@ -29,7 +37,41 @@ func main() { logrus.WithError(err).Fatal() } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + handleInterruptions(cancel) + + // At this moment, we do nothing else. + go logValidationErrors(ctx, validationErrChan) + if err := kafkaProxy(context.Background()); err != nil { logrus.WithError(err).Fatal() } } + +func logValidationErrors(ctx context.Context, validationErrChan chan *proxy.ValidationError) { + for { + select { + case validationErr, ok := <-validationErrChan: + if !ok { + return + } + + logrus.WithField("validation_errors", validationErr.String()).Errorf("error validating message") + case <-ctx.Done(): + return + } + } +} + +func handleInterruptions(cancel context.CancelFunc) { + c := make(chan os.Signal, 2) + signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) + go func() { + s := <-c + logrus.WithField("signal", s).Info("Stopping AsyncAPI Event-Gateway...") + cancel() + time.Sleep(time.Second) + os.Exit(0) + }() +} diff --git a/proxy/validation.go b/proxy/validation.go new file mode 100644 index 0000000..5bdddbe --- /dev/null +++ b/proxy/validation.go @@ -0,0 +1,106 @@ +package proxy + +import ( + "fmt" + "strings" + + "github.com/pkg/errors" + "github.com/xeipuuv/gojsonschema" +) + +// Message represents a message flowing through the wire. For example, a Kafka message. +type Message struct { + Context MessageContext + Key []byte + Value []byte + Headers []MessageHeader +} + +// MessageContext contains information about the context that surrounds a message. +type MessageContext struct { + Channel string +} + +// MessageHeader represents a header of a message, if there are any. +type MessageHeader struct { + Key []byte + Value []byte +} + +// ValidationError represents a message validation error. +type ValidationError struct { + Msg *Message + Result *gojsonschema.Result +} + +func (v ValidationError) String() string { + errs := make([]string, len(v.Result.Errors())) + for i, err := range v.Result.Errors() { + errs[i] = err.String() + } + + return strings.Join(errs, " | ") +} + +// ValidationErrorNotifier notifies whenever a ValidationError happens. +type ValidationErrorNotifier func(validationError *ValidationError) error + +// ValidationErrorToChanNotifier notifies to a given chan when a ValidationError happens. +func ValidationErrorToChanNotifier(errChan chan *ValidationError) ValidationErrorNotifier { + return func(validationError *ValidationError) error { + // TODO Blocking or non blocking? Shall we just fire and forget via goroutine instead? + errChan <- validationError + + return nil + } +} + +// MessageValidator validates a message. +// Returns a boolean indicating if the message is valid, and an error if something went wrong. +type MessageValidator func(*Message) (*ValidationError, error) + +// NotifyOnValidationError is a MessageValidator that notifies ValidationError from a given MessageValidator output to the given channel. +func NotifyOnValidationError(validator MessageValidator, notifier ValidationErrorNotifier) MessageValidator { + return func(msg *Message) (*ValidationError, error) { + validationErr, err := validator(msg) + if err != nil { + return nil, err + } + + if validationErr != nil { + if err := notifier(validationErr); err != nil { + return nil, errors.Wrap(err, "error notifying validation error") + } + + return validationErr, nil + } + + return nil, nil + } +} + +// JSONSchemaMessageValidator validates a message payload based on a map of Json Schema, where the key can be any identifier (depends on who implements it). +// For example, the identifier can be it's channel name, message ID, etc. +func JSONSchemaMessageValidator(messageSchemas map[string]gojsonschema.JSONLoader, idProvider func(msg *Message) string) (MessageValidator, error) { + return func(msg *Message) (*ValidationError, error) { + msgID := idProvider(msg) + msgSchema, ok := messageSchemas[msgID] + if !ok { + return nil, nil + } + + result, err := gojsonschema.Validate(msgSchema, gojsonschema.NewBytesLoader(msg.Value)) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("error validating JSON Schema for message %s", msgID)) + } + + if !result.Valid() { + return &ValidationError{ + Msg: msg, + Result: result, + }, nil + } + + return nil, nil + }, nil +} diff --git a/proxy/validation_test.go b/proxy/validation_test.go new file mode 100644 index 0000000..74d42f5 --- /dev/null +++ b/proxy/validation_test.go @@ -0,0 +1,128 @@ +package proxy + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/xeipuuv/gojsonschema" +) + +func TestValidationError_String(t *testing.T) { + validationErr := generateTestValidationError(nil) + assert.Equal(t, "AnIntegerField: Invalid type. Expected: integer, given: string | AStringField: Invalid type. Expected: string, given: integer", validationErr.String()) +} + +func TestNotifyOnValidationError(t *testing.T) { + expectedMessage := generateTestMessage() + validator := func(msg *Message) (*ValidationError, error) { + assert.Equal(t, expectedMessage, msg) + return generateTestValidationError(msg), nil + } + + var notified bool + notifier := func(validationError *ValidationError) error { + notified = true + return nil + } + + validationErr, err := NotifyOnValidationError(validator, notifier)(expectedMessage) + assert.NoError(t, err) + assert.False(t, validationErr.Result.Valid()) + assert.True(t, notified) +} + +func generateTestMessage() *Message { + return &Message{ + Context: MessageContext{Channel: "test"}, + Value: []byte(`Hello World!`), + } +} + +func generateTestValidationError(msg *Message) *ValidationError { + validationErr := &ValidationError{ + Msg: msg, + Result: &gojsonschema.Result{}, + } + + addTestErrors(validationErr) + return validationErr +} + +func addTestErrors(validationErr *ValidationError) { + badTypeErr := &gojsonschema.InvalidTypeError{} + badTypeErr.SetContext(gojsonschema.NewJsonContext("AnIntegerField", nil)) + badTypeErr.SetDetails(gojsonschema.ErrorDetails{ + "expected": gojsonschema.TYPE_INTEGER, + "given": gojsonschema.TYPE_STRING, + }) + badTypeErr.SetDescriptionFormat(gojsonschema.Locale.InvalidType()) + validationErr.Result.AddError(badTypeErr, badTypeErr.Details()) + + badTypeErr2 := &gojsonschema.InvalidTypeError{} + badTypeErr2.SetContext(gojsonschema.NewJsonContext("AStringField", nil)) + badTypeErr2.SetDetails(gojsonschema.ErrorDetails{ + "expected": gojsonschema.TYPE_STRING, + "given": gojsonschema.TYPE_INTEGER, + }) + badTypeErr2.SetDescriptionFormat(gojsonschema.Locale.InvalidType()) + validationErr.Result.AddError(badTypeErr2, badTypeErr2.Details()) +} + +func TestJsonSchemaMessageValidator(t *testing.T) { + schema := `{ + "properties":{ + "command":{ + "description":"Whether to turn on or off the light.", + "enum":[ + "on", + "off" + ], + "type":"string" + } + }, + "type":"object" +}` + + tests := []struct { + name string + valid bool + payload []byte + }{ + { + name: "Valid payload", + payload: []byte(`{"command": "on"}`), + valid: true, + }, + { + name: "Invalid payload", + payload: []byte(`{"command": 123123}`), + valid: false, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + expectedMsg := generateTestMessage() + expectedMsg.Value = test.payload + + idToSchemaMap := map[string]gojsonschema.JSONLoader{ + expectedMsg.Context.Channel: gojsonschema.NewStringLoader(schema), + } + + validator, err := JSONSchemaMessageValidator(idToSchemaMap, func(msg *Message) string { + assert.Equal(t, expectedMsg, msg) + return msg.Context.Channel + }) + assert.NoError(t, err) + + validationErr, err := validator(expectedMsg) + assert.NoError(t, err) + + if test.valid { + assert.Nil(t, validationErr) + } else { + assert.NotNil(t, validationErr) + assert.False(t, validationErr.Result.Valid()) + } + }) + } +} From bfca5fb9a906b2f606059912aba6a31f30239b14 Mon Sep 17 00:00:00 2001 From: Sergio Moya <1083296+smoya@users.noreply.github.com> Date: Fri, 23 Jul 2021 14:16:44 +0200 Subject: [PATCH 2/4] Type field in schema is not required. --- asyncapi/v2/v2.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncapi/v2/v2.go b/asyncapi/v2/v2.go index cd41d14..f542f50 100644 --- a/asyncapi/v2/v2.go +++ b/asyncapi/v2/v2.go @@ -449,7 +449,7 @@ type Schema struct { RequiredField []string `mapstructure:"required" json:"required,omitempty"` ThenField *Schema `mapstructure:"then" json:"then,omitempty"` TitleField string `mapstructure:"title" json:"title,omitempty"` - TypeField interface{} `mapstructure:"type" json:"type"` // string | []string + TypeField interface{} `mapstructure:"type" json:"type,omitempty"` // string | []string UniqueItemsField bool `mapstructure:"uniqueItems" json:"uniqueItems,omitempty"` WriteOnlyField bool `mapstructure:"writeOnly" json:"writeOnly,omitempty"` From 49da04a7aed410a5d83242aece7899b5173c251d Mon Sep 17 00:00:00 2001 From: Sergio Moya <1083296+smoya@users.noreply.github.com> Date: Fri, 23 Jul 2021 18:48:32 +0200 Subject: [PATCH 3/4] operation.Messages is now aware of oneOf field --- asyncapi/document.go | 14 ++++-- asyncapi/v2/v2.go | 78 ++++++++++++++++++++++------- asyncapi/v2/v2_test.go | 84 +++++++++++++++++++++++++++++++ asyncapi/v2/validation.go | 32 ++++++++---- asyncapi/v2/validation_test.go | 91 ++++++++++++++++++++++++---------- 5 files changed, 240 insertions(+), 59 deletions(-) create mode 100644 asyncapi/v2/v2_test.go diff --git a/asyncapi/document.go b/asyncapi/document.go index ea05418..63e0b0a 100644 --- a/asyncapi/document.go +++ b/asyncapi/document.go @@ -80,14 +80,22 @@ type Message interface { Payload() Schema } +// FalsifiableSchema is a variadic type used for some Schema fields. +// For example, additionalProperties value can be either `false` or a Schema. +type FalsifiableSchema interface { + IsFalse() bool + IsSchema() bool + Schema() Schema +} + // Schema is an object that allows the definition of input and output data types. // These types can be objects, but also primitives and arrays. // This object is a superset of the JSON Schema Specification Draft 07. type Schema interface { Extendable ID() string - AdditionalItems() Schema - AdditionalProperties() Schema // TODO (boolean | Schema) + AdditionalItems() FalsifiableSchema + AdditionalProperties() FalsifiableSchema // TODO (boolean | Schema) AllOf() []Schema AnyOf() []Schema CircularProps() []string @@ -121,7 +129,7 @@ type Schema interface { MinProperties() *float64 MultipleOf() *float64 Not() Schema - OneOf() Schema + OneOf() []Schema Pattern() string PatternProperties() map[string]Schema Properties() map[string]Schema diff --git a/asyncapi/v2/v2.go b/asyncapi/v2/v2.go index f542f50..4890c87 100644 --- a/asyncapi/v2/v2.go +++ b/asyncapi/v2/v2.go @@ -327,8 +327,21 @@ func (o Operation) IsClientSubscribing() bool { func (o Operation) Messages() []asyncapi.Message { if o.MessageField != nil { - return []asyncapi.Message{o.MessageField} + if len(o.MessageField.Payload().OneOf()) == 0 { + return []asyncapi.Message{o.MessageField} + } + + var msgs []asyncapi.Message + for _, payload := range o.MessageField.Payload().OneOf() { + p := payload.(*Schema) + msgs = append(msgs, Message{ + PayloadField: p, + }) + } + + return msgs } + return nil } @@ -405,10 +418,40 @@ func (s Schemas) ToInterface(dst map[string]asyncapi.Schema) map[string]asyncapi return dst } +type FalsifiableSchema struct { + val interface{} +} + +// NewFalsifiableSchema creates a new FalsifiableSchema. +func NewFalsifiableSchema(val interface{}) *FalsifiableSchema { + if val == nil { + return nil + } + return &FalsifiableSchema{val: val} +} + +func (f FalsifiableSchema) IsFalse() bool { + _, ok := f.val.(bool) + return ok +} + +func (f FalsifiableSchema) IsSchema() bool { + _, ok := f.val.(*Schema) + return ok +} + +func (f FalsifiableSchema) Schema() asyncapi.Schema { + if f.IsSchema() { + return f.val.(*Schema) + } + + return nil +} + type Schema struct { Extendable - AdditionalItemsField *Schema `mapstructure:"additionalItems" json:"additionalItems,omitempty"` - AdditionalPropertiesField *Schema `mapstructure:"additionalProperties" json:"additionalProperties,omitempty"` + AdditionalItemsField interface{} `mapstructure:"additionalItems" json:"additionalItems,omitempty"` + AdditionalPropertiesField interface{} `mapstructure:"additionalProperties" json:"additionalProperties,omitempty"` // Schema || false AllOfField []asyncapi.Schema `mapstructure:"allOf" json:"allOf,omitempty"` AnyOfField []asyncapi.Schema `mapstructure:"anyOf" json:"anyOf,omitempty"` ConstField interface{} `mapstructure:"const" json:"const,omitempty"` @@ -440,7 +483,7 @@ type Schema struct { MinPropertiesField *float64 `mapstructure:"minProperties" json:"minProperties,omitempty"` MultipleOfField *float64 `mapstructure:"multipleOf" json:"multipleOf,omitempty"` NotField *Schema `mapstructure:"not" json:"not,omitempty"` - OneOfField *Schema `mapstructure:"oneOf" json:"oneOf,omitempty"` + OneOfField []asyncapi.Schema `mapstructure:"oneOf" json:"oneOf,omitempty"` PatternField string `mapstructure:"pattern" json:"pattern,omitempty"` PatternPropertiesField Schemas `mapstructure:"patternProperties" json:"patternProperties,omitempty"` PropertiesField Schemas `mapstructure:"properties" json:"properties,omitempty"` @@ -454,18 +497,18 @@ type Schema struct { WriteOnlyField bool `mapstructure:"writeOnly" json:"writeOnly,omitempty"` // cached converted map[string]asyncapi.Schema from map[string]*Schema - propertiesFieldMap map[string]asyncapi.Schema `json:"-"` - patternPropertiesFieldMap map[string]asyncapi.Schema `json:"-"` - DefinitionsFieldMap map[string]asyncapi.Schema `json:"-"` - DependenciesFieldMap map[string]asyncapi.Schema `json:"-"` + propertiesFieldMap map[string]asyncapi.Schema + patternPropertiesFieldMap map[string]asyncapi.Schema + definitionsFieldMap map[string]asyncapi.Schema + dependenciesFieldMap map[string]asyncapi.Schema } -func (s *Schema) AdditionalItems() asyncapi.Schema { - return s.AdditionalItemsField +func (s *Schema) AdditionalItems() asyncapi.FalsifiableSchema { + return NewFalsifiableSchema(s.AdditionalItemsField) } -func (s *Schema) AdditionalProperties() asyncapi.Schema { - return s.AdditionalPropertiesField +func (s *Schema) AdditionalProperties() asyncapi.FalsifiableSchema { + return NewFalsifiableSchema(s.AdditionalPropertiesField) } func (s *Schema) AllOf() []asyncapi.Schema { @@ -505,14 +548,14 @@ func (s *Schema) Default() interface{} { } func (s *Schema) Definitions() map[string]asyncapi.Schema { - s.DefinitionsFieldMap = s.DefinitionsField.ToInterface(s.DefinitionsFieldMap) - return s.DefinitionsFieldMap + s.definitionsFieldMap = s.DefinitionsField.ToInterface(s.definitionsFieldMap) + return s.definitionsFieldMap } func (s *Schema) Dependencies() map[string]asyncapi.Schema { // TODO Map[string, Schema|string[]] - s.DependenciesFieldMap = s.DependenciesField.ToInterface(s.DependenciesFieldMap) - return s.DependenciesFieldMap + s.dependenciesFieldMap = s.DependenciesField.ToInterface(s.dependenciesFieldMap) + return s.dependenciesFieldMap } func (s *Schema) Deprecated() bool { @@ -616,8 +659,7 @@ func (s *Schema) Not() asyncapi.Schema { return s.NotField } -func (s *Schema) OneOf() asyncapi.Schema { - // TODO Schema[] +func (s *Schema) OneOf() []asyncapi.Schema { return s.OneOfField } diff --git a/asyncapi/v2/v2_test.go b/asyncapi/v2/v2_test.go new file mode 100644 index 0000000..cadb599 --- /dev/null +++ b/asyncapi/v2/v2_test.go @@ -0,0 +1,84 @@ +package v2 + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/asyncapi/event-gateway/asyncapi" +) + +func TestMessage_OneOfPayload_MultipleMessages(t *testing.T) { + oneOfSchemas := []asyncapi.Schema{ + &Schema{ + PropertiesField: map[string]*Schema{ + "schemaOnefieldOne": {}, + "schemaOnefieldTwo": {}, + }, + }, + &Schema{ + PropertiesField: map[string]*Schema{ + "schemaTwofieldOne": {}, + "schemaTWofieldTwo": {}, + }, + }, + } + msg := &Message{ + PayloadField: &Schema{ + OneOfField: oneOfSchemas, + }, + } + + msgs := NewSubscribeOperation(msg).Messages() + require.Len(t, msgs, 2) + assert.Equal(t, oneOfSchemas[0], msgs[0].Payload()) + assert.Equal(t, oneOfSchemas[1], msgs[1].Payload()) +} + +func TestMessage_PlainPayload_OneMessage(t *testing.T) { + msg := &Message{ + PayloadField: &Schema{ + PropertiesField: map[string]*Schema{ + "schemaFieldOne": {}, + "schemaFieldfieldTwo": {}, + }, + }, + } + + msgs := NewSubscribeOperation(msg).Messages() + require.Len(t, msgs, 1) + assert.Equal(t, msg, msgs[0]) +} + +func TestSchema_AdditionalProperties(t *testing.T) { + schema := &Schema{} + assert.Nil(t, schema.AdditionalProperties()) + + schema = &Schema{AdditionalPropertiesField: false} + assert.True(t, schema.AdditionalProperties().IsFalse()) + assert.False(t, schema.AdditionalProperties().IsSchema()) + assert.Nil(t, schema.AdditionalProperties().Schema()) + + field := &Schema{TypeField: "string"} + schema = &Schema{AdditionalPropertiesField: field} + assert.False(t, schema.AdditionalProperties().IsFalse()) + assert.True(t, schema.AdditionalProperties().IsSchema()) + assert.Equal(t, field, schema.AdditionalProperties().Schema()) +} + +func TestSchema_AdditionalItems(t *testing.T) { + schema := &Schema{} + assert.Nil(t, schema.AdditionalItems()) + + schema = &Schema{AdditionalItemsField: false} + assert.True(t, schema.AdditionalItems().IsFalse()) + assert.False(t, schema.AdditionalItems().IsSchema()) + assert.Nil(t, schema.AdditionalItems().Schema()) + + field := &Schema{TypeField: "string"} + schema = &Schema{AdditionalItemsField: field} + assert.False(t, schema.AdditionalItems().IsFalse()) + assert.True(t, schema.AdditionalItems().IsSchema()) + assert.Equal(t, field, schema.AdditionalItems().Schema()) +} diff --git a/asyncapi/v2/validation.go b/asyncapi/v2/validation.go index 114904c..fcb00b8 100644 --- a/asyncapi/v2/validation.go +++ b/asyncapi/v2/validation.go @@ -3,6 +3,7 @@ package v2 import ( "encoding/json" "fmt" + "strings" "github.com/asyncapi/event-gateway/asyncapi" "github.com/asyncapi/event-gateway/proxy" @@ -18,23 +19,32 @@ func FromDocJSONSchemaMessageValidator(doc asyncapi.Document) (proxy.MessageVali continue } - // Assuming there is only one message per operation as per Asyncapi 2.x.x. - // See https://github.com/asyncapi/event-gateway/issues/10 - if len(o.Messages()) > 1 { - return nil, fmt.Errorf("can not generate message validation for operation %s. Reason: the operation has more than one message and we can't correlate which one is it", o.ID()) - } - if len(o.Messages()) == 0 { return nil, fmt.Errorf("can not generate message validation for operation %s. Reason:. Operation has no message. This is totally unexpected", o.ID()) } - // Assuming there is only one message per operation and one operation of a particular type per Channel. - // See https://github.com/asyncapi/event-gateway/issues/10 - msg := o.Messages()[0] + var payload asyncapi.Schema + var messageNames string + if len(o.Messages()) > 1 { + // Meaning message payload is a Schema containing several payloads as `oneOf`. + // Generating back just one Schema adding all payloads to oneOf field. + msgs := o.Messages() + oneOfSchemas := make([]asyncapi.Schema, len(msgs)) + names := make([]string, len(msgs)) + for i, msg := range msgs { + oneOfSchemas[i] = msg.Payload() + names[i] = msg.Name() + } + payload = &Schema{OneOfField: oneOfSchemas} + messageNames = strings.Join(names, ", ") + } else { + payload = o.Messages()[0].Payload() + messageNames = o.Messages()[0].Name() + } - raw, err := json.Marshal(msg.Payload()) + raw, err := json.Marshal(payload) if err != nil { - return nil, fmt.Errorf("error marshaling message payload for generating json schema for validation. Operation: %s, Message: %s", o.ID(), msg.Name()) + return nil, fmt.Errorf("error marshaling message payload for generating json schema for validation. Operation: %s, Messages: %s", o.ID(), messageNames) } messageSchemas[c.ID()] = gojsonschema.NewBytesLoader(raw) diff --git a/asyncapi/v2/validation_test.go b/asyncapi/v2/validation_test.go index cefaa0f..a69b674 100644 --- a/asyncapi/v2/validation_test.go +++ b/asyncapi/v2/validation_test.go @@ -3,59 +3,96 @@ package v2 import ( "testing" + "github.com/asyncapi/event-gateway/asyncapi" + "github.com/asyncapi/event-gateway/proxy" "github.com/stretchr/testify/assert" ) func TestFromDocJsonSchemaMessageValidator(t *testing.T) { - msg := &Message{ - PayloadField: &Schema{ - TypeField: "object", - PropertiesField: Schemas{ - "AnIntergerField": &Schema{ - Extendable: Extendable{}, - MaximumField: refFloat64(10), - MinimumField: refFloat64(3), - RequiredField: []string{"AnIntergerField"}, - TypeField: "number", - }, - }, - }, - } - channel := NewChannel("test") - channel.Subscribe = NewSubscribeOperation(msg) - - doc := Document{ - Extendable: Extendable{}, - ChannelsField: map[string]Channel{ - "test": *channel, - }, - } - tests := []struct { name string valid bool + schema *Schema payload []byte }{ { - name: "Valid payload", + name: "Valid payload", + schema: &Schema{ + TypeField: "object", + PropertiesField: Schemas{ + "AnIntergerField": &Schema{ + MaximumField: refFloat64(10), + MinimumField: refFloat64(3), + RequiredField: []string{"AnIntergerField"}, + TypeField: "number", + }, + }, + }, payload: []byte(`{"AnIntergerField": 5}`), valid: true, }, { - name: "Invalid payload", + name: "Valid multiple payloads", + schema: &Schema{ + TypeField: "object", + OneOfField: []asyncapi.Schema{ + &Schema{ + PropertiesField: Schemas{ + "AnIntergerField": &Schema{ + MaximumField: refFloat64(10), + MinimumField: refFloat64(3), + RequiredField: []string{"AnIntergerField"}, + TypeField: "number", + }, + }, + AdditionalPropertiesField: false, + }, + &Schema{ + PropertiesField: Schemas{ + "AStringField": &Schema{ + RequiredField: []string{"AStringField"}, + TypeField: "string", + }, + }, + AdditionalPropertiesField: false, + }, + }, + }, + payload: []byte(`{"AStringField": "hello!"}`), + valid: true, + }, + { + name: "Invalid payload", + schema: &Schema{ + TypeField: "object", + PropertiesField: Schemas{ + "AnIntergerField": &Schema{ + MaximumField: refFloat64(10), + MinimumField: refFloat64(3), + RequiredField: []string{"AnIntergerField"}, + TypeField: "number", + }, + }, + }, payload: []byte(`{"AnIntergerField": 1}`), valid: false, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + // Doc generation + channel := NewChannel(t.Name()) + channel.Subscribe = NewSubscribeOperation(&Message{PayloadField: test.schema}) + doc := Document{ChannelsField: map[string]Channel{t.Name(): *channel}} + + // Test validator, err := FromDocJSONSchemaMessageValidator(doc) assert.NoError(t, err) msg := &proxy.Message{ Context: proxy.MessageContext{ - Channel: "test", + Channel: t.Name(), }, Value: test.payload, } From b4c815ec48c3a62d11be38894b67493b524bd8cb Mon Sep 17 00:00:00 2001 From: Sergio Moya <1083296+smoya@users.noreply.github.com> Date: Mon, 26 Jul 2021 15:26:55 +0200 Subject: [PATCH 4/4] support Operation message to be oneOf --- asyncapi/v2/decode_test.go | 118 +++++++++++++++++++++---------------- asyncapi/v2/v2.go | 62 +++++++++++-------- asyncapi/v2/v2_test.go | 64 +++++++++----------- 3 files changed, 135 insertions(+), 109 deletions(-) diff --git a/asyncapi/v2/decode_test.go b/asyncapi/v2/decode_test.go index 197fd62..a0f4a84 100644 --- a/asyncapi/v2/decode_test.go +++ b/asyncapi/v2/decode_test.go @@ -78,22 +78,35 @@ channels: summary: Inform about environmental lighting conditions for a particular streetlight. operationId: onLightMeasured message: - name: LightMeasured - payload: - type: object - properties: - id: - type: integer - minimum: 0 - description: Id of the streetlight. - lumens: - type: integer - minimum: 0 - description: Light intensity measured in lumens. - sentAt: - type: string - format: date-time - description: Date and time when the message was sent.`) + oneOf: + - $ref: '#/components/messages/lightMeasured' + - $ref: '#/components/messages/lightMeasured2' +components: + messages: + lightMeasured: + name: LightMeasured + payload: + $ref: "#/components/schemas/lightMeasuredPayload" + lightMeasured2: + name: LightMeasured + payload: + $ref: "#/components/schemas/lightMeasuredPayload" + schemas: + lightMeasuredPayload: + type: object + properties: + id: + type: integer + minimum: 0 + description: Id of the streetlight. + lumens: + type: integer + minimum: 0 + description: Light intensity measured in lumens. + sentAt: + type: string + format: date-time + description: Date and time when the message was sent.`) doc := new(Document) require.NoError(t, Decode(raw, doc)) @@ -116,7 +129,7 @@ channels: assert.Len(t, doc.ApplicationPublishableChannels(), 1) assert.Len(t, doc.ApplicationPublishOperations(), 1) - assert.Len(t, doc.ApplicationPublishableMessages(), 1) + assert.Len(t, doc.ApplicationPublishableMessages(), 2) assert.Empty(t, doc.ApplicationSubscribableChannels()) assert.Empty(t, doc.ApplicationSubscribeOperations()) @@ -124,7 +137,7 @@ channels: assert.Len(t, doc.ClientSubscribableChannels(), 1) assert.Len(t, doc.ClientSubscribeOperations(), 1) - assert.Len(t, doc.ClientSubscribableMessages(), 1) + assert.Len(t, doc.ClientSubscribableMessages(), 2) assert.Empty(t, doc.ClientPublishableChannels()) assert.Empty(t, doc.ClientPublishOperations()) @@ -149,40 +162,43 @@ channels: assert.Equal(t, "onLightMeasured", operations[0].ID()) messages := operations[0].Messages() - require.Len(t, messages, 1) - - assert.Equal(t, "LightMeasured", messages[0].Name()) - assert.False(t, messages[0].HasSummary()) - assert.False(t, messages[0].HasDescription()) - assert.False(t, messages[0].HasTitle()) - assert.Empty(t, messages[0].ContentType()) - - payload := messages[0].Payload() - require.NotNil(t, payload) - - assert.Equal(t, []string{"object"}, payload.Type()) - properties := payload.Properties() - require.Len(t, properties, 3) - - expectedProperties := map[string]asyncapi.Schema{ - "id": &Schema{ - DescriptionField: "Id of the streetlight.", - MinimumField: refFloat64(0), - TypeField: "integer", - }, - "lumens": &Schema{ - DescriptionField: "Light intensity measured in lumens.", - MinimumField: refFloat64(0), - TypeField: "integer", - }, - "sentAt": &Schema{ - DescriptionField: "Date and time when the message was sent.", - FormatField: "date-time", - TypeField: "string", - }, - } + require.Len(t, messages, 2) + + for i := 0; i < 2; i++ { + msg := messages[i] + assert.Equal(t, "LightMeasured", msg.Name()) + assert.False(t, msg.HasSummary()) + assert.False(t, msg.HasDescription()) + assert.False(t, msg.HasTitle()) + assert.Empty(t, msg.ContentType()) + + payload := msg.Payload() + require.NotNil(t, payload) + + assert.Equal(t, []string{"object"}, payload.Type()) + properties := payload.Properties() + require.Len(t, properties, 3) + + expectedProperties := map[string]asyncapi.Schema{ + "id": &Schema{ + DescriptionField: "Id of the streetlight.", + MinimumField: refFloat64(0), + TypeField: "integer", + }, + "lumens": &Schema{ + DescriptionField: "Light intensity measured in lumens.", + MinimumField: refFloat64(0), + TypeField: "integer", + }, + "sentAt": &Schema{ + DescriptionField: "Date and time when the message was sent.", + FormatField: "date-time", + TypeField: "string", + }, + } - assert.Equal(t, expectedProperties, properties) + assert.Equal(t, expectedProperties, properties) + } } func refFloat64(v float64) *float64 { diff --git a/asyncapi/v2/v2.go b/asyncapi/v2/v2.go index 4890c87..7fe0c72 100644 --- a/asyncapi/v2/v2.go +++ b/asyncapi/v2/v2.go @@ -258,8 +258,8 @@ type SubscribeOperation struct { } // NewSubscribeOperation creates a new SubscribeOperation. Useful for testing. -func NewSubscribeOperation(msg *Message) *SubscribeOperation { - return &SubscribeOperation{Operation: *NewOperation(OperationTypeSubscribe, msg)} +func NewSubscribeOperation(msgs ...*Message) *SubscribeOperation { + return &SubscribeOperation{Operation: *NewOperation(OperationTypeSubscribe, msgs...)} } func (o SubscribeOperation) MapStructureDefaults() map[string]interface{} { @@ -269,8 +269,8 @@ func (o SubscribeOperation) MapStructureDefaults() map[string]interface{} { } // NewPublishOperation creates a new PublishOperation. Useful for testing. -func NewPublishOperation(msg *Message) *PublishOperation { - return &PublishOperation{Operation: *NewOperation(OperationTypePublish, msg)} +func NewPublishOperation(msgs ...*Message) *PublishOperation { + return &PublishOperation{Operation: *NewOperation(OperationTypePublish, msgs...)} } type PublishOperation struct { @@ -284,15 +284,17 @@ func (o PublishOperation) MapStructureDefaults() map[string]interface{} { } // NewOperation creates a new Operation. Useful for testing. -func NewOperation(operationType asyncapi.OperationType, msg *Message) *Operation { +func NewOperation(operationType asyncapi.OperationType, msgs ...*Message) *Operation { op := &Operation{ OperationType: operationType, } - if msg != nil { - op.MessageField = msg + if len(msgs) == 0 { + return op } + op.MessageField = *NewMessages(msgs) + return op } @@ -300,7 +302,7 @@ type Operation struct { Extendable Describable `mapstructure:",squash"` OperationIDField string `mapstructure:"operationId"` - MessageField *Message `mapstructure:"message"` + MessageField Messages `mapstructure:"message"` OperationType asyncapi.OperationType `mapstructure:"operationType"` // set by hook SummaryField string `mapstructure:"summary"` } @@ -326,23 +328,13 @@ func (o Operation) IsClientSubscribing() bool { } func (o Operation) Messages() []asyncapi.Message { - if o.MessageField != nil { - if len(o.MessageField.Payload().OneOf()) == 0 { - return []asyncapi.Message{o.MessageField} - } - - var msgs []asyncapi.Message - for _, payload := range o.MessageField.Payload().OneOf() { - p := payload.(*Schema) - msgs = append(msgs, Message{ - PayloadField: p, - }) - } + msgs := o.MessageField.Messages() - return msgs + convertedMsgs := make([]asyncapi.Message, len(msgs)) // Go lack of covariance :/ + for i, m := range msgs { + convertedMsgs[i] = m } - - return nil + return convertedMsgs } func (o Operation) Type() asyncapi.OperationType { @@ -358,6 +350,30 @@ func (o Operation) HasSummary() bool { return o.SummaryField != "" } +// Messages is a variadic type for Message object, which can be either one message or oneOf. +// See https://www.asyncapi.com/docs/specifications/v2.0.0#operationObject. +type Messages struct { + Message `mapstructure:",squash"` + OneOfField []*Message `mapstructure:"oneOf"` +} + +func (m *Messages) Messages() []*Message { + if len(m.OneOfField) > 0 { + return m.OneOfField + } + + return []*Message{&m.Message} +} + +// NewMessages creates Messages. +func NewMessages(msgs []*Message) *Messages { + if len(msgs) == 1 { + return &Messages{Message: *msgs[0]} + } + + return &Messages{OneOfField: msgs} +} + type Message struct { Extendable Describable `mapstructure:",squash"` diff --git a/asyncapi/v2/v2_test.go b/asyncapi/v2/v2_test.go index cadb599..6dcb265 100644 --- a/asyncapi/v2/v2_test.go +++ b/asyncapi/v2/v2_test.go @@ -3,52 +3,33 @@ package v2 import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/asyncapi/event-gateway/asyncapi" + "github.com/stretchr/testify/require" - "github.com/asyncapi/event-gateway/asyncapi" + "github.com/stretchr/testify/assert" ) -func TestMessage_OneOfPayload_MultipleMessages(t *testing.T) { - oneOfSchemas := []asyncapi.Schema{ - &Schema{ - PropertiesField: map[string]*Schema{ - "schemaOnefieldOne": {}, - "schemaOnefieldTwo": {}, - }, - }, - &Schema{ - PropertiesField: map[string]*Schema{ - "schemaTwofieldOne": {}, - "schemaTWofieldTwo": {}, - }, - }, - } - msg := &Message{ - PayloadField: &Schema{ - OneOfField: oneOfSchemas, - }, +func TestMessage_OneOf(t *testing.T) { + expectedMessages := []*Message{ + generateTestMessage(), + generateTestMessage(), + generateTestMessage(), } - msgs := NewSubscribeOperation(msg).Messages() - require.Len(t, msgs, 2) - assert.Equal(t, oneOfSchemas[0], msgs[0].Payload()) - assert.Equal(t, oneOfSchemas[1], msgs[1].Payload()) + op := NewSubscribeOperation(expectedMessages...) + assert.IsType(t, Messages{}, op.MessageField) + assert.EqualValues(t, []asyncapi.Message{expectedMessages[0], expectedMessages[1], expectedMessages[2]}, op.Messages()) // Go lack of covariance :/ } func TestMessage_PlainPayload_OneMessage(t *testing.T) { - msg := &Message{ - PayloadField: &Schema{ - PropertiesField: map[string]*Schema{ - "schemaFieldOne": {}, - "schemaFieldfieldTwo": {}, - }, - }, - } + expectedMsg := generateTestMessage() - msgs := NewSubscribeOperation(msg).Messages() + op := NewSubscribeOperation(expectedMsg) + assert.IsType(t, Messages{}, op.MessageField) + msgs := op.Messages() require.Len(t, msgs, 1) - assert.Equal(t, msg, msgs[0]) + assert.Equal(t, expectedMsg, msgs[0]) } func TestSchema_AdditionalProperties(t *testing.T) { @@ -82,3 +63,16 @@ func TestSchema_AdditionalItems(t *testing.T) { assert.True(t, schema.AdditionalItems().IsSchema()) assert.Equal(t, field, schema.AdditionalItems().Schema()) } + +func generateTestMessage() *Message { + return &Message{ + PayloadField: &Schema{ + TypeField: "object", + PropertiesField: Schemas{ + "fieldOne": &Schema{ + TypeField: "string", + }, + }, + }, + } +}