Skip to content

Commit

Permalink
Merge pull request #132 from githubsands/primaryHandler
Browse files Browse the repository at this point in the history
Create primaryHandler object inplace of the duplicate caduceusHandler that handles the primary routes - authorization and notification(notify) of the webhooks
  • Loading branch information
schmidtw authored Mar 12, 2019
2 parents 46bc394 + e957bfa commit ea35303
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 176 deletions.
87 changes: 8 additions & 79 deletions src/caduceus/caduceus.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,21 @@ package main
import (
"crypto/tls"
"fmt"
"github.com/Comcast/webpa-common/service/servicecfg"
"github.com/go-kit/kit/log/level"
"net/http"
_ "net/http/pprof"
"net/url"
"os"
"os/signal"
"time"

"github.com/Comcast/webpa-common/service/servicecfg"
"github.com/go-kit/kit/log/level"

"github.com/Comcast/webpa-common/concurrent"
"github.com/Comcast/webpa-common/logging"
"github.com/Comcast/webpa-common/secure"
"github.com/Comcast/webpa-common/secure/handler"
"github.com/Comcast/webpa-common/secure/key"
"github.com/Comcast/webpa-common/server"
"github.com/Comcast/webpa-common/webhook"
"github.com/Comcast/webpa-common/webhook/aws"
"github.com/SermoDigital/jose/jwt"
"github.com/gorilla/mux"
"github.com/justinas/alice"
"github.com/spf13/pflag"
"github.com/spf13/viper"
)
Expand All @@ -48,48 +43,6 @@ const (
DEFAULT_KEY_ID = "current"
)

// getValidator returns validator for JWT tokens
func getValidator(v *viper.Viper) (validator secure.Validator, err error) {
var jwtVals []JWTValidator

v.UnmarshalKey("jwtValidators", &jwtVals)

// if a JWTKeys section was supplied, configure a JWS validator
// and append it to the chain of validators
validators := make(secure.Validators, 0, len(jwtVals))

for _, validatorDescriptor := range jwtVals {
var keyResolver key.Resolver
keyResolver, err = validatorDescriptor.Keys.NewResolver()
if err != nil {
validator = validators
return
}

validators = append(
validators,
secure.JWSValidator{
DefaultKeyId: DEFAULT_KEY_ID,
Resolver: keyResolver,
JWTValidators: []*jwt.Validator{validatorDescriptor.Custom.New()},
},
)
}

// TODO: This should really be part of the unmarshalled validators somehow
basicAuth := v.GetStringSlice("authHeader")
for _, authValue := range basicAuth {
validators = append(
validators,
secure.ExactMatchValidator(authValue),
)
}

validator = validators

return
}

// caduceus is the driver function for Caduceus. It performs everything main() would do,
// except for obtaining the command-line arguments (which are passed to it).

Expand Down Expand Up @@ -163,25 +116,12 @@ func caduceus(arguments []string) int {
maxOutstanding: 0,
}

validator, err := getValidator(v)
primaryHandler, err := NewPrimaryHandler(logger, v, serverWrapper)
if err != nil {
fmt.Fprintf(os.Stderr, "Validator error: %v\n", err)
return 1
}

authHandler := handler.AuthorizationHandler{
HeaderName: "Authorization",
ForbiddenStatusCode: 403,
Validator: validator,
Logger: logger,
}

caduceusHandler := alice.New(authHandler.Decorate)

router := mux.NewRouter()

router = configServerRouter(router, caduceusHandler, serverWrapper)

webhookFactory, err := webhook.NewFactory(v)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating new webhook factory: %s\n", err)
Expand All @@ -191,8 +131,8 @@ func caduceus(arguments []string) int {
webhookFactory.SetExternalUpdate(caduceusSenderWrapper.Update)

// register webhook end points for api
router.Handle("/hook", caduceusHandler.ThenFunc(webhookRegistry.UpdateRegistry))
router.Handle("/hooks", caduceusHandler.ThenFunc(webhookRegistry.GetRegistry))
primaryHandler.HandleFunc("/hook", webhookRegistry.UpdateRegistry)
primaryHandler.HandleFunc("/hooks", webhookRegistry.GetRegistry)

scheme := v.GetString("scheme")
if len(scheme) < 1 {
Expand All @@ -204,9 +144,9 @@ func caduceus(arguments []string) int {
Host: v.GetString("fqdn") + v.GetString("primary.address"),
}

webhookFactory.Initialize(router, selfURL, v.GetString("soa.provider"), webhookHandler, logger, metricsRegistry, nil)
webhookFactory.Initialize(primaryHandler, selfURL, v.GetString("soa.provider"), webhookHandler, logger, metricsRegistry, nil)

_, runnable, done := webPA.Prepare(logger, nil, metricsRegistry, router)
_, runnable, done := webPA.Prepare(logger, nil, metricsRegistry, primaryHandler)

waitGroup, shutdown, err := concurrent.Execute(runnable)
if err != nil {
Expand Down Expand Up @@ -287,17 +227,6 @@ func caduceus(arguments []string) int {
return 0
}

func configServerRouter(router *mux.Router, caduceusHandler alice.Chain, serverWrapper *ServerHandler) *mux.Router {
var singleContentType = func(r *http.Request, _ *mux.RouteMatch) bool {
return len(r.Header["Content-Type"]) == 1 //require single specification for Content-Type Header
}

router.Handle("/api/v3/notify", caduceusHandler.Then(serverWrapper)).Methods("POST").
HeadersRegexp("Content-Type", "application/msgpack").MatcherFunc(singleContentType)

return router
}

func main() {
os.Exit(caduceus(os.Args))
}
85 changes: 0 additions & 85 deletions src/caduceus/caduceus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,95 +17,10 @@
package main

import (
"github.com/Comcast/webpa-common/logging"
"github.com/Comcast/webpa-common/secure"
"github.com/Comcast/webpa-common/secure/handler"
"github.com/gorilla/mux"
"github.com/justinas/alice"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"net/http"
"net/http/httptest"
"os"
"testing"
)

func TestMain(m *testing.M) {
os.Exit(m.Run())
}

/*
Simply tests that no bad requests make it to the caduceus listener.
*/

func TestMuxServerConfig(t *testing.T) {
assert := assert.New(t)

logger := logging.DefaultLogger()
fakeHandler := new(mockHandler)
fakeHandler.On("HandleRequest", mock.AnythingOfType("int"),
mock.AnythingOfType("*wrp.Message")).Return().Once()

fakeEmptyRequests := new(mockCounter)
fakeErrorRequests := new(mockCounter)
fakeInvalidCount := new(mockCounter)

fakeQueueDepth := new(mockGauge)
fakeQueueDepth.On("Add", mock.AnythingOfType("float64")).Return().Times(2)

serverWrapper := &ServerHandler{
Logger: logger,
caduceusHandler: fakeHandler,
errorRequests: fakeErrorRequests,
emptyRequests: fakeEmptyRequests,
incomingQueueDepthMetric: fakeQueueDepth,
invalidCount: fakeInvalidCount,
}

authHandler := handler.AuthorizationHandler{Validator: nil}
caduceusHandler := alice.New(authHandler.Decorate)
router := configServerRouter(mux.NewRouter(), caduceusHandler, serverWrapper)

t.Run("TestMuxResponseCorrectMSP", func(t *testing.T) {
req := exampleRequest("1234", "application/msgpack", "/api/v3/notify")

req.Header.Set("Content-Type", "application/msgpack")
w := httptest.NewRecorder()

router.ServeHTTP(w, req)
resp := w.Result()

assert.Equal(http.StatusAccepted, resp.StatusCode)
})
}

func TestGetValidator(t *testing.T) {
assert := assert.New(t)

fakeViper := viper.New()

t.Run("TestAuthHeaderNotSet", func(t *testing.T) {
validator, err := getValidator(fakeViper)

assert.Nil(err)

validators := validator.(secure.Validators)
assert.Equal(0, len(validators))
})

t.Run("TestAuthHeaderSet", func(t *testing.T) {
expectedAuthHeader := []string{"Basic xxxxxxx"}
fakeViper.Set("authHeader", expectedAuthHeader)

validator, err := getValidator(fakeViper)

assert.Nil(err)

validators := validator.(secure.Validators)
assert.Equal(1, len(validators))

exactMatchValidator := validators[0].(secure.ExactMatchValidator)
assert.Equal(expectedAuthHeader[0], string(exactMatchValidator))
})
}
14 changes: 2 additions & 12 deletions src/caduceus/caduceus_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
package main

import (
"time"

"github.com/Comcast/webpa-common/logging"
"github.com/Comcast/webpa-common/secure"
"github.com/Comcast/webpa-common/secure/key"
"github.com/Comcast/webpa-common/wrp"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/metrics"
"time"
)

// Below is the struct we're using to contain the data from a provided config file
Expand All @@ -48,15 +47,6 @@ type SenderConfig struct {
DeliveryInterval time.Duration
}

type JWTValidator struct {
// JWTKeys is used to create the key.Resolver for JWT verification keys
Keys key.ResolverFactory

// Custom is an optional configuration section that defines
// custom rules for validation over and above the standard RFC rules.
Custom secure.JWTValidatorFactory
}

type CaduceusMetricsRegistry interface {
NewCounter(name string) metrics.Counter
NewGauge(name string) metrics.Gauge
Expand Down
102 changes: 102 additions & 0 deletions src/caduceus/primaryHandler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package main

import (
"fmt"
"net/http"

"github.com/Comcast/webpa-common/secure"
"github.com/Comcast/webpa-common/secure/handler"
"github.com/Comcast/webpa-common/secure/key"
"github.com/SermoDigital/jose/jwt"
"github.com/go-kit/kit/log"
"github.com/gorilla/mux"
"github.com/justinas/alice"
"github.com/spf13/viper"
)

const (
baseURI = "api"
version = "v3"
)

type JWTValidator struct {
// JWTKeys is used to create the key.Resolver for JWT verification keys
Keys key.ResolverFactory

// Custom is an optional configuration section that defines
// custom rules for validation over and above the standard RFC rules.
Custom secure.JWTValidatorFactory
}

func NewPrimaryHandler(l log.Logger, v *viper.Viper, sw *ServerHandler) (*mux.Router, error) {
var (
router = mux.NewRouter()
)

validator, err := getValidator(v)
if err != nil {
return nil, err
}

authHandler := handler.AuthorizationHandler{
HeaderName: "Authorization",
ForbiddenStatusCode: 403,
Validator: validator,
Logger: l,
}

authorizationDecorator := alice.New(authHandler.Decorate)

return configServerRouter(router, authorizationDecorator, sw), nil
}

func configServerRouter(router *mux.Router, primaryHandler alice.Chain, serverWrapper *ServerHandler) *mux.Router {
var singleContentType = func(r *http.Request, _ *mux.RouteMatch) bool {
return len(r.Header["Content-Type"]) == 1 //require single specification for Content-Type Header
}

router.Handle("/"+fmt.Sprintf("%s/%s", baseURI, version)+"/notify", primaryHandler.Then(serverWrapper)).Methods("POST").HeadersRegexp("Content-Type", "application/msgpack").MatcherFunc(singleContentType)

return router
}

func getValidator(v *viper.Viper) (validator secure.Validator, err error) {
var jwtVals []JWTValidator

v.UnmarshalKey("jwtValidators", &jwtVals)

// if a JWTKeys section was supplied, configure a JWS validator
// and append it to the chain of validators
validators := make(secure.Validators, 0, len(jwtVals))

for _, validatorDescriptor := range jwtVals {
var keyResolver key.Resolver
keyResolver, err = validatorDescriptor.Keys.NewResolver()
if err != nil {
validator = validators
return
}

validators = append(
validators,
secure.JWSValidator{
DefaultKeyId: DEFAULT_KEY_ID,
Resolver: keyResolver,
JWTValidators: []*jwt.Validator{validatorDescriptor.Custom.New()},
},
)
}

// TODO: This should really be part of the unmarshalled validators somehow
basicAuth := v.GetStringSlice("authHeader")
for _, authValue := range basicAuth {
validators = append(
validators,
secure.ExactMatchValidator(authValue),
)
}

validator = validators

return
}
Loading

0 comments on commit ea35303

Please sign in to comment.