From 481f435dfe2c568ec2b81b4e43551712d497151c Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 23 Aug 2024 09:45:45 -0600 Subject: [PATCH] treewide: replace gorilla/mux with http.ServeMux Signed-off-by: Sumner Evans --- appservice/appservice.go | 17 ++- appservice/http.go | 12 +-- bridgev2/matrix/connector.go | 5 +- bridgev2/matrix/provisioning.go | 101 ++++++++--------- bridgev2/matrix/publicmedia.go | 11 +- bridgev2/matrixinterface.go | 5 +- crypto/verificationhelper/mockserver_test.go | 53 +++++---- federation/keyserver.go | 47 ++++---- go.mod | 1 - go.sum | 4 +- mediaproxy/mediaproxy.go | 107 +++++++++---------- 11 files changed, 175 insertions(+), 188 deletions(-) diff --git a/appservice/appservice.go b/appservice/appservice.go index 518e1073..3c62be9a 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -19,7 +19,6 @@ import ( "syscall" "time" - "github.com/gorilla/mux" "github.com/gorilla/websocket" "github.com/rs/zerolog" "golang.org/x/net/publicsuffix" @@ -43,7 +42,7 @@ func Create() *AppService { intents: make(map[id.UserID]*IntentAPI), HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar}, StateStore: mautrix.NewMemoryStateStore().(StateStore), - Router: mux.NewRouter(), + Router: http.NewServeMux(), UserAgent: mautrix.DefaultUserAgent, txnIDC: NewTransactionIDCache(128), Live: true, @@ -61,12 +60,12 @@ func Create() *AppService { DefaultHTTPRetries: 4, } - as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut) - as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet) - as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet) - as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost) - as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet) - as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet) + as.Router.HandleFunc("PUT /_matrix/app/v1/transactions/{txnID}", as.PutTransaction) + as.Router.HandleFunc("GET /_matrix/app/v1/rooms/{roomAlias}", as.GetRoom) + as.Router.HandleFunc("GET /_matrix/app/v1/users/{userID}", as.GetUser) + as.Router.HandleFunc("POST /_matrix/app/v1/ping", as.PostPing) + as.Router.HandleFunc("GET /_matrix/mau/live", as.GetLive) + as.Router.HandleFunc("GET /_matrix/mau/ready", as.GetReady) return as } @@ -160,7 +159,7 @@ type AppService struct { QueryHandler QueryHandler StateStore StateStore - Router *mux.Router + Router *http.ServeMux UserAgent string server *http.Server HTTPClient *http.Client diff --git a/appservice/http.go b/appservice/http.go index 661513b4..e724245e 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -17,7 +17,6 @@ import ( "syscall" "time" - "github.com/gorilla/mux" "github.com/rs/zerolog" "maunium.net/go/mautrix" @@ -101,8 +100,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { return } - vars := mux.Vars(r) - txnID := vars["txnID"] + txnID := r.PathValue("txnID") if len(txnID) == 0 { Error{ ErrorCode: ErrNoTransactionID, @@ -258,9 +256,7 @@ func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) { return } - vars := mux.Vars(r) - roomAlias := vars["roomAlias"] - ok := as.QueryHandler.QueryAlias(roomAlias) + ok := as.QueryHandler.QueryAlias(r.PathValue("roomAlias")) if ok { WriteBlankOK(w) } else { @@ -277,9 +273,7 @@ func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) { return } - vars := mux.Vars(r) - userID := id.UserID(vars["userID"]) - ok := as.QueryHandler.QueryUser(userID) + ok := as.QueryHandler.QueryUser(id.UserID(r.PathValue("userID"))) if ok { WriteBlankOK(w) } else { diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 35ef4a08..7ec92d01 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -13,6 +13,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "net/url" "os" "regexp" @@ -21,7 +22,6 @@ import ( "time" "unsafe" - "github.com/gorilla/mux" _ "github.com/lib/pq" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" @@ -222,7 +222,8 @@ func (br *Connector) GetPublicAddress() string { return br.Config.AppService.PublicAddress } -func (br *Connector) GetRouter() *mux.Router { +// TODO switch to http.ServeMux +func (br *Connector) GetRouter() *http.ServeMux { if br.GetPublicAddress() != "" { return br.AS.Router } diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 951e6df1..1725487d 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -17,10 +17,10 @@ import ( "sync" "time" - "github.com/gorilla/mux" "github.com/rs/xid" "github.com/rs/zerolog" "github.com/rs/zerolog/hlog" + "go.mau.fi/util/exhttp" "go.mau.fi/util/jsontime" "go.mau.fi/util/requestlog" @@ -38,7 +38,7 @@ type matrixAuthCacheEntry struct { } type ProvisioningAPI struct { - Router *mux.Router + Router *http.ServeMux br *Connector log zerolog.Logger @@ -82,12 +82,12 @@ func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User { return r.Context().Value(provisioningUserKey).(*bridgev2.User) } -func (prov *ProvisioningAPI) GetRouter() *mux.Router { +func (prov *ProvisioningAPI) GetRouter() *http.ServeMux { return prov.Router } type IProvisioningAPI interface { - GetRouter() *mux.Router + GetRouter() *http.ServeMux GetUser(r *http.Request) *bridgev2.User } @@ -106,50 +106,44 @@ func (prov *ProvisioningAPI) Init() { tp.Dialer.Timeout = 10 * time.Second tp.Transport.ResponseHeaderTimeout = 10 * time.Second tp.Transport.TLSHandshakeTimeout = 10 * time.Second - prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter() - prov.Router.Use(hlog.NewHandler(prov.log)) - prov.Router.Use(hlog.RequestIDHandler("request_id", "Request-Id")) - prov.Router.Use(corsMiddleware) - prov.Router.Use(requestlog.AccessLogger(false)) - prov.Router.Use(prov.AuthMiddleware) - prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami) - prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows) - prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart) - prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginSubmitInput) - prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginWait) - prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLogout) - prov.Router.Path("/v3/logins").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLogins) - prov.Router.Path("/v3/contacts").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetContactList) - prov.Router.Path("/v3/search_users").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostSearchUsers) - prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetResolveIdentifier) - prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM) - prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup) + + provRouter := http.NewServeMux() + + provRouter.HandleFunc("GET /v3/whoami", prov.GetWhoami) + provRouter.HandleFunc("GET /v3/whoami/flows", prov.GetLoginFlows) + + provRouter.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart) + provRouter.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}", prov.PostLogin) + provRouter.HandleFunc("POST /v3/logout/{loginID}", prov.PostLogout) + provRouter.HandleFunc("GET /v3/logins", prov.GetLogins) + provRouter.HandleFunc("GET /v3/contacts", prov.GetContactList) + provRouter.HandleFunc("POST /v3/search_users", prov.PostSearchUsers) + provRouter.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier) + provRouter.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM) + provRouter.HandleFunc("POST /v3/create_group", prov.PostCreateGroup) + + var provHandler http.Handler = prov.Router + provHandler = prov.AuthMiddleware(provHandler) + provHandler = requestlog.AccessLogger(false)(provHandler) + provHandler = exhttp.CORSMiddleware(provHandler) + provHandler = hlog.RequestIDHandler("request_id", "Request-Id")(provHandler) + provHandler = hlog.NewHandler(prov.log)(provHandler) + provHandler = http.StripPrefix(prov.br.Config.Provisioning.Prefix, provHandler) + prov.br.AS.Router.Handle(prov.br.Config.Provisioning.Prefix, provHandler) if prov.br.Config.Provisioning.DebugEndpoints { prov.log.Debug().Msg("Enabling debug API at /debug") - r := prov.br.AS.Router.PathPrefix("/debug").Subrouter() - r.Use(prov.DebugAuthMiddleware) - r.HandleFunc("/pprof/cmdline", pprof.Cmdline).Methods(http.MethodGet) - r.HandleFunc("/pprof/profile", pprof.Profile).Methods(http.MethodGet) - r.HandleFunc("/pprof/symbol", pprof.Symbol).Methods(http.MethodGet) - r.HandleFunc("/pprof/trace", pprof.Trace).Methods(http.MethodGet) - r.PathPrefix("/pprof/").HandlerFunc(pprof.Index) + debugRouter := http.NewServeMux() + // TODO do we need to strip prefix here? + debugRouter.HandleFunc("/debug/pprof", pprof.Index) + debugRouter.HandleFunc("GET /debug/pprof/trace", pprof.Trace) + debugRouter.HandleFunc("GET /debug/pprof/symbol", pprof.Symbol) + debugRouter.HandleFunc("GET /debug/pprof/profile", pprof.Profile) + debugRouter.HandleFunc("GET /debug/pprof/cmdline", pprof.Cmdline) + prov.br.AS.Router.Handle("/debug", prov.AuthMiddleware(debugRouter)) } } -func corsMiddleware(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization") - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusOK) - return - } - handler.ServeHTTP(w, r) - }) -} - func jsonResponse(w http.ResponseWriter, status int, response any) { w.Header().Add("Content-Type", "application/json") w.WriteHeader(status) @@ -270,7 +264,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { } ctx := context.WithValue(r.Context(), provisioningUserKey, user) - if loginID, ok := mux.Vars(r)["loginProcessID"]; ok { + if loginID := r.PathValue("loginProcessID"); loginID != "" { prov.loginsLock.RLock() login, ok := prov.logins[loginID] prov.loginsLock.RUnlock() @@ -285,7 +279,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { login.Lock.Lock() // This will only unlock after the handler runs defer login.Lock.Unlock() - stepID := mux.Vars(r)["stepID"] + stepID := r.PathValue("stepID") if login.NextStep.StepID != stepID { zerolog.Ctx(r.Context()).Warn(). Str("request_step_id", stepID). @@ -297,7 +291,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { }) return } - stepType := mux.Vars(r)["stepType"] + stepType := r.PathValue("stepType") if login.NextStep.Type != bridgev2.LoginStepType(stepType) { zerolog.Ctx(r.Context()).Warn(). Str("request_step_type", stepType). @@ -401,7 +395,7 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque login, err := prov.net.CreateLogin( r.Context(), prov.GetUser(r), - mux.Vars(r)["flowID"], + r.PathValue("flowID"), ) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process") @@ -440,6 +434,17 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov }, bridgev2.DeleteOpts{LogoutRemote: true}) } +func (prov *ProvisioningAPI) PostLogin(w http.ResponseWriter, r *http.Request) { + switch r.PathValue("stepType") { + case "user_input", "cookies": + prov.PostLoginSubmitInput(w, r) + case "display_and_wait": + prov.PostLoginWait(w, r) + default: + panic("Impossible state") // checked by the AuthMiddleware + } +} + func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http.Request) { var params map[string]string err := json.NewDecoder(r.Body).Decode(¶ms) @@ -493,7 +498,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) { user := prov.GetUser(r) - userLoginID := networkid.UserLoginID(mux.Vars(r)["loginID"]) + userLoginID := networkid.UserLoginID(r.PathValue("loginID")) if userLoginID == "all" { for { login := user.GetDefaultLogin() @@ -596,7 +601,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. }) return } - resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat) + resp, err := api.ResolveIdentifier(r.Context(), r.PathValue("identifier"), createChat) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier") RespondWithError(w, err, "Internal error resolving identifier") diff --git a/bridgev2/matrix/publicmedia.go b/bridgev2/matrix/publicmedia.go index 9db5f442..0f3ae6f1 100644 --- a/bridgev2/matrix/publicmedia.go +++ b/bridgev2/matrix/publicmedia.go @@ -16,8 +16,6 @@ import ( "net/http" "time" - "github.com/gorilla/mux" - "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/id" ) @@ -35,7 +33,7 @@ func (br *Connector) initPublicMedia() error { return fmt.Errorf("public media hash length is negative") } br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey) - br.AS.Router.HandleFunc("/_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia).Methods(http.MethodGet) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia) return nil } @@ -76,16 +74,15 @@ var proxyHeadersToCopy = []string{ } func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) contentURI := id.ContentURI{ - Homeserver: vars["server"], - FileID: vars["mediaID"], + Homeserver: r.PathValue("server"), + FileID: r.PathValue("mediaID"), } if !contentURI.IsValid() { http.Error(w, "invalid content URI", http.StatusBadRequest) return } - checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"]) + checksum, err := base64.RawURLEncoding.DecodeString(r.PathValue("checksum")) if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) { http.Error(w, "invalid base64 in checksum", http.StatusBadRequest) return diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 66d39403..5e1addf1 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -10,11 +10,10 @@ import ( "context" "fmt" "io" + "net/http" "os" "time" - "github.com/gorilla/mux" - "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/database" @@ -58,7 +57,7 @@ type MatrixConnector interface { type MatrixConnectorWithServer interface { GetPublicAddress() string - GetRouter() *mux.Router + GetRouter() *http.ServeMux } type MatrixConnectorWithPublicMedia interface { diff --git a/crypto/verificationhelper/mockserver_test.go b/crypto/verificationhelper/mockserver_test.go index e35f51b2..6fdf2c85 100644 --- a/crypto/verificationhelper/mockserver_test.go +++ b/crypto/verificationhelper/mockserver_test.go @@ -12,11 +12,9 @@ import ( "io" "net/http" "net/http/httptest" - "net/url" "strings" "testing" - "github.com/gorilla/mux" "github.com/rs/zerolog/log" "github.com/stretchr/testify/require" "go.mau.fi/util/random" @@ -42,19 +40,19 @@ type mockServer struct { UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys } -func DecodeVarsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - var err error - for k, v := range vars { - vars[k], err = url.PathUnescape(v) - if err != nil { - panic(err) - } - } - next.ServeHTTP(w, r) - }) -} +// func DecodeVarsMiddleware(next http.Handler) http.Handler { +// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// vars := mux.Vars(r) +// var err error +// for k, v := range vars { +// vars[k], err = url.PathUnescape(v) +// if err != nil { +// panic(err) +// } +// } +// next.ServeHTTP(w, r) +// }) +// } func createMockServer(t *testing.T) *mockServer { t.Helper() @@ -69,15 +67,14 @@ func createMockServer(t *testing.T) *mockServer { UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, } - router := mux.NewRouter().SkipClean(true).StrictSlash(false).UseEncodedPath() - router.Use(DecodeVarsMiddleware) - router.HandleFunc("/_matrix/client/v3/login", server.postLogin).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/keys/query", server.postKeysQuery).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice).Methods(http.MethodPut) - router.HandleFunc("/_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData).Methods(http.MethodPut) - router.HandleFunc("/_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/keys/signatures/upload", server.emptyResp).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/keys/upload", server.postKeysUpload).Methods(http.MethodPost) + router := http.NewServeMux() + router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin) + router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery) + router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice) + router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData) + router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload) + router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp) + router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload) server.Server = httptest.NewServer(router) return &server @@ -118,10 +115,9 @@ func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) { } func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) var req mautrix.ReqSendToDevice json.NewDecoder(r.Body).Decode(&req) - evtType := event.Type{Type: vars["type"], Class: event.ToDeviceEventType} + evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType} for user, devices := range req.Messages { for device, content := range devices { @@ -140,9 +136,8 @@ func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { } func (s *mockServer) putAccountData(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - userID := id.UserID(vars["userID"]) - eventType := event.Type{Type: vars["type"], Class: event.AccountDataEventType} + userID := id.UserID(r.PathValue("userID")) + eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType} jsonData, _ := io.ReadAll(r.Body) if _, ok := s.AccountData[userID]; !ok { diff --git a/federation/keyserver.go b/federation/keyserver.go index 3e74bfdf..7c7f18c7 100644 --- a/federation/keyserver.go +++ b/federation/keyserver.go @@ -13,7 +13,7 @@ import ( "strconv" "time" - "github.com/gorilla/mux" + "go.mau.fi/util/exhttp" "go.mau.fi/util/jsontime" "maunium.net/go/mautrix" @@ -50,25 +50,32 @@ type KeyServer struct { } // Register registers the key server endpoints to the given router. -func (ks *KeyServer) Register(r *mux.Router) { - r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet) - r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet) - keyRouter := r.PathPrefix("/_matrix/key").Subrouter() - keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet) - keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet) - keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost) - keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "Unrecognized endpoint", - }) - }) - keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "Invalid method for endpoint", - }) +func (ks *KeyServer) Register(r *http.ServeMux) { + r.HandleFunc("GET /.well-known/matrix/server", ks.GetWellKnown) + r.HandleFunc("GET /_matrix/federation/v1/version", ks.GetServerVersion) + + keyRouter := http.NewServeMux() + keyRouter.HandleFunc("GET /v2/server", ks.GetServerKey) + keyRouter.HandleFunc("GET /v2/query/{serverName}", ks.GetQueryKeys) + keyRouter.HandleFunc("POST /v2/query", ks.PostQueryKeys) + + keyHandler := exhttp.HandleErrors(keyRouter, exhttp.ErrorBodyGenerators{ + NotFound: func() (body []byte) { + body, _ = json.Marshal(&mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "Unrecognized endpoint", + }) + return + }, + MethodNotAllowed: func() (body []byte) { + body, _ = json.Marshal(&mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "Invalid method for endpoint", + }) + return + }, }) + r.Handle("/_matrix/key", http.StripPrefix("/_matrix/key", keyHandler)) } func jsonResponse(w http.ResponseWriter, code int, data any) { @@ -177,7 +184,7 @@ type GetQueryKeysResponse struct { // // https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) { - serverName := mux.Vars(r)["serverName"] + serverName := r.PathValue("serverName") minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts") minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64) if err != nil && minimumValidUntilTSString != "" { diff --git a/go.mod b/go.mod index 8ef08be8..1738e7cf 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ toolchain go1.23.2 require ( filippo.io/edwards25519 v1.1.0 github.com/chzyer/readline v1.5.1 - github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.23 diff --git a/go.sum b/go.sum index ac8e03f6..967b2da8 100644 --- a/go.sum +++ b/go.sum @@ -13,8 +13,6 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= @@ -53,6 +51,8 @@ github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= go.mau.fi/util v0.8.1-0.20241003092848-3b49d3e0b9ee h1:/BGpUK7fzVyFgy5KBiyP7ktEDn20vzz/5FTngrXtIEE= go.mau.fi/util v0.8.1-0.20241003092848-3b49d3e0b9ee/go.mod h1:L9qnqEkhe4KpuYmILrdttKTXL79MwGLyJ4EOskWxO3I= +go.mau.fi/util v0.7.1-0.20240826142731-d642a8a8b6fb h1:5sx2bjPNqkKB/EJsIinnRhXXomMBP2+7nYRIptwDlp4= +go.mau.fi/util v0.7.1-0.20240826142731-d642a8a8b6fb/go.mod h1:WuAOOV0O/otkxGkFUvfv/XE2ztegaoyM15ovS6SYbf4= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index f2591428..e5cb0ead 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -21,8 +21,8 @@ import ( "strings" "time" - "github.com/gorilla/mux" "github.com/rs/zerolog" + "go.mau.fi/util/exhttp" "maunium.net/go/mautrix" "maunium.net/go/mautrix/federation" @@ -60,9 +60,9 @@ type MediaProxy struct { serverName string serverKey *federation.SigningKey - FederationRouter *mux.Router - LegacyMediaRouter *mux.Router - ClientMediaRouter *mux.Router + FederationRouter *http.ServeMux + LegacyMediaRouter *http.ServeMux + ClientMediaRouter *http.ServeMux } func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) { @@ -70,7 +70,8 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx if err != nil { return nil, err } - return &MediaProxy{ + + mp := &MediaProxy{ serverName: serverName, serverKey: parsed, GetMedia: getMedia, @@ -93,7 +94,29 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"), }, }, - }, nil + FederationRouter: http.NewServeMux(), + LegacyMediaRouter: http.NewServeMux(), + ClientMediaRouter: http.NewServeMux(), + } + + mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation) + mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion) + addClientRoutes := func(router *http.ServeMux, prefix string) { + router.HandleFunc("GET "+prefix+"/download/{serverName}/{mediaID}", mp.DownloadMedia) + router.HandleFunc("GET "+prefix+"/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia) + router.HandleFunc("GET "+prefix+"/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia) + router.HandleFunc("PUT "+prefix+"/upload/{serverName}/{mediaID}", mp.UploadNotSupported) + router.HandleFunc("POST "+prefix+"/upload", mp.UploadNotSupported) + router.HandleFunc("POST "+prefix+"/create", mp.UploadNotSupported) + router.HandleFunc("GET "+prefix+"/config", mp.UploadNotSupported) + router.HandleFunc("GET "+prefix+"/preview_url", mp.PreviewURLNotSupported) + } + addClientRoutes(mp.LegacyMediaRouter, "/v3") + addClientRoutes(mp.LegacyMediaRouter, "/r0") + addClientRoutes(mp.LegacyMediaRouter, "/v1") + addClientRoutes(mp.ClientMediaRouter, "") + + return mp, nil } type BasicConfig struct { @@ -123,7 +146,7 @@ type ServerConfig struct { } func (mp *MediaProxy) Listen(cfg ServerConfig) error { - router := mux.NewRouter() + router := http.NewServeMux() mp.RegisterRoutes(router) return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router) } @@ -140,50 +163,17 @@ func (mp *MediaProxy) DisallowProxying() { mp.ProxyClient = nil } -func (mp *MediaProxy) RegisterRoutes(router *mux.Router) { - if mp.FederationRouter == nil { - mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter() - } - if mp.LegacyMediaRouter == nil { - mp.LegacyMediaRouter = router.PathPrefix("/_matrix/media").Subrouter() - } - if mp.ClientMediaRouter == nil { - mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter() - } +func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux) { + legacyMediaHandler := exhttp.HandleErrors(mp.LegacyMediaRouter, exhttp.ErrorBodyGenerators{NotFound: mp.UnknownEndpoint, MethodNotAllowed: mp.UnsupportedMethod}) + federationHandler := exhttp.HandleErrors(mp.FederationRouter, exhttp.ErrorBodyGenerators{NotFound: mp.UnknownEndpoint, MethodNotAllowed: mp.UnsupportedMethod}) + clientMediaHandler := exhttp.HandleErrors(mp.ClientMediaRouter, exhttp.ErrorBodyGenerators{NotFound: mp.UnknownEndpoint, MethodNotAllowed: mp.UnsupportedMethod}) - mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet) - mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet) - addClientRoutes := func(router *mux.Router, prefix string) { - router.HandleFunc(prefix+"/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) - router.HandleFunc(prefix+"/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet) - router.HandleFunc(prefix+"/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) - router.HandleFunc(prefix+"/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut) - router.HandleFunc(prefix+"/upload", mp.UploadNotSupported).Methods(http.MethodPost) - router.HandleFunc(prefix+"/create", mp.UploadNotSupported).Methods(http.MethodPost) - router.HandleFunc(prefix+"/config", mp.UploadNotSupported).Methods(http.MethodGet) - router.HandleFunc(prefix+"/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet) - } - addClientRoutes(mp.LegacyMediaRouter, "/v3") - addClientRoutes(mp.LegacyMediaRouter, "/r0") - addClientRoutes(mp.LegacyMediaRouter, "/v1") - addClientRoutes(mp.ClientMediaRouter, "") - mp.LegacyMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) - mp.LegacyMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) - mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) - mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) - mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) - mp.ClientMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) - corsMiddleware := func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization") - w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none'; plugin-types application/pdf; style-src 'unsafe-inline'; object-src 'self';") - next.ServeHTTP(w, r) - }) - } - mp.LegacyMediaRouter.Use(corsMiddleware) - mp.ClientMediaRouter.Use(corsMiddleware) + legacyMediaHandler = exhttp.CORSMiddleware(legacyMediaHandler) + clientMediaHandler = exhttp.CORSMiddleware(clientMediaHandler) + + router.Handle("/_matrix/federation", http.StripPrefix("/_matrix/federation", federationHandler)) + router.Handle("/_matrix/media", http.StripPrefix("/_matrix/media", legacyMediaHandler)) + router.Handle("/_matrix/client/v1/media", http.StripPrefix("/_matrix/client/v1/media", clientMediaHandler)) mp.KeyServer.Register(router) } @@ -260,7 +250,7 @@ func (err *ResponseError) Error() string { var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax") func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse { - mediaID := mux.Vars(r)["mediaID"] + mediaID := r.PathValue("mediaID") resp, err := mp.GetMedia(r.Context(), mediaID) if err != nil { var respError *ResponseError @@ -342,8 +332,7 @@ func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Req func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { ctx := r.Context() log := zerolog.Ctx(ctx) - vars := mux.Vars(r) - if vars["serverName"] != mp.serverName { + if r.PathValue("serverName") != mp.serverName { jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ ErrCode: mautrix.MNotFound.ErrCode, Err: fmt.Sprintf("This is a media proxy at %q, other media downloads are not available here", mp.serverName), @@ -360,7 +349,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { // In any other case, redirect to the URL. isFederated := strings.HasPrefix(r.Header.Get("Authorization"), "X-Matrix") if mp.ProxyClient != nil && (r.URL.Query().Get("allow_redirect") != "true" || (mp.ForceProxyLegacyFederation && isFederated)) { - mp.proxyDownload(ctx, w, urlResp.URL, vars["fileName"]) + mp.proxyDownload(ctx, w, urlResp.URL, r.PathValue("fileName")) return } w.Header().Set("Location", urlResp.URL) @@ -409,16 +398,18 @@ func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Requ }) } -func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ +func (mp *MediaProxy) UnknownEndpoint() (body []byte) { + body, _ = json.Marshal(&mautrix.RespError{ ErrCode: mautrix.MUnrecognized.ErrCode, Err: "Unrecognized endpoint", }) + return } -func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{ +func (mp *MediaProxy) UnsupportedMethod() (body []byte) { + body, _ = json.Marshal(&mautrix.RespError{ ErrCode: mautrix.MUnrecognized.ErrCode, Err: "Invalid method for endpoint", }) + return }