Skip to content
This repository was archived by the owner on Dec 28, 2024. It is now read-only.

Commit

Permalink
refactor: move AllowedOrigins to server.Config
Browse files Browse the repository at this point in the history
  • Loading branch information
palkan committed Oct 9, 2024
1 parent ef7b054 commit 8ea4e9e
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 38 deletions.
11 changes: 1 addition & 10 deletions broadcast/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,11 @@ func (c HTTPConfig) ToToml() string {
result.WriteString("# secret = \"\"\n")
}

result.WriteString("# Enable CORS headers\n")
result.WriteString("# Enable CORS headers (allowed origins are used as allowed hosts)\n")
if c.AddCORSHeaders {
result.WriteString("cors_headers = true\n")

result.WriteString("# Allowed hosts for CORS (comma-separated)\n")
if c.CORSHosts != "" {
result.WriteString(fmt.Sprintf("cors_hosts = \"%s\"\n", c.CORSHosts))
} else {
result.WriteString("# cors_hosts = \"\"\n")
}
} else {
result.WriteString("# cors_headers = false\n")
result.WriteString("# Allowed hosts for CORS (comma-separated)\n")
result.WriteString("# cors_hosts = \"\"\n")
}

result.WriteString("\n")
Expand Down
2 changes: 0 additions & 2 deletions broadcast/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,13 @@ func TestHTTPConfig__ToToml(t *testing.T) {
conf.Path = "/broadcast"
conf.Secret = ""
conf.AddCORSHeaders = true
conf.CORSHosts = "example.com,test.com"

tomlStr := conf.ToToml()

assert.Contains(t, tomlStr, "port = 8080")
assert.Contains(t, tomlStr, "path = \"/broadcast\"")
assert.Contains(t, tomlStr, "# secret = \"\"")
assert.Contains(t, tomlStr, "cors_headers = true")
assert.Contains(t, tomlStr, "cors_hosts = \"example.com,test.com\"")

// Round-trip test
conf2 := NewHTTPConfig()
Expand Down
18 changes: 10 additions & 8 deletions cli/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,10 @@ and will be removed in the next major release of anycable-go.
Use shutdown_timeout instead.`)
}

c.SSE.AllowedOrigins = c.WS.AllowedOrigins
c.HTTPBroadcast.CORSHosts = c.WS.AllowedOrigins
// Propagate allowed origins to all the components
c.WS.AllowedOrigins = c.Server.AllowedOrigins
c.SSE.AllowedOrigins = c.Server.AllowedOrigins
c.HTTPBroadcast.CORSHosts = c.Server.AllowedOrigins

if turboRailsKey != "" {
fmt.Println(`DEPRECATION WARNING: turbo_rails_key option is deprecated
Expand Down Expand Up @@ -446,6 +448,12 @@ func serverCLIFlags(c *config.Config, path *string) []cli.Flag {
Destination: &c.Server.Port,
},

&cli.StringFlag{
Name: "allowed_origins",
Usage: `Accept requests only from specified origins, e.g., "www.example.com,*example.io". No check is performed if empty`,
Destination: &c.Server.AllowedOrigins,
},

&cli.StringFlag{
Name: "secret",
Usage: "A common secret key used by all features by default",
Expand Down Expand Up @@ -1071,12 +1079,6 @@ func wsCLIFlags(c *config.Config) []cli.Flag {
Destination: &c.WS.EnableCompression,
Hidden: true,
},

&cli.StringFlag{
Name: "allowed_origins",
Usage: `Accept requests only from specified origins, e.g., "www.example.com,*example.io". No check is performed if empty`,
Destination: &c.WS.AllowedOrigins,
},
})
}

Expand Down
6 changes: 3 additions & 3 deletions node/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Config struct {
// Define when to invoke Disconnect callback
DisconnectMode string `toml:"disconnect_mode"`
// The number of goroutines to use for disconnect calls on shutdown
ShutdownDisconnectPoolSize int `toml:"shutdown_disconnect_pool_size"`
ShutdownDisconnectPoolSize int `toml:"shutdown_disconnect_gopool_size"`
// How often server should send Action Cable ping messages (seconds)
PingInterval int `toml:"ping_interval"`
// How ofter to refresh node stats (seconds)
Expand Down Expand Up @@ -74,8 +74,8 @@ func (c Config) ToToml() string {
result.WriteString("# The number of Go routines to use for broadcasting (server-to-client fan-out)\n")
result.WriteString(fmt.Sprintf("broadcast_gopool_size = %d\n", c.HubGopoolSize))

result.WriteString("# The number of goroutines to use for Disconnect RPC calls on shutdown\n")
result.WriteString(fmt.Sprintf("shutdown_disconnect_pool_size = %d\n", c.ShutdownDisconnectPoolSize))
result.WriteString("# The number of Go routines to use for Disconnect RPC calls on shutdown\n")
result.WriteString(fmt.Sprintf("shutdown_disconnect_gopool_size = %d\n", c.ShutdownDisconnectPoolSize))

result.WriteString("\n")

Expand Down
2 changes: 2 additions & 0 deletions node/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ func TestConfig_ToToml(t *testing.T) {
conf.DisconnectMode = "always"
conf.HubGopoolSize = 100
conf.PingTimestampPrecision = "ns"
conf.ShutdownDisconnectPoolSize = 1024

tomlStr := conf.ToToml()

assert.Contains(t, tomlStr, "disconnect_mode = \"always\"")
assert.Contains(t, tomlStr, "broadcast_gopool_size = 100")
assert.Contains(t, tomlStr, "ping_timestamp_precision = \"ns\"")
assert.Contains(t, tomlStr, "# pong_timeout = 6")
assert.Contains(t, tomlStr, "shutdown_disconnect_gopool_size = 1024")

// Round-trip test
conf2 := NewConfig()
Expand Down
15 changes: 10 additions & 5 deletions server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ import (
)

type Config struct {
Host string `toml:"host"`
Port int `toml:"port"`
MaxConn int `toml:"max_conn"`
HealthPath string `toml:"health_path"`
SSL SSLConfig `toml:"ssl"`
Host string `toml:"host"`
Port int `toml:"port"`
AllowedOrigins string `toml:"allowed_origins"`
MaxConn int `toml:"max_conn"`
HealthPath string `toml:"health_path"`
SSL SSLConfig `toml:"ssl"`
}

func NewConfig() Config {
Expand All @@ -29,6 +30,10 @@ func (c Config) ToToml() string {
result.WriteString(fmt.Sprintf("host = %q\n", c.Host))
result.WriteString("# Port to listen on\n")
result.WriteString(fmt.Sprintf("port = %d\n", c.Port))

result.WriteString("# Allowed origins (a comma-separated list)\n")
result.WriteString(fmt.Sprintf("allowed_origins = \"%s\"\n", c.AllowedOrigins))

result.WriteString("# Maximum number of allowed concurrent connections\n")
if c.MaxConn == 0 {
result.WriteString("# max_conn = 1000\n")
Expand Down
2 changes: 2 additions & 0 deletions server/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ func TestConfig__ToToml(t *testing.T) {
conf.HealthPath = "/healthz"
conf.SSL.CertPath = "/path/to/cert"
conf.SSL.KeyPath = "/path/to/key"
conf.AllowedOrigins = "http://example.com"

tomlStr := conf.ToToml()

assert.Contains(t, tomlStr, "host = \"local.test\"")
assert.Contains(t, tomlStr, "port = 8082")
assert.Contains(t, tomlStr, "# max_conn = 1000")
assert.Contains(t, tomlStr, "health_path = \"/healthz\"")
assert.Contains(t, tomlStr, "allowed_origins = \"http://example.com\"")
assert.Contains(t, tomlStr, "ssl.cert_path = \"/path/to/cert\"")
assert.Contains(t, tomlStr, "ssl.key_path = \"/path/to/key\"")

Expand Down
6 changes: 2 additions & 4 deletions sse/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@ const (
type Config struct {
Enabled bool `toml:"enabled"`
// Path is the URL path to handle SSE requests
Path string `toml:"path"`
// List of allowed origins for CORS requests
// We inherit it from the ws.Config
AllowedOrigins string
Path string `toml:"path"`
AllowedOrigins string `toml:"-"`
}

// NewConfig creates a new Config with default values.
Expand Down
5 changes: 1 addition & 4 deletions ws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type Config struct {
WriteBufferSize int `toml:"write_buffer_size"`
MaxMessageSize int64 `toml:"max_message_size"`
EnableCompression bool `toml:"enable_compression"`
AllowedOrigins string `toml:"allowed_origins"`
AllowedOrigins string `toml:"-"`
}

// NewConfig build a new Config struct
Expand All @@ -27,9 +27,6 @@ func (c Config) ToToml() string {
result.WriteString("# WebSocket endpoint paths\n")
result.WriteString(fmt.Sprintf("paths = [\"%s\"]\n", strings.Join(c.Paths, "\", \"")))

result.WriteString("# Allowed origins (a comma-separated list)\n")
result.WriteString(fmt.Sprintf("allowed_origins = \"%s\"\n", c.AllowedOrigins))

result.WriteString("# Read buffer size\n")
result.WriteString(fmt.Sprintf("read_buffer_size = %d\n", c.ReadBufferSize))

Expand Down
2 changes: 0 additions & 2 deletions ws/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ func TestConfig_ToToml(t *testing.T) {
conf.WriteBufferSize = 2048
conf.MaxMessageSize = 131072
conf.EnableCompression = true
conf.AllowedOrigins = "http://example.com"

tomlStr := conf.ToToml()

Expand All @@ -24,7 +23,6 @@ func TestConfig_ToToml(t *testing.T) {
assert.Contains(t, tomlStr, "write_buffer_size = 2048")
assert.Contains(t, tomlStr, "max_message_size = 131072")
assert.Contains(t, tomlStr, "enable_compression = true")
assert.Contains(t, tomlStr, "allowed_origins = \"http://example.com\"")

// Round-trip test
conf2 := Config{}
Expand Down

0 comments on commit 8ea4e9e

Please sign in to comment.