diff --git a/context.go b/context.go index c19673c18..382335e83 100644 --- a/context.go +++ b/context.go @@ -88,7 +88,7 @@ func (nc *Conn) oldRequestWithContext(ctx context.Context, subj string, hdr, dat inbox := nc.NewInbox() ch := make(chan *Msg, RequestChanLen) - s, err := nc.subscribe(inbox, _EMPTY_, nil, ch, true, nil) + s, err := nc.subscribe(inbox, _EMPTY_, nil, ch, nil, true, nil) if err != nil { return nil, err } diff --git a/enc.go b/enc.go index 78bcc219f..34a3fae7f 100644 --- a/enc.go +++ b/enc.go @@ -258,7 +258,7 @@ func (c *EncodedConn) subscribe(subject, queue string, cb Handler) (*Subscriptio cbValue.Call(oV) } - return c.Conn.subscribe(subject, queue, natsCB, nil, false, nil) + return c.Conn.subscribe(subject, queue, natsCB, nil, nil, false, nil) } // FlushTimeout allows a Flush operation to have an associated timeout. diff --git a/js.go b/js.go index e024fae0a..fd263f0aa 100644 --- a/js.go +++ b/js.go @@ -1839,7 +1839,7 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync, ocb := cb cb = func(m *Msg) { ocb(m); m.Ack() } } - sub, err := nc.subscribe(deliver, queue, cb, ch, isSync, jsi) + sub, err := nc.subscribe(deliver, queue, cb, ch, nil, isSync, jsi) if err != nil { return nil, err } @@ -1910,7 +1910,7 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync, jsi.hbi = info.Config.Heartbeat // Recreate the subscription here. - sub, err = nc.subscribe(jsi.deliver, queue, cb, ch, isSync, jsi) + sub, err = nc.subscribe(jsi.deliver, queue, cb, ch, nil, isSync, jsi) if err != nil { return nil, err } diff --git a/nats.go b/nats.go index 3f12d61e2..09a17ce39 100644 --- a/nats.go +++ b/nats.go @@ -109,6 +109,7 @@ var ( ErrAuthorization = errors.New("nats: authorization violation") ErrAuthExpired = errors.New("nats: authentication expired") ErrAuthRevoked = errors.New("nats: authentication revoked") + ErrPermissionViolation = errors.New("nats: permissions violation") ErrAccountAuthExpired = errors.New("nats: account authentication expired") ErrNoServers = errors.New("nats: no servers available for connection") ErrJsonParse = errors.New("nats: connect message, json parse error") @@ -510,6 +511,11 @@ type Options struct { // SkipHostLookup skips the DNS lookup for the server hostname. SkipHostLookup bool + + // PermissionErrOnSubscribe - if set to true, the client will return ErrPermissionViolation + // from SubscribeSync if the server returns a permissions error for a subscription. + // Defaults to false. + PermissionErrOnSubscribe bool } const ( @@ -618,17 +624,19 @@ type Subscription struct { // For holding information about a JetStream consumer. jsi *jsSub - delivered uint64 - max uint64 - conn *Conn - mcb MsgHandler - mch chan *Msg - closed bool - sc bool - connClosed bool - draining bool - status SubStatus - statListeners map[chan SubStatus][]SubStatus + delivered uint64 + max uint64 + conn *Conn + mcb MsgHandler + mch chan *Msg + errCh chan (error) + closed bool + sc bool + connClosed bool + draining bool + status SubStatus + statListeners map[chan SubStatus][]SubStatus + permissionsErr error // Type of Subscription typ SubscriptionType @@ -1401,6 +1409,13 @@ func SkipHostLookup() Option { } } +func PermissionErrOnSubscribe(enabled bool) Option { + return func(o *Options) error { + o.PermissionErrOnSubscribe = enabled + return nil + } +} + // TLSHandshakeFirst is an Option to perform the TLS handshake first, that is // before receiving the INFO protocol. This requires the server to also be // configured with such option, otherwise the connection will fail. @@ -3435,6 +3450,9 @@ slowConsumer: } } +var permissionsRe = regexp.MustCompile(`Subscription to "(\S+)"`) +var permissionsQueueRe = regexp.MustCompile(`using queue "(\S+)"`) + // processTransientError is called when the server signals a non terminal error // which does not close the connection or trigger a reconnect. // This will trigger the async error callback if set. @@ -3444,6 +3462,27 @@ slowConsumer: func (nc *Conn) processTransientError(err error) { nc.mu.Lock() nc.err = err + if errors.Is(err, ErrPermissionViolation) { + matches := permissionsRe.FindStringSubmatch(err.Error()) + if len(matches) >= 2 { + queueMatches := permissionsQueueRe.FindStringSubmatch(err.Error()) + var q string + if len(queueMatches) >= 2 { + q = queueMatches[1] + } + subject := matches[1] + for _, sub := range nc.subs { + if sub.Subject == subject && sub.Queue == q && sub.permissionsErr == nil { + sub.mu.Lock() + if sub.errCh != nil { + sub.errCh <- err + } + sub.permissionsErr = err + sub.mu.Unlock() + } + } + } + } if nc.Opts.AsyncErrorCB != nil { nc.ach.push(func() { nc.Opts.AsyncErrorCB(nc, nil, err) }) } @@ -3685,7 +3724,7 @@ func (nc *Conn) processErr(ie string) { } else if e == MAX_CONNECTIONS_ERR { close = nc.processOpErr(ErrMaxConnectionsExceeded) } else if strings.HasPrefix(e, PERMISSIONS_ERR) { - nc.processTransientError(fmt.Errorf("nats: %s", ne)) + nc.processTransientError(fmt.Errorf("%w: %s", ErrPermissionViolation, ne)) } else if strings.HasPrefix(e, MAX_SUBSCRIPTIONS_ERR) { nc.processTransientError(ErrMaxSubscriptionsExceeded) } else if authErr := checkAuthError(e); authErr != nil { @@ -4042,7 +4081,7 @@ func (nc *Conn) createNewRequestAndSend(subj string, hdr, data []byte) (chan *Ms // Create the response subscription we will use for all new style responses. // This will be on an _INBOX with an additional terminal token. The subscription // will be on a wildcard. - s, err := nc.subscribeLocked(nc.respSub, _EMPTY_, nc.respHandler, nil, false, nil) + s, err := nc.subscribeLocked(nc.respSub, _EMPTY_, nc.respHandler, nil, nil, false, nil) if err != nil { nc.mu.Unlock() return nil, token, err @@ -4140,7 +4179,7 @@ func (nc *Conn) oldRequest(subj string, hdr, data []byte, timeout time.Duration) inbox := nc.NewInbox() ch := make(chan *Msg, RequestChanLen) - s, err := nc.subscribe(inbox, _EMPTY_, nil, ch, true, nil) + s, err := nc.subscribe(inbox, _EMPTY_, nil, ch, nil, true, nil) if err != nil { return nil, err } @@ -4246,14 +4285,14 @@ func (nc *Conn) respToken(respInbox string) string { // since it can't match more than one token. // Messages will be delivered to the associated MsgHandler. func (nc *Conn) Subscribe(subj string, cb MsgHandler) (*Subscription, error) { - return nc.subscribe(subj, _EMPTY_, cb, nil, false, nil) + return nc.subscribe(subj, _EMPTY_, cb, nil, nil, false, nil) } // ChanSubscribe will express interest in the given subject and place // all messages received on the channel. // You should not close the channel until sub.Unsubscribe() has been called. func (nc *Conn) ChanSubscribe(subj string, ch chan *Msg) (*Subscription, error) { - return nc.subscribe(subj, _EMPTY_, nil, ch, false, nil) + return nc.subscribe(subj, _EMPTY_, nil, ch, nil, false, nil) } // ChanQueueSubscribe will express interest in the given subject. @@ -4263,7 +4302,7 @@ func (nc *Conn) ChanSubscribe(subj string, ch chan *Msg) (*Subscription, error) // You should not close the channel until sub.Unsubscribe() has been called. // Note: This is the same than QueueSubscribeSyncWithChan. func (nc *Conn) ChanQueueSubscribe(subj, group string, ch chan *Msg) (*Subscription, error) { - return nc.subscribe(subj, group, nil, ch, false, nil) + return nc.subscribe(subj, group, nil, ch, nil, false, nil) } // SubscribeSync will express interest on the given subject. Messages will @@ -4273,7 +4312,11 @@ func (nc *Conn) SubscribeSync(subj string) (*Subscription, error) { return nil, ErrInvalidConnection } mch := make(chan *Msg, nc.Opts.SubChanLen) - return nc.subscribe(subj, _EMPTY_, nil, mch, true, nil) + var errCh chan error + if nc.Opts.PermissionErrOnSubscribe { + errCh = make(chan error, 100) + } + return nc.subscribe(subj, _EMPTY_, nil, mch, errCh, true, nil) } // QueueSubscribe creates an asynchronous queue subscriber on the given subject. @@ -4281,7 +4324,7 @@ func (nc *Conn) SubscribeSync(subj string) (*Subscription, error) { // only one member of the group will be selected to receive any given // message asynchronously. func (nc *Conn) QueueSubscribe(subj, queue string, cb MsgHandler) (*Subscription, error) { - return nc.subscribe(subj, queue, cb, nil, false, nil) + return nc.subscribe(subj, queue, cb, nil, nil, false, nil) } // QueueSubscribeSync creates a synchronous queue subscriber on the given @@ -4290,7 +4333,11 @@ func (nc *Conn) QueueSubscribe(subj, queue string, cb MsgHandler) (*Subscription // given message synchronously using Subscription.NextMsg(). func (nc *Conn) QueueSubscribeSync(subj, queue string) (*Subscription, error) { mch := make(chan *Msg, nc.Opts.SubChanLen) - return nc.subscribe(subj, queue, nil, mch, true, nil) + var errCh chan error + if nc.Opts.PermissionErrOnSubscribe { + errCh = make(chan error, 100) + } + return nc.subscribe(subj, queue, nil, mch, errCh, true, nil) } // QueueSubscribeSyncWithChan will express interest in the given subject. @@ -4300,7 +4347,7 @@ func (nc *Conn) QueueSubscribeSync(subj, queue string) (*Subscription, error) { // You should not close the channel until sub.Unsubscribe() has been called. // Note: This is the same than ChanQueueSubscribe. func (nc *Conn) QueueSubscribeSyncWithChan(subj, queue string, ch chan *Msg) (*Subscription, error) { - return nc.subscribe(subj, queue, nil, ch, false, nil) + return nc.subscribe(subj, queue, nil, ch, nil, false, nil) } // badSubject will do quick test on whether a subject is acceptable. @@ -4324,16 +4371,16 @@ func badQueue(qname string) bool { } // subscribe is the internal subscribe function that indicates interest in a subject. -func (nc *Conn) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync bool, js *jsSub) (*Subscription, error) { +func (nc *Conn) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, errCh chan (error), isSync bool, js *jsSub) (*Subscription, error) { if nc == nil { return nil, ErrInvalidConnection } nc.mu.Lock() defer nc.mu.Unlock() - return nc.subscribeLocked(subj, queue, cb, ch, isSync, js) + return nc.subscribeLocked(subj, queue, cb, ch, errCh, isSync, js) } -func (nc *Conn) subscribeLocked(subj, queue string, cb MsgHandler, ch chan *Msg, isSync bool, js *jsSub) (*Subscription, error) { +func (nc *Conn) subscribeLocked(subj, queue string, cb MsgHandler, ch chan *Msg, errCh chan (error), isSync bool, js *jsSub) (*Subscription, error) { if nc == nil { return nil, ErrInvalidConnection } @@ -4384,6 +4431,7 @@ func (nc *Conn) subscribeLocked(subj, queue string, cb MsgHandler, ch chan *Msg, } else { // Sync Subscription sub.typ = SyncSubscription sub.mch = ch + sub.errCh = errCh } nc.subsMu.Lock() @@ -4828,16 +4876,92 @@ func (s *Subscription) NextMsg(timeout time.Duration) (*Msg, error) { t := globalTimerPool.Get(timeout) defer globalTimerPool.Put(t) + if s.errCh != nil { + select { + case msg, ok = <-mch: + if !ok { + return nil, s.getNextMsgErr() + } + if err := s.processNextMsgDelivered(msg); err != nil { + return nil, err + } + case err := <-s.errCh: + return nil, err + case <-t.C: + return nil, ErrTimeout + } + } else { + select { + case msg, ok = <-mch: + if !ok { + return nil, s.getNextMsgErr() + } + if err := s.processNextMsgDelivered(msg); err != nil { + return nil, err + } + case <-t.C: + return nil, ErrTimeout + } + } + + return msg, nil +} + +// nextMsgNoTimeout works similarly to Subscription.NextMsg() but will not +// time out. It is only used internally for non-timeout subscription iterator. +func (s *Subscription) nextMsgNoTimeout() (*Msg, error) { + if s == nil { + return nil, ErrBadSubscription + } + + s.mu.Lock() + err := s.validateNextMsgState(false) + if err != nil { + s.mu.Unlock() + return nil, err + } + + // snapshot + mch := s.mch + s.mu.Unlock() + + var ok bool + var msg *Msg + + // If something is available right away, let's optimize that case. select { case msg, ok = <-mch: + if !ok { + return nil, s.getNextMsgErr() + } + if err := s.processNextMsgDelivered(msg); err != nil { + return nil, err + } else { + return msg, nil + } + default: + } + + if s.errCh != nil { + select { + case msg, ok = <-mch: + if !ok { + return nil, s.getNextMsgErr() + } + if err := s.processNextMsgDelivered(msg); err != nil { + return nil, err + } + case err := <-s.errCh: + return nil, err + } + } else { + msg, ok = <-mch if !ok { return nil, s.getNextMsgErr() } if err := s.processNextMsgDelivered(msg); err != nil { return nil, err } - case <-t.C: - return nil, ErrTimeout } return msg, nil @@ -4860,6 +4984,12 @@ func (s *Subscription) validateNextMsgState(pullSubInternal bool) error { if s.mcb != nil { return ErrSyncSubRequired } + // if this subscription previously had a permissions error + // and no reconnect has been attempted, return the permissions error + // since the subscription does not exist on the server + if s.conn.Opts.PermissionErrOnSubscribe && s.permissionsErr != nil { + return s.permissionsErr + } if s.sc { s.changeSubStatus(SubscriptionActive) s.sc = false @@ -5235,6 +5365,9 @@ func (nc *Conn) resendSubscriptions() { for _, s := range subs { adjustedMax := uint64(0) s.mu.Lock() + // when resending subscriptions, the permissions error should be cleared + // since the user may have fixed the permissions issue + s.permissionsErr = nil if s.max > 0 { if s.delivered < s.max { adjustedMax = s.max - s.delivered diff --git a/nats_iter.go b/nats_iter.go new file mode 100644 index 000000000..0f4d1e258 --- /dev/null +++ b/nats_iter.go @@ -0,0 +1,73 @@ +// Copyright 2012-2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.23 +// +build go1.23 + +package nats + +import ( + "errors" + "iter" + "time" +) + +// Msgs returns an iter.Seq2[*Msg, error] that can be used to iterate over +// messages. It can only be used with a subscription that has been created with +// SubscribeSync or QueueSubscribeSync, otherwise it will return an error on the +// first iteration. +// +// The iterator will block until a message is available. The +// subscription will not be closed when the iterator is done. +func (sub *Subscription) Msgs() iter.Seq2[*Msg, error] { + return func(yield func(*Msg, error) bool) { + for { + msg, err := sub.nextMsgNoTimeout() + if err != nil { + yield(nil, err) + return + } + if !yield(msg, nil) { + return + } + + } + } +} + +// MsgsTimeout returns an iter.Seq2[*Msg, error] that can be used to iterate +// over messages. It can only be used with a subscription that has been created +// with SubscribeSync or QueueSubscribeSync, otherwise it will return an error +// on the first iteration. +// +// The iterator will block until a message is available or the timeout is +// reached. If the timeout is reached, the iterator will return nats.ErrTimeout +// but it will not be closed. +func (sub *Subscription) MsgsTimeout(timeout time.Duration) iter.Seq2[*Msg, error] { + return func(yield func(*Msg, error) bool) { + for { + msg, err := sub.NextMsg(timeout) + if err != nil { + if !yield(nil, err) { + return + } + if !errors.Is(err, ErrTimeout) { + return + } + } + if !yield(msg, nil) { + return + } + } + } +} diff --git a/netchan.go b/netchan.go index 3722d9f1b..35d92140c 100644 --- a/netchan.go +++ b/netchan.go @@ -113,5 +113,5 @@ func (c *EncodedConn) bindRecvChan(subject, queue string, channel any) (*Subscri chVal.Send(oPtr) } - return c.Conn.subscribe(subject, queue, cb, nil, false, nil) + return c.Conn.subscribe(subject, queue, cb, nil, nil, false, nil) } diff --git a/rand.go b/rand.go deleted file mode 100644 index 0cdee0acd..000000000 --- a/rand.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2023 The NATS Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !go1.20 -// +build !go1.20 - -// A Go client for the NATS messaging system (https://nats.io). -package nats - -import ( - "math/rand" - "time" -) - -func init() { - // This is not needed since Go 1.20 because now rand.Seed always happens - // by default (uses runtime.fastrand64 instead as source). - rand.Seed(time.Now().UnixNano()) -} diff --git a/test/nats_iter_test.go b/test/nats_iter_test.go new file mode 100644 index 000000000..a762aec19 --- /dev/null +++ b/test/nats_iter_test.go @@ -0,0 +1,337 @@ +// Copyright 2012-2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.23 +// +build go1.23 + +package test + +import ( + "errors" + "fmt" + "os" + "sync" + "testing" + "time" + + "github.com/nats-io/nats.go" +) + +func TestSubscribeIterator(t *testing.T) { + t.Run("with timeout", func(t *testing.T) { + s := RunServerOnPort(-1) + defer s.Shutdown() + + nc, err := nats.Connect(s.ClientURL(), nats.PermissionErrOnSubscribe(true)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + sub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatal("Failed to subscribe: ", err) + } + defer sub.Unsubscribe() + + total := 100 + for i := 0; i < total/2; i++ { + if err := nc.Publish("foo", []byte("Hello")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + } + + // publish some more messages asynchronously + errCh := make(chan error, 1) + go func() { + for i := 0; i < total/2; i++ { + if err := nc.Publish("foo", []byte("Hello")); err != nil { + errCh <- err + return + } + time.Sleep(10 * time.Millisecond) + } + close(errCh) + }() + + received := 0 + for _, err := range sub.MsgsTimeout(100 * time.Millisecond) { + if err != nil { + if !errors.Is(err, nats.ErrTimeout) { + t.Fatalf("Error on subscribe: %v", err) + } + break + } else { + received++ + } + } + if received != total { + t.Fatalf("Expected %d messages, got %d", total, received) + } + }) + + t.Run("no timeout", func(t *testing.T) { + s := RunServerOnPort(-1) + defer s.Shutdown() + + nc, err := nats.Connect(s.ClientURL(), nats.PermissionErrOnSubscribe(true)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + sub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatal("Failed to subscribe: ", err) + } + defer sub.Unsubscribe() + + // Send some messages to ourselves. + total := 100 + for i := 0; i < total/2; i++ { + if err := nc.Publish("foo", []byte("Hello")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + } + + received := 0 + + // publish some more messages asynchronously + errCh := make(chan error, 1) + go func() { + for i := 0; i < total/2; i++ { + if err := nc.Publish("foo", []byte("Hello")); err != nil { + errCh <- err + return + } + time.Sleep(10 * time.Millisecond) + } + close(errCh) + }() + + for _, err := range sub.Msgs() { + if err != nil { + t.Fatalf("Error getting msg: %v", err) + } + received++ + if received >= total { + break + } + } + err = <-errCh + if err != nil { + t.Fatalf("Error on publish: %v", err) + } + _, err = sub.NextMsg(100 * time.Millisecond) + if !errors.Is(err, nats.ErrTimeout) { + t.Fatalf("Expected timeout waiting for next message, got %v", err) + } + }) + + t.Run("permissions violation", func(t *testing.T) { + conf := createConfFile(t, []byte(` + listen: 127.0.0.1:-1 + authorization: { + users = [ + { + user: test + password: test + permissions: { + subscribe: { + deny: "foo" + } + } + } + ] + } + `)) + defer os.Remove(conf) + + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + nc, err := nats.Connect(s.ClientURL(), nats.UserInfo("test", "test"), nats.PermissionErrOnSubscribe(true)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + sub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + defer sub.Unsubscribe() + + errs := make(chan error) + go func() { + var err error + for _, err = range sub.Msgs() { + break + } + errs <- err + }() + + select { + case e := <-errs: + if !errors.Is(e, nats.ErrPermissionViolation) { + t.Fatalf("Expected permissions error, got %v", e) + } + case <-time.After(2 * time.Second): + t.Fatalf("Did not get the permission error") + } + + _, err = sub.NextMsg(100 * time.Millisecond) + if !errors.Is(err, nats.ErrPermissionViolation) { + t.Fatalf("Expected permissions violation error, got %v", err) + } + }) + + t.Run("attempt iterator on async sub", func(t *testing.T) { + s := RunServerOnPort(-1) + defer s.Shutdown() + + nc, err := nats.Connect(s.ClientURL(), nats.PermissionErrOnSubscribe(true)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + sub, err := nc.Subscribe("foo", func(msg *nats.Msg) {}) + if err != nil { + t.Fatal("Failed to subscribe: ", err) + } + defer sub.Unsubscribe() + + for _, err := range sub.MsgsTimeout(100 * time.Millisecond) { + if !errors.Is(err, nats.ErrSyncSubRequired) { + t.Fatalf("Error on subscribe: %v", err) + } + } + for _, err := range sub.Msgs() { + if !errors.Is(err, nats.ErrSyncSubRequired) { + t.Fatalf("Error on subscribe: %v", err) + } + } + }) +} + +func TestQueueSubscribeIterator(t *testing.T) { + t.Run("basic", func(t *testing.T) { + s := RunServerOnPort(-1) + defer s.Shutdown() + + nc, err := nats.Connect(s.ClientURL()) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + subs := make([]*nats.Subscription, 4) + for i := 0; i < 4; i++ { + sub, err := nc.QueueSubscribeSync("foo", "q") + if err != nil { + t.Fatal("Failed to subscribe: ", err) + } + subs[i] = sub + defer sub.Unsubscribe() + } + + // Send some messages to ourselves. + total := 100 + for i := 0; i < total; i++ { + if err := nc.Publish("foo", []byte(fmt.Sprintf("%d", i))); err != nil { + t.Fatalf("Error on publish: %v", err) + } + } + + wg := sync.WaitGroup{} + wg.Add(100) + startWg := sync.WaitGroup{} + startWg.Add(4) + + for i := range subs { + go func(i int) { + startWg.Done() + for _, err := range subs[i].MsgsTimeout(100 * time.Millisecond) { + if err != nil { + break + } + wg.Done() + } + }(i) + } + + startWg.Wait() + + wg.Wait() + + for _, sub := range subs { + if _, err = sub.NextMsg(100 * time.Millisecond); !errors.Is(err, nats.ErrTimeout) { + t.Fatalf("Expected timeout waiting for next message, got %v", err) + } + } + }) + + t.Run("permissions violation", func(t *testing.T) { + conf := createConfFile(t, []byte(` + listen: 127.0.0.1:-1 + authorization: { + users = [ + { + user: test + password: test + permissions: { + subscribe: { + deny: "foo" + } + } + } + ] + } + `)) + defer os.Remove(conf) + + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + nc, err := nats.Connect(s.ClientURL(), nats.UserInfo("test", "test"), nats.PermissionErrOnSubscribe(true)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + sub, err := nc.QueueSubscribeSync("foo", "q") + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + defer sub.Unsubscribe() + + errs := make(chan error) + go func() { + var err error + for _, err = range sub.MsgsTimeout(2 * time.Second) { + break + } + errs <- err + }() + + select { + case e := <-errs: + if !errors.Is(e, nats.ErrPermissionViolation) { + t.Fatalf("Expected permissions error, got %v", e) + } + case <-time.After(2 * time.Second): + t.Fatalf("Did not get the permission error") + } + }) +} diff --git a/test/sub_test.go b/test/sub_test.go index 559efc50c..01b78d0f9 100644 --- a/test/sub_test.go +++ b/test/sub_test.go @@ -1770,3 +1770,92 @@ func TestMaxSubscriptionsExceeded(t *testing.T) { // wait for the server to process the SUBs time.Sleep(100 * time.Millisecond) } + +func TestSubscribeSyncPermissionError(t *testing.T) { + conf := createConfFile(t, []byte(` + listen: 127.0.0.1:-1 + authorization: { + users = [ + { + user: test + password: test + permissions: { + subscribe: { + deny: "foo" + } + } + } + ] + } +`)) + defer os.Remove(conf) + + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + t.Run("PermissionErrOnSubscribe enabled", func(t *testing.T) { + + nc, err := nats.Connect(s.ClientURL(), + nats.UserInfo("test", "test"), + nats.PermissionErrOnSubscribe(true), + nats.ErrorHandler(func(*nats.Conn, *nats.Subscription, error) {})) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + subs := make([]*nats.Subscription, 0, 100) + for i := 0; i < 10; i++ { + var subject string + if i%2 == 0 { + subject = "foo" + } else { + subject = "bar" + } + sub, err := nc.SubscribeSync(subject) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + defer sub.Unsubscribe() + subs = append(subs, sub) + } + + for _, sub := range subs { + _, err = sub.NextMsg(100 * time.Millisecond) + if sub.Subject == "foo" { + if !errors.Is(err, nats.ErrPermissionViolation) { + t.Fatalf("Expected permissions violation error, got %v", err) + } + // subsequent calls should return the same error + _, err = sub.NextMsg(100 * time.Millisecond) + if !errors.Is(err, nats.ErrPermissionViolation) { + t.Fatalf("Expected permissions violation error, got %v", err) + } + } else { + if !errors.Is(err, nats.ErrTimeout) { + t.Fatalf("Expected timeout error, got %v", err) + } + } + } + }) + + t.Run("PermissionErrOnSubscribe disabled", func(t *testing.T) { + nc, err := nats.Connect(s.ClientURL(), nats.UserInfo("test", "test")) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + // Cause a subscribe error + sub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + defer sub.Unsubscribe() + + _, err = sub.NextMsg(100 * time.Millisecond) + if !errors.Is(err, nats.ErrTimeout) { + t.Fatalf("Expected timeout error, got %v", err) + } + }) +}