diff --git a/broadcast/http.go b/broadcast/http.go index ecd5287b..5e68bdf3 100644 --- a/broadcast/http.go +++ b/broadcast/http.go @@ -7,6 +7,7 @@ import ( "log/slog" "net/http" "strconv" + "strings" "github.com/anycable/anycable-go/server" "github.com/anycable/anycable-go/utils" @@ -28,6 +29,11 @@ type HTTPConfig struct { Secret string // SecretBase is a secret used to generate a token if none provided SecretBase string + // AddCORSHeaders enables adding CORS headers (so you can perform broadcast requests from the browser) + // (We mostly need it for Stackblitz) + AddCORSHeaders bool + // CORSHosts contains a list of hostnames for CORS (comma-separated) + CORSHosts string } // NewHTTPConfig builds a new config for HTTP pub/sub @@ -43,13 +49,15 @@ func (c *HTTPConfig) IsSecured() bool { // HTTPBroadcaster represents HTTP broadcaster type HTTPBroadcaster struct { - port int - path string - conf *HTTPConfig - authHeader string - server *server.HTTPServer - node Handler - log *slog.Logger + port int + path string + conf *HTTPConfig + authHeader string + enableCORS bool + allowedHosts []string + server *server.HTTPServer + node Handler + log *slog.Logger } var _ Broadcaster = (*HTTPBroadcaster)(nil) @@ -92,6 +100,15 @@ func (s *HTTPBroadcaster) Prepare() error { s.authHeader = authHeader + if s.conf.AddCORSHeaders { + s.enableCORS = true + if s.conf.CORSHosts != "" { + s.allowedHosts = strings.Split(s.conf.CORSHosts, ",") + } else { + s.allowedHosts = []string{} + } + } + return nil } @@ -119,6 +136,10 @@ func (s *HTTPBroadcaster) Start(done chan (error)) error { verifiedVia = "no authorization" } + if s.enableCORS { + verifiedVia = verifiedVia + ", CORS enabled" + } + s.log.Info(fmt.Sprintf("Accept broadcast requests at %s%s (%s)", s.server.Address(), s.path, verifiedVia)) go func() { @@ -143,6 +164,17 @@ func (s *HTTPBroadcaster) Shutdown(ctx context.Context) error { // Handler processes HTTP requests func (s *HTTPBroadcaster) Handler(w http.ResponseWriter, r *http.Request) { + if s.enableCORS { + // Write CORS headers + server.WriteCORSHeaders(w, r, s.allowedHosts) + + // Respond to preflight requests + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + } + if r.Method != "POST" { s.log.Debug("invalid request method", "method", r.Method) w.WriteHeader(422) diff --git a/cli/options.go b/cli/options.go index 06b1672c..b02c2301 100644 --- a/cli/options.go +++ b/cli/options.go @@ -193,6 +193,7 @@ Use shutdown_timeout instead.`) } c.SSE.AllowedOrigins = c.WS.AllowedOrigins + c.HTTPBroadcast.CORSHosts = c.WS.AllowedOrigins if turboRailsKey != "" { fmt.Println(`DEPRECATION WARNING: turbo_rails_key option is deprecated @@ -602,6 +603,12 @@ func httpBroadcastCLIFlags(c *config.Config) []cli.Flag { Destination: &c.HTTPBroadcast.Secret, Hidden: true, }, + + &cli.BoolFlag{ + Name: "http_broadcast_cors", + Destination: &c.HTTPBroadcast.AddCORSHeaders, + Hidden: true, + }, }) }