diff --git a/broadcast/legacy_nats.go b/broadcast/legacy_nats.go index c651659a..324d5578 100644 --- a/broadcast/legacy_nats.go +++ b/broadcast/legacy_nats.go @@ -2,24 +2,46 @@ package broadcast import ( "context" + "fmt" "log/slog" + "strings" "github.com/nats-io/nats.go" nconfig "github.com/anycable/anycable-go/nats" ) +type LegacyNATSConfig struct { + Channel string `toml:"channel"` + NATS *nconfig.NATSConfig `toml:"nats"` +} + +func NewLegacyNATSConfig() LegacyNATSConfig { + return LegacyNATSConfig{ + Channel: "__anycable__", + } +} + +func (c LegacyNATSConfig) ToToml() string { + var result strings.Builder + result.WriteString(fmt.Sprintf("channel = \"%s\"\n", c.Channel)) + + result.WriteString("\n") + + return result.String() +} + type LegacyNATSBroadcaster struct { conn *nats.Conn handler Handler - config *nconfig.NATSConfig + config *LegacyNATSConfig log *slog.Logger } var _ Broadcaster = (*LegacyNATSBroadcaster)(nil) -func NewLegacyNATSBroadcaster(node Handler, c *nconfig.NATSConfig, l *slog.Logger) *LegacyNATSBroadcaster { +func NewLegacyNATSBroadcaster(node Handler, c *LegacyNATSConfig, l *slog.Logger) *LegacyNATSBroadcaster { return &LegacyNATSBroadcaster{ config: c, handler: node, @@ -34,7 +56,7 @@ func (LegacyNATSBroadcaster) IsFanout() bool { func (s *LegacyNATSBroadcaster) Start(done chan (error)) error { connectOptions := []nats.Option{ nats.RetryOnFailedConnect(true), - nats.MaxReconnects(s.config.MaxReconnectAttempts), + nats.MaxReconnects(s.config.NATS.MaxReconnectAttempts), nats.DisconnectErrHandler(func(nc *nats.Conn, err error) { if err != nil { s.log.Warn("connection failed", "error", err) @@ -45,11 +67,11 @@ func (s *LegacyNATSBroadcaster) Start(done chan (error)) error { }), } - if s.config.DontRandomizeServers { + if s.config.NATS.DontRandomizeServers { connectOptions = append(connectOptions, nats.DontRandomize()) } - nc, err := nats.Connect(s.config.Servers, connectOptions...) + nc, err := nats.Connect(s.config.NATS.Servers, connectOptions...) if err != nil { return err diff --git a/broadcast/legacy_nats_test.go b/broadcast/legacy_nats_test.go new file mode 100644 index 00000000..bffd503d --- /dev/null +++ b/broadcast/legacy_nats_test.go @@ -0,0 +1,26 @@ +package broadcast + +import ( + "testing" + + "github.com/BurntSushi/toml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLegacyNATSConfig__ToToml(t *testing.T) { + conf := NewLegacyNATSConfig() + conf.Channel = "_test_" + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "channel = \"_test_\"") + + // Round-trip test + conf2 := NewLegacyNATSConfig() + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} diff --git a/broadcast/legacy_redis.go b/broadcast/legacy_redis.go index 97e3e689..608e3b97 100644 --- a/broadcast/legacy_redis.go +++ b/broadcast/legacy_redis.go @@ -16,6 +16,26 @@ import ( "github.com/gomodule/redigo/redis" ) +type LegacyRedisConfig struct { + Channel string `toml:"channel"` + Redis *rconfig.RedisConfig `toml:"redis"` +} + +func NewLegacyRedisConfig() LegacyRedisConfig { + return LegacyRedisConfig{ + Channel: "__anycable__", + } +} + +func (c LegacyRedisConfig) ToToml() string { + var result strings.Builder + result.WriteString(fmt.Sprintf("channel = \"%s\"\n", c.Channel)) + + result.WriteString("\n") + + return result.String() +} + // LegacyRedisBroadcaster contains information about Redis pubsub connection type LegacyRedisBroadcaster struct { node Handler @@ -33,18 +53,18 @@ type LegacyRedisBroadcaster struct { } // NewLegacyRedisBroadcaster returns new RedisSubscriber struct -func NewLegacyRedisBroadcaster(node Handler, config *rconfig.RedisConfig, l *slog.Logger) *LegacyRedisBroadcaster { +func NewLegacyRedisBroadcaster(node Handler, config *LegacyRedisConfig, l *slog.Logger) *LegacyRedisBroadcaster { return &LegacyRedisBroadcaster{ node: node, - url: config.URL, - sentinels: config.Sentinels, - sentinelDiscoveryInterval: time.Duration(config.SentinelDiscoveryInterval), + url: config.Redis.URL, + sentinels: config.Redis.Sentinels, + sentinelDiscoveryInterval: time.Duration(config.Redis.SentinelDiscoveryInterval), channel: config.Channel, - pingInterval: time.Duration(config.KeepalivePingInterval), + pingInterval: time.Duration(config.Redis.KeepalivePingInterval), reconnectAttempt: 0, - maxReconnectAttempts: config.MaxReconnectAttempts, + maxReconnectAttempts: config.Redis.MaxReconnectAttempts, log: l.With("context", "broadcast").With("provider", "redis"), - tlsVerify: config.TLSVerify, + tlsVerify: config.Redis.TLSVerify, } } diff --git a/broadcast/legacy_redis_test.go b/broadcast/legacy_redis_test.go new file mode 100644 index 00000000..2ebe9120 --- /dev/null +++ b/broadcast/legacy_redis_test.go @@ -0,0 +1,26 @@ +package broadcast + +import ( + "testing" + + "github.com/BurntSushi/toml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLegacyRedisConfig__ToToml(t *testing.T) { + conf := NewLegacyRedisConfig() + conf.Channel = "_test_" + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "channel = \"_test_\"") + + // Round-trip test + conf2 := NewLegacyRedisConfig() + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} diff --git a/broadcast/redis.go b/broadcast/redis.go index b2b289d9..7e453eec 100644 --- a/broadcast/redis.go +++ b/broadcast/redis.go @@ -16,10 +16,44 @@ import ( "github.com/redis/rueidis" ) +type RedisConfig struct { + Stream string `toml:"stream"` + Group string `toml:"group"` + // Redis stream read wait time in milliseconds + StreamReadBlockMilliseconds int64 `toml:"stream_read_block_milliseconds"` + + Redis *rconfig.RedisConfig `toml:"redis"` +} + +func NewRedisConfig() RedisConfig { + return RedisConfig{ + Stream: "__anycable__", + Group: "bx", + StreamReadBlockMilliseconds: 2000, + } +} + +func (c RedisConfig) ToToml() string { + var result strings.Builder + + result.WriteString("# Redis stream name for broadcasts\n") + result.WriteString(fmt.Sprintf("stream = \"%s\"\n", c.Stream)) + + result.WriteString("# Stream consumer group name\n") + result.WriteString(fmt.Sprintf("group = \"%s\"\n", c.Group)) + + result.WriteString("# Streams read wait time in milliseconds\n") + result.WriteString(fmt.Sprintf("stream_read_block_milliseconds = %d\n", c.StreamReadBlockMilliseconds)) + + result.WriteString("\n") + + return result.String() +} + // RedisBroadcaster represents Redis broadcaster using Redis streams type RedisBroadcaster struct { node Handler - config *rconfig.RedisConfig + config *RedisConfig // Unique consumer identifier consumerName string @@ -39,7 +73,7 @@ type RedisBroadcaster struct { var _ Broadcaster = (*RedisBroadcaster)(nil) // NewRedisBroadcaster builds a new RedisSubscriber struct -func NewRedisBroadcaster(node Handler, config *rconfig.RedisConfig, l *slog.Logger) *RedisBroadcaster { +func NewRedisBroadcaster(node Handler, config *RedisConfig, l *slog.Logger) *RedisBroadcaster { name, _ := nanoid.Nanoid(6) return &RedisBroadcaster{ @@ -57,18 +91,18 @@ func (s *RedisBroadcaster) IsFanout() bool { } func (s *RedisBroadcaster) Start(done chan error) error { - options, err := s.config.ToRueidisOptions() + options, err := s.config.Redis.ToRueidisOptions() if err != nil { return err } - if s.config.IsSentinel() { //nolint:gocritic - s.log.With("stream", s.config.Channel).With("consumer", s.consumerName).Info(fmt.Sprintf("Starting Redis broadcaster at %v (sentinels)", s.config.Hostnames())) - } else if s.config.IsCluster() { - s.log.With("stream", s.config.Channel).With("consumer", s.consumerName).Info(fmt.Sprintf("Starting Redis broadcaster at %v (cluster)", s.config.Hostnames())) + if s.config.Redis.IsSentinel() { //nolint:gocritic + s.log.With("stream", s.config.Stream).With("consumer", s.consumerName).Info(fmt.Sprintf("Starting Redis broadcaster at %v (sentinels)", s.config.Redis.Hostnames())) + } else if s.config.Redis.IsCluster() { + s.log.With("stream", s.config.Stream).With("consumer", s.consumerName).Info(fmt.Sprintf("Starting Redis broadcaster at %v (cluster)", s.config.Redis.Hostnames())) } else { - s.log.With("stream", s.config.Channel).With("consumer", s.consumerName).Info(fmt.Sprintf("Starting Redis broadcaster at %s", s.config.Hostname())) + s.log.With("stream", s.config.Stream).With("consumer", s.consumerName).Info(fmt.Sprintf("Starting Redis broadcaster at %s", s.config.Redis.Hostname())) } s.clientOptions = options @@ -94,7 +128,7 @@ func (s *RedisBroadcaster) Shutdown(ctx context.Context) error { res := s.client.Do( context.Background(), - s.client.B().XgroupDelconsumer().Key(s.config.Channel).Group(s.config.Group).Consumername(s.consumerName).Build(), + s.client.B().XgroupDelconsumer().Key(s.config.Stream).Group(s.config.Group).Consumername(s.consumerName).Build(), ) err := res.Error() @@ -144,7 +178,7 @@ func (s *RedisBroadcaster) runReader(done chan (error)) { // First, create a consumer group for the stream err = s.client.Do(context.Background(), - s.client.B().XgroupCreate().Key(s.config.Channel).Group(s.config.Group).Id("$").Mkstream().Build(), + s.client.B().XgroupCreate().Key(s.config.Stream).Group(s.config.Group).Id("$").Mkstream().Build(), ).Error() if err != nil { @@ -204,7 +238,7 @@ func (s *RedisBroadcaster) runReader(done chan (error)) { func (s *RedisBroadcaster) readFromStream(blockTime int64) ([]rueidis.XRangeEntry, error) { streamRes := s.client.Do(context.Background(), - s.client.B().Xreadgroup().Group(s.config.Group, s.consumerName).Block(blockTime).Streams().Key(s.config.Channel).Id(">").Build(), + s.client.B().Xreadgroup().Group(s.config.Group, s.consumerName).Block(blockTime).Streams().Key(s.config.Stream).Id(">").Build(), ) res, _ := streamRes.AsXRead() @@ -218,7 +252,7 @@ func (s *RedisBroadcaster) readFromStream(blockTime int64) ([]rueidis.XRangeEntr return nil, nil } - if messages, ok := res[s.config.Channel]; ok { + if messages, ok := res[s.config.Stream]; ok { return messages, nil } @@ -227,7 +261,7 @@ func (s *RedisBroadcaster) readFromStream(blockTime int64) ([]rueidis.XRangeEntr func (s *RedisBroadcaster) autoclaimMessages(blockTime int64) ([]rueidis.XRangeEntry, error) { claimRes := s.client.Do(context.Background(), - s.client.B().Xautoclaim().Key(s.config.Channel).Group(s.config.Group).Consumer(s.consumerName).MinIdleTime(fmt.Sprintf("%d", blockTime)).Start("0-0").Build(), + s.client.B().Xautoclaim().Key(s.config.Stream).Group(s.config.Group).Consumer(s.consumerName).MinIdleTime(fmt.Sprintf("%d", blockTime)).Start("0-0").Build(), ) arr, err := claimRes.ToArray() @@ -260,8 +294,8 @@ func (s *RedisBroadcaster) broadcastXrange(messages []rueidis.XRangeEntry) { s.node.HandleBroadcast([]byte(payload)) ackRes := s.client.DoMulti(context.Background(), - s.client.B().Xack().Key(s.config.Channel).Group(s.config.Group).Id(message.ID).Build(), - s.client.B().Xdel().Key(s.config.Channel).Id(message.ID).Build(), + s.client.B().Xack().Key(s.config.Stream).Group(s.config.Group).Id(message.ID).Build(), + s.client.B().Xdel().Key(s.config.Stream).Id(message.ID).Build(), ) err := ackRes[0].Error() @@ -274,7 +308,7 @@ func (s *RedisBroadcaster) broadcastXrange(messages []rueidis.XRangeEntry) { } func (s *RedisBroadcaster) maybeReconnect(done chan (error)) { - if s.reconnectAttempt >= s.config.MaxReconnectAttempts { + if s.reconnectAttempt >= s.config.Redis.MaxReconnectAttempts { close(s.finishedCh) done <- errors.New("failed to reconnect to Redis: attempts exceeded") //nolint:stylecheck return diff --git a/broadcast/redis_test.go b/broadcast/redis_test.go index 0d5137fe..a84e6ee7 100644 --- a/broadcast/redis_test.go +++ b/broadcast/redis_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/BurntSushi/toml" "github.com/anycable/anycable-go/mocks" rconfig "github.com/anycable/anycable-go/redis" "github.com/anycable/anycable-go/utils" @@ -64,10 +65,12 @@ func TestRedisBroadcaster(t *testing.T) { return } - config := rconfig.NewRedisConfig() + rconfig := rconfig.NewRedisConfig() + config := NewRedisConfig() + config.Redis = &rconfig if redisURL != "" { - config.URL = redisURL + rconfig.URL = redisURL } config.StreamReadBlockMilliseconds = 500 @@ -154,12 +157,14 @@ func TestRedisBroadcasterAcksClaims(t *testing.T) { return } - config := rconfig.NewRedisConfig() + rconfig := rconfig.NewRedisConfig() + config := NewRedisConfig() + config.Redis = &rconfig // Make it short to avoid sleeping for too long in tests config.StreamReadBlockMilliseconds = 100 if redisURL != "" { - config.URL = redisURL + rconfig.URL = redisURL } handler := &mocks.Handler{} @@ -181,7 +186,7 @@ func TestRedisBroadcasterAcksClaims(t *testing.T) { closed = true // Close the connection to prevent consumer from ack-ing the message broadcaster.client.Close() - broadcaster.reconnectAttempt = config.MaxReconnectAttempts + 1 + broadcaster.reconnectAttempt = config.Redis.MaxReconnectAttempts + 1 } }) @@ -285,3 +290,24 @@ func waitRedisStreamConsumers(client rueidis.Client, count int) error { attempts++ } } + +func TestRedisConfig__ToToml(t *testing.T) { + config := NewRedisConfig() + config.Stream = "test_stream" + config.Group = "test_group" + config.StreamReadBlockMilliseconds = 3000 + + tomlStr := config.ToToml() + + assert.Contains(t, tomlStr, "stream = \"test_stream\"") + assert.Contains(t, tomlStr, "group = \"test_group\"") + assert.Contains(t, tomlStr, "stream_read_block_milliseconds = 3000") + + // Round-trip test + config2 := NewRedisConfig() + + _, err := toml.Decode(tomlStr, &config2) + require.NoError(t, err) + + assert.Equal(t, config, config2) +} diff --git a/cli/options.go b/cli/options.go index 22185f82..a9417f15 100644 --- a/cli/options.go +++ b/cli/options.go @@ -247,6 +247,27 @@ Use shutdown_timeout instead.`) c.SSE.AllowedOrigins = c.Server.AllowedOrigins c.HTTPBroadcast.CORSHosts = c.Server.AllowedOrigins + // Propagate Redis and NATS configs to components + if c.RedisBroadcast.Redis == nil { + c.RedisBroadcast.Redis = &c.Redis + } + + if c.LegacyRedisBroadcast.Redis == nil { + c.LegacyRedisBroadcast.Redis = &c.Redis + } + + if c.NATSBroadcast.NATS == nil { + c.NATSBroadcast.NATS = &c.NATS + } + + if c.RedisPubSub.Redis == nil { + c.RedisPubSub.Redis = &c.Redis + } + + if c.NATSPubSub.NATS == nil { + c.NATSPubSub.NATS = &c.NATS + } + if turboRailsKey != "" { fmt.Println(`DEPRECATION WARNING: turbo_rails_key option is deprecated and will be removed in the next major release of anycable-go. @@ -568,6 +589,21 @@ func broadcastCLIFlags(c *config.Config, adapters *string) []cli.Flag { Value: c.PubSubAdapter, Destination: &c.PubSubAdapter, }, + + &cli.StringFlag{ + Name: "redis_channel", + Usage: "Redis channel for broadcasts", + Value: c.LegacyRedisBroadcast.Channel, + Destination: &c.LegacyRedisBroadcast.Channel, + }, + + &cli.StringFlag{ + Name: "nats_channel", + Usage: "NATS channel for broadcasts", + Value: c.NATSBroadcast.Channel, + Destination: &c.NATSBroadcast.Channel, + }, + &cli.IntFlag{ Name: "hub_gopool_size", Usage: "The size of the goroutines pool to broadcast messages", @@ -612,13 +648,6 @@ func redisCLIFlags(c *config.Config) []cli.Flag { Destination: &c.Redis.URL, }, - &cli.StringFlag{ - Name: "redis_channel", - Usage: "Redis channel for broadcasts", - Value: c.Redis.Channel, - Destination: &c.Redis.Channel, - }, - &cli.StringFlag{ Name: "redis_sentinels", Usage: "Comma separated list of sentinel hosts, format: 'hostname:port,..'", @@ -701,13 +730,6 @@ func natsCLIFlags(c *config.Config) []cli.Flag { Destination: &c.NATS.Servers, }, - &cli.StringFlag{ - Name: "nats_channel", - Usage: "NATS channel for broadcasts", - Value: c.NATS.Channel, - Destination: &c.NATS.Channel, - }, - &cli.BoolFlag{ Name: "nats_dont_randomize_servers", Usage: "Pass this option to disable NATS servers randomization during (re-)connect", diff --git a/cli/runner_options.go b/cli/runner_options.go index a278126f..3c4f1030 100644 --- a/cli/runner_options.go +++ b/cli/runner_options.go @@ -75,13 +75,13 @@ func WithDefaultBroadcaster() Option { hb := broadcast.NewHTTPBroadcaster(h, &c.HTTPBroadcast, l) broadcasters = append(broadcasters, hb) case "redis": - rb := broadcast.NewLegacyRedisBroadcaster(h, &c.Redis, l) + rb := broadcast.NewLegacyRedisBroadcaster(h, &c.LegacyRedisBroadcast, l) broadcasters = append(broadcasters, rb) case "redisx": - rb := broadcast.NewRedisBroadcaster(h, &c.Redis, l) + rb := broadcast.NewRedisBroadcaster(h, &c.RedisBroadcast, l) broadcasters = append(broadcasters, rb) case "nats": - nb := broadcast.NewLegacyNATSBroadcaster(h, &c.NATS, l) + nb := broadcast.NewLegacyNATSBroadcaster(h, &c.NATSBroadcast, l) broadcasters = append(broadcasters, nb) default: return broadcasters, errorx.IllegalArgument.New("Unsupported broadcast adapter: %s", adapter) @@ -110,9 +110,9 @@ func WithDefaultSubscriber() Option { case "": return pubsub.NewLegacySubscriber(h), nil case "redis": - return pubsub.NewRedisSubscriber(h, &c.Redis, l) + return pubsub.NewRedisSubscriber(h, &c.RedisPubSub, l) case "nats": - return pubsub.NewNATSSubscriber(h, &c.NATS, l) + return pubsub.NewNATSSubscriber(h, &c.NATSPubSub, l) } return nil, errorx.IllegalArgument.New("Unsupported subscriber adapter: %s", c.PubSubAdapter) diff --git a/config/config.go b/config/config.go index a8dd2efb..58f31406 100644 --- a/config/config.go +++ b/config/config.go @@ -14,6 +14,7 @@ import ( "github.com/anycable/anycable-go/metrics" nconfig "github.com/anycable/anycable-go/nats" "github.com/anycable/anycable-go/node" + "github.com/anycable/anycable-go/pubsub" rconfig "github.com/anycable/anycable-go/redis" "github.com/anycable/anycable-go/rpc" "github.com/anycable/anycable-go/server" @@ -43,7 +44,12 @@ type Config struct { RPC rpc.Config Broker broker.Config Redis rconfig.RedisConfig + LegacyRedisBroadcast broadcast.LegacyRedisConfig + RedisBroadcast broadcast.RedisConfig + NATSBroadcast broadcast.LegacyNATSConfig HTTPBroadcast broadcast.HTTPConfig + RedisPubSub pubsub.RedisConfig + NATSPubSub pubsub.NATSConfig NATS nconfig.NATSConfig DisconnectorDisabled bool DisconnectQueue node.DisconnectQueueConfig @@ -64,21 +70,26 @@ func NewConfig() Config { ID: id, Server: server.NewConfig(), // TODO(v2.0): Make HTTP default - BroadcastAdapters: []string{"http", "redis"}, - Broker: broker.NewConfig(), - Log: logger.NewConfig(), - App: node.NewConfig(), - WS: ws.NewConfig(), - Metrics: metrics.NewConfig(), - RPC: rpc.NewConfig(), - Redis: rconfig.NewRedisConfig(), - HTTPBroadcast: broadcast.NewHTTPConfig(), - NATS: nconfig.NewNATSConfig(), - DisconnectQueue: node.NewDisconnectQueueConfig(), - JWT: identity.NewJWTConfig(""), - EmbeddedNats: enats.NewConfig(), - SSE: sse.NewConfig(), - Streams: streams.NewConfig(), + BroadcastAdapters: []string{"http", "redis"}, + Broker: broker.NewConfig(), + Log: logger.NewConfig(), + App: node.NewConfig(), + WS: ws.NewConfig(), + Metrics: metrics.NewConfig(), + RPC: rpc.NewConfig(), + Redis: rconfig.NewRedisConfig(), + RedisBroadcast: broadcast.NewRedisConfig(), + LegacyRedisBroadcast: broadcast.NewLegacyRedisConfig(), + NATSBroadcast: broadcast.NewLegacyNATSConfig(), + HTTPBroadcast: broadcast.NewHTTPConfig(), + RedisPubSub: pubsub.NewRedisConfig(), + NATSPubSub: pubsub.NewNATSConfig(), + NATS: nconfig.NewNATSConfig(), + DisconnectQueue: node.NewDisconnectQueueConfig(), + JWT: identity.NewJWTConfig(""), + EmbeddedNats: enats.NewConfig(), + SSE: sse.NewConfig(), + Streams: streams.NewConfig(), } return config @@ -198,8 +209,19 @@ func (c Config) ToToml() string { result.WriteString("# NATS configuration\n[nats]\n") result.WriteString(c.NATS.ToToml()) - result.WriteString("# Broadcasting configuration\n[http_broadcast]\n") + result.WriteString("# Broadcast adapters configuration\n[http_broadcast]\n") result.WriteString(c.HTTPBroadcast.ToToml()) + result.WriteString("[redis_stream_broadcast]\n") + result.WriteString(c.RedisBroadcast.ToToml()) + result.WriteString("[redis_pubsub_broadcast]\n") + result.WriteString(c.LegacyRedisBroadcast.ToToml()) + result.WriteString("[nats_broadcast]\n") + result.WriteString(c.NATSBroadcast.ToToml()) + + result.WriteString("# Pub/sub adapters configuration\n[redis_pubsub]\n") + result.WriteString(c.RedisPubSub.ToToml()) + result.WriteString("[nats_pubsub]\n") + result.WriteString(c.NATSPubSub.ToToml()) result.WriteString("# Metrics configuration\n[metrics]\n") result.WriteString(c.Metrics.ToToml()) diff --git a/nats/config.go b/nats/config.go index e278596b..47367d9d 100644 --- a/nats/config.go +++ b/nats/config.go @@ -9,19 +9,14 @@ import ( type NATSConfig struct { Servers string `toml:"servers"` - Channel string `toml:"channel"` DontRandomizeServers bool `toml:"dont_randomize_servers"` MaxReconnectAttempts int `toml:"max_reconnect_attempts"` - // Internal channel name for node-to-node broadcasting - InternalChannel string `toml:"internal_channel"` } func NewNATSConfig() NATSConfig { return NATSConfig{ Servers: natsgo.DefaultURL, - Channel: "__anycable__", MaxReconnectAttempts: 5, - InternalChannel: "__anycable_internal__", } } @@ -31,9 +26,6 @@ func (c NATSConfig) ToToml() string { result.WriteString("# NATS server URLs (comma-separated)\n") result.WriteString(fmt.Sprintf("servers = \"%s\"\n", c.Servers)) - result.WriteString("# Channel name for legacy broadasting\n") - result.WriteString(fmt.Sprintf("channel = \"%s\"\n", c.Channel)) - result.WriteString("# Don't randomize servers during connection\n") if c.DontRandomizeServers { result.WriteString("dont_randomize_servers = true\n") @@ -44,9 +36,6 @@ func (c NATSConfig) ToToml() string { result.WriteString("# Max number of reconnect attempts\n") result.WriteString(fmt.Sprintf("max_reconnect_attempts = %d\n", c.MaxReconnectAttempts)) - result.WriteString("# Channel name for pub/sub (node-to-node)\n") - result.WriteString(fmt.Sprintf("internal_channel = \"%s\"\n", c.InternalChannel)) - result.WriteString("\n") return result.String() diff --git a/nats/config_test.go b/nats/config_test.go index 58926188..b62401b0 100644 --- a/nats/config_test.go +++ b/nats/config_test.go @@ -11,18 +11,14 @@ import ( func TestNATSConfig_ToToml(t *testing.T) { conf := NewNATSConfig() conf.Servers = "nats://localhost:4222" - conf.Channel = "test_channel" conf.DontRandomizeServers = true conf.MaxReconnectAttempts = 10 - conf.InternalChannel = "test_internal_channel" tomlStr := conf.ToToml() assert.Contains(t, tomlStr, "servers = \"nats://localhost:4222\"") - assert.Contains(t, tomlStr, "channel = \"test_channel\"") assert.Contains(t, tomlStr, "dont_randomize_servers = true") assert.Contains(t, tomlStr, "max_reconnect_attempts = 10") - assert.Contains(t, tomlStr, "internal_channel = \"test_internal_channel\"") // Round-trip test conf2 := NewNATSConfig() diff --git a/pubsub/nats.go b/pubsub/nats.go index 9af3e331..6cd17827 100644 --- a/pubsub/nats.go +++ b/pubsub/nats.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "strings" "sync" "github.com/anycable/anycable-go/common" @@ -14,9 +15,29 @@ import ( "github.com/nats-io/nats.go" ) +type NATSConfig struct { + Channel string `toml:"channel"` + NATS *nconfig.NATSConfig +} + +func NewNATSConfig() NATSConfig { + return NATSConfig{ + Channel: "__anycable_internal__", + } +} + +func (c NATSConfig) ToToml() string { + var result strings.Builder + result.WriteString(fmt.Sprintf("channel = \"%s\"\n", c.Channel)) + + result.WriteString("\n") + + return result.String() +} + type NATSSubscriber struct { node Handler - config *nconfig.NATSConfig + config *NATSConfig conn *nats.Conn @@ -29,7 +50,7 @@ type NATSSubscriber struct { var _ Subscriber = (*NATSSubscriber)(nil) // NewNATSSubscriber creates a NATS subscriber using pub/sub -func NewNATSSubscriber(node Handler, config *nconfig.NATSConfig, l *slog.Logger) (*NATSSubscriber, error) { +func NewNATSSubscriber(node Handler, config *NATSConfig, l *slog.Logger) (*NATSSubscriber, error) { return &NATSSubscriber{ node: node, config: config, @@ -41,7 +62,7 @@ func NewNATSSubscriber(node Handler, config *nconfig.NATSConfig, l *slog.Logger) func (s *NATSSubscriber) Start(done chan (error)) error { connectOptions := []nats.Option{ nats.RetryOnFailedConnect(true), - nats.MaxReconnects(s.config.MaxReconnectAttempts), + nats.MaxReconnects(s.config.NATS.MaxReconnectAttempts), nats.DisconnectErrHandler(func(nc *nats.Conn, err error) { if err != nil { s.log.Warn("connection failed", "error", err) @@ -52,21 +73,21 @@ func (s *NATSSubscriber) Start(done chan (error)) error { }), } - if s.config.DontRandomizeServers { + if s.config.NATS.DontRandomizeServers { connectOptions = append(connectOptions, nats.DontRandomize()) } - nc, err := nats.Connect(s.config.Servers, connectOptions...) + nc, err := nats.Connect(s.config.NATS.Servers, connectOptions...) if err != nil { return err } - s.log.Info(fmt.Sprintf("Starting NATS pub/sub: %s", s.config.Servers)) + s.log.Info(fmt.Sprintf("Starting NATS pub/sub: %s", s.config.NATS.Servers)) s.conn = nc - s.Subscribe(s.config.InternalChannel) + s.Subscribe(s.config.Channel) return nil } @@ -120,7 +141,7 @@ func (s *NATSSubscriber) Broadcast(msg *common.StreamMessage) { } func (s *NATSSubscriber) BroadcastCommand(cmd *common.RemoteCommandMessage) { - s.Publish(s.config.InternalChannel, cmd) + s.Publish(s.config.Channel, cmd) } func (s *NATSSubscriber) Publish(stream string, msg interface{}) { diff --git a/pubsub/nats_test.go b/pubsub/nats_test.go index 68d455df..d3d6b489 100644 --- a/pubsub/nats_test.go +++ b/pubsub/nats_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/BurntSushi/toml" "github.com/anycable/anycable-go/common" "github.com/anycable/anycable-go/enats" "github.com/anycable/anycable-go/nats" @@ -18,13 +19,32 @@ import ( nats_server "github.com/nats-io/nats.go" ) +func TestNATSConfig__ToToml(t *testing.T) { + conf := NewNATSConfig() + conf.Channel = "_test_" + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "channel = \"_test_\"") + + // Round-trip test + conf2 := NewNATSConfig() + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} + func TestNATSCommon(t *testing.T) { server := buildNATSServer() err := server.Start() require.NoError(t, err) defer server.Shutdown(context.Background()) // nolint:errcheck - config := nats.NewNATSConfig() + nconfig := nats.NewNATSConfig() + config := NewNATSConfig() + config.NATS = &nconfig SharedSubscriberTests(t, func(handler *TestHandler) Subscriber { sub, err := NewNATSSubscriber(handler, &config, slog.Default()) @@ -44,7 +64,9 @@ func TestNATSReconnect(t *testing.T) { defer server.Shutdown(context.Background()) // nolint:errcheck handler := NewTestHandler() - config := nats.NewNATSConfig() + nconfig := nats.NewNATSConfig() + config := NewNATSConfig() + config.NATS = &nconfig subscriber, err := NewNATSSubscriber(handler, &config, slog.Default()) require.NoError(t, err) @@ -93,7 +115,7 @@ func waitNATSSubscription(subscriber Subscriber, stream string) error { } if stream == "internal" { - stream = s.config.InternalChannel + stream = s.config.Channel } unsubscribing := false diff --git a/pubsub/redis.go b/pubsub/redis.go index fac02fae..7d09f129 100644 --- a/pubsub/redis.go +++ b/pubsub/redis.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "strings" "sync" "time" @@ -32,9 +33,29 @@ type subscriptionEntry struct { id string } +type RedisConfig struct { + Channel string `toml:"channel"` + Redis *rconfig.RedisConfig +} + +func NewRedisConfig() RedisConfig { + return RedisConfig{ + Channel: "__anycable_internal__", + } +} + +func (c RedisConfig) ToToml() string { + var result strings.Builder + result.WriteString(fmt.Sprintf("channel = \"%s\"\n", c.Channel)) + + result.WriteString("\n") + + return result.String() +} + type RedisSubscriber struct { node Handler - config *rconfig.RedisConfig + config *RedisConfig client rueidis.Client clientOptions *rueidis.ClientOption @@ -59,8 +80,8 @@ type RedisSubscriber struct { var _ Subscriber = (*RedisSubscriber)(nil) // NewRedisSubscriber creates a Redis subscriber using pub/sub -func NewRedisSubscriber(node Handler, config *rconfig.RedisConfig, l *slog.Logger) (*RedisSubscriber, error) { - options, err := config.ToRueidisOptions() +func NewRedisSubscriber(node Handler, config *RedisConfig, l *slog.Logger) (*RedisSubscriber, error) { + options, err := config.Redis.ToRueidisOptions() if err != nil { return nil, err @@ -80,17 +101,17 @@ func NewRedisSubscriber(node Handler, config *rconfig.RedisConfig, l *slog.Logge } func (s *RedisSubscriber) Start(done chan (error)) error { - if s.config.IsSentinel() { //nolint:gocritic - s.log.Info(fmt.Sprintf("Starting Redis pub/sub (sentinels): %v", s.config.Hostnames())) - } else if s.config.IsCluster() { - s.log.Info(fmt.Sprintf("Starting Redis pub/sub (cluster): %v", s.config.Hostnames())) + if s.config.Redis.IsSentinel() { //nolint:gocritic + s.log.Info(fmt.Sprintf("Starting Redis pub/sub (sentinels): %v", s.config.Redis.Hostnames())) + } else if s.config.Redis.IsCluster() { + s.log.Info(fmt.Sprintf("Starting Redis pub/sub (cluster): %v", s.config.Redis.Hostnames())) } else { - s.log.Info(fmt.Sprintf("Starting Redis pub/sub: %s", s.config.Hostname())) + s.log.Info(fmt.Sprintf("Starting Redis pub/sub: %s", s.config.Redis.Hostname())) } // Add internal channel to subscriptions s.subMu.Lock() - s.subscriptions[s.config.InternalChannel] = &subscriptionEntry{id: s.config.InternalChannel} + s.subscriptions[s.config.Channel] = &subscriptionEntry{id: s.config.Channel} s.subMu.Unlock() go s.runPubSub(done) @@ -146,7 +167,7 @@ func (s *RedisSubscriber) Broadcast(msg *common.StreamMessage) { } func (s *RedisSubscriber) BroadcastCommand(cmd *common.RemoteCommandMessage) { - s.Publish(s.config.InternalChannel, cmd) + s.Publish(s.config.Channel, cmd) } func (s *RedisSubscriber) Publish(stream string, msg interface{}) { @@ -202,7 +223,7 @@ func (s *RedisSubscriber) runPubSub(done chan (error)) { wait := client.SetPubSubHooks(rueidis.PubSubHooks{ OnSubscription: func(m rueidis.PubSubSubscription) { - if m.Kind == "subscribe" && m.Channel == s.config.InternalChannel { + if m.Kind == "subscribe" && m.Channel == s.config.Channel { if s.reconnectAttempt > 0 { s.log.Info("reconnected") } else { @@ -264,7 +285,7 @@ func (s *RedisSubscriber) runPubSub(done chan (error)) { } func (s *RedisSubscriber) maybeReconnect(done chan (error)) { - if s.reconnectAttempt >= s.config.MaxReconnectAttempts { + if s.reconnectAttempt >= s.config.Redis.MaxReconnectAttempts { done <- errors.New("failed to reconnect to Redis: attempts exceeded") //nolint:stylecheck return } diff --git a/pubsub/redis_test.go b/pubsub/redis_test.go index 6768ce00..628c947b 100644 --- a/pubsub/redis_test.go +++ b/pubsub/redis_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/BurntSushi/toml" "github.com/anycable/anycable-go/common" rconfig "github.com/anycable/anycable-go/redis" "github.com/redis/rueidis" @@ -24,10 +25,12 @@ var ( // Check if Redis is available and skip tests otherwise func init() { - config := rconfig.NewRedisConfig() + rconfig := rconfig.NewRedisConfig() + config := NewRedisConfig() + config.Redis = &rconfig if redisURL != "" { - config.URL = redisURL + rconfig.URL = redisURL } subscriber, err := NewRedisSubscriber(nil, &config, slog.Default()) @@ -45,7 +48,7 @@ func init() { err = subscriber.initClient() if err != nil { - fmt.Printf("No Redis detected at %s: %v", config.URL, err) + fmt.Printf("No Redis detected at %s: %v", rconfig.URL, err) return } @@ -58,16 +61,35 @@ func init() { redisAvailable = err == nil } +func TestRedisConfig__ToToml(t *testing.T) { + conf := NewRedisConfig() + conf.Channel = "_test_" + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "channel = \"_test_\"") + + // Round-trip test + conf2 := NewRedisConfig() + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} + func TestRedisCommon(t *testing.T) { if !redisAvailable { t.Skip("Skipping Redis tests: no Redis available") return } - config := rconfig.NewRedisConfig() + rconfig := rconfig.NewRedisConfig() + config := NewRedisConfig() + config.Redis = &rconfig if redisURL != "" { - config.URL = redisURL + rconfig.URL = redisURL } SharedSubscriberTests(t, func(handler *TestHandler) Subscriber { @@ -89,10 +111,12 @@ func TestRedisReconnect(t *testing.T) { } handler := NewTestHandler() - config := rconfig.NewRedisConfig() + rconfig := rconfig.NewRedisConfig() + config := NewRedisConfig() + config.Redis = &rconfig if redisURL != "" { - config.URL = redisURL + rconfig.URL = redisURL } subscriber, err := NewRedisSubscriber(handler, &config, slog.Default()) @@ -135,7 +159,7 @@ func waitRedisSubscription(subscriber Subscriber, stream string) error { s := subscriber.(*RedisSubscriber) if stream == "internal" { - stream = s.config.InternalChannel + stream = s.config.Channel } unsubscribing := false diff --git a/redis/config.go b/redis/config.go index 1bf4e73c..89deed9e 100644 --- a/redis/config.go +++ b/redis/config.go @@ -14,12 +14,6 @@ type RedisConfig struct { // Redis instance URL or master name in case of sentinels usage // or list of URLs if cluster usage URL string `toml:"url"` - // Redis channel to subscribe to (legacy pub/sub) - Channel string `toml:"channel"` - // Redis stream consumer group name - Group string `toml:"group"` - // Redis stream read wait time in milliseconds - StreamReadBlockMilliseconds int64 `toml:"stream_read_block_milliseconds"` // Internal channel name for node-to-node broadcasting InternalChannel string `toml:"internal_channel"` // List of Redis Sentinel addresses @@ -44,16 +38,13 @@ type RedisConfig struct { // NewRedisConfig builds a new config for Redis pubsub func NewRedisConfig() RedisConfig { return RedisConfig{ - KeepalivePingInterval: 30, - URL: "redis://localhost:6379", - Channel: "__anycable__", - Group: "bx", - StreamReadBlockMilliseconds: 2000, - InternalChannel: "__anycable_internal__", - SentinelDiscoveryInterval: 30, - TLSVerify: false, - MaxReconnectAttempts: 5, - DisableCache: false, + KeepalivePingInterval: 30, + URL: "redis://localhost:6379", + InternalChannel: "__anycable_internal__", + SentinelDiscoveryInterval: 30, + TLSVerify: false, + MaxReconnectAttempts: 5, + DisableCache: false, } } @@ -129,15 +120,6 @@ func (config RedisConfig) ToToml() string { result.WriteString("# or list of URLs if cluster usage\n") result.WriteString(fmt.Sprintf("url = \"%s\"\n", config.URL)) - result.WriteString("# Channel name for legacy broadcasting\n") - result.WriteString(fmt.Sprintf("channel = \"%s\"\n", config.Channel)) - - result.WriteString("# Stream consumer group name for RedisX broadcasting\n") - result.WriteString(fmt.Sprintf("group = \"%s\"\n", config.Group)) - - result.WriteString("# Streams read wait time in milliseconds\n") - result.WriteString(fmt.Sprintf("stream_read_block_milliseconds = %d\n", config.StreamReadBlockMilliseconds)) - result.WriteString("# Channel name for pub/sub (node-to-node)\n") result.WriteString(fmt.Sprintf("internal_channel = \"%s\"\n", config.InternalChannel)) diff --git a/redis/config_test.go b/redis/config_test.go index 31b90119..9d106990 100644 --- a/redis/config_test.go +++ b/redis/config_test.go @@ -170,9 +170,6 @@ func TestInvalidURL(t *testing.T) { func TestRedisConfig__ToToml(t *testing.T) { config := NewRedisConfig() config.URL = "redis://example.com:6379" - config.Channel = "test_channel" - config.Group = "test_group" - config.StreamReadBlockMilliseconds = 3000 config.InternalChannel = "test_internal" config.Sentinels = "sentinel1:26379,sentinel2:26379" config.SentinelDiscoveryInterval = 60 @@ -184,9 +181,6 @@ func TestRedisConfig__ToToml(t *testing.T) { tomlStr := config.ToToml() assert.Contains(t, tomlStr, "url = \"redis://example.com:6379\"") - assert.Contains(t, tomlStr, "channel = \"test_channel\"") - assert.Contains(t, tomlStr, "group = \"test_group\"") - assert.Contains(t, tomlStr, "stream_read_block_milliseconds = 3000") assert.Contains(t, tomlStr, "internal_channel = \"test_internal\"") assert.Contains(t, tomlStr, "sentinels = \"sentinel1:26379,sentinel2:26379\"") assert.Contains(t, tomlStr, "sentinel_discovery_interval = 60")