Skip to content

Commit

Permalink
feat: allow specifying CORS headers for broadcast endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
palkan committed Mar 14, 2024
1 parent ea2e04d commit 2f7d108
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 7 deletions.
46 changes: 39 additions & 7 deletions broadcast/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"log/slog"
"net/http"
"strconv"
"strings"

"github.com/anycable/anycable-go/server"
"github.com/anycable/anycable-go/utils"
Expand All @@ -28,6 +29,11 @@ type HTTPConfig struct {
Secret string
// SecretBase is a secret used to generate a token if none provided
SecretBase string
// AddCORSHeaders enables adding CORS headers (so you can perform broadcast requests from the browser)
// (We mostly need it for Stackblitz)
AddCORSHeaders bool
// CORSHosts contains a list of hostnames for CORS (comma-separated)
CORSHosts string
}

// NewHTTPConfig builds a new config for HTTP pub/sub
Expand All @@ -43,13 +49,15 @@ func (c *HTTPConfig) IsSecured() bool {

// HTTPBroadcaster represents HTTP broadcaster
type HTTPBroadcaster struct {
port int
path string
conf *HTTPConfig
authHeader string
server *server.HTTPServer
node Handler
log *slog.Logger
port int
path string
conf *HTTPConfig
authHeader string
enableCORS bool
allowedHosts []string
server *server.HTTPServer
node Handler
log *slog.Logger
}

var _ Broadcaster = (*HTTPBroadcaster)(nil)
Expand Down Expand Up @@ -92,6 +100,15 @@ func (s *HTTPBroadcaster) Prepare() error {

s.authHeader = authHeader

if s.conf.AddCORSHeaders {
s.enableCORS = true
if s.conf.CORSHosts != "" {
s.allowedHosts = strings.Split(s.conf.CORSHosts, ",")
} else {
s.allowedHosts = []string{}
}
}

return nil
}

Expand Down Expand Up @@ -119,6 +136,10 @@ func (s *HTTPBroadcaster) Start(done chan (error)) error {
verifiedVia = "no authorization"
}

if s.enableCORS {
verifiedVia = verifiedVia + ", CORS enabled"
}

s.log.Info(fmt.Sprintf("Accept broadcast requests at %s%s (%s)", s.server.Address(), s.path, verifiedVia))

go func() {
Expand All @@ -143,6 +164,17 @@ func (s *HTTPBroadcaster) Shutdown(ctx context.Context) error {

// Handler processes HTTP requests
func (s *HTTPBroadcaster) Handler(w http.ResponseWriter, r *http.Request) {
if s.enableCORS {
// Write CORS headers
server.WriteCORSHeaders(w, r, s.allowedHosts)

// Respond to preflight requests
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
}

if r.Method != "POST" {
s.log.Debug("invalid request method", "method", r.Method)
w.WriteHeader(422)
Expand Down
7 changes: 7 additions & 0 deletions cli/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ Use shutdown_timeout instead.`)
}

c.SSE.AllowedOrigins = c.WS.AllowedOrigins
c.HTTPBroadcast.CORSHosts = c.WS.AllowedOrigins

if turboRailsKey != "" {
fmt.Println(`DEPRECATION WARNING: turbo_rails_key option is deprecated
Expand Down Expand Up @@ -602,6 +603,12 @@ func httpBroadcastCLIFlags(c *config.Config) []cli.Flag {
Destination: &c.HTTPBroadcast.Secret,
Hidden: true,
},

&cli.BoolFlag{
Name: "http_broadcast_cors",
Destination: &c.HTTPBroadcast.AddCORSHeaders,
Hidden: true,
},
})
}

Expand Down

0 comments on commit 2f7d108

Please sign in to comment.