diff --git a/CHANGELOG.md b/CHANGELOG.md index 973d00b0..29070cad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] +## [v0.3.8] +- Update code to abide by latest API spec in the main repo readme. [#71](https://github.com/xmidt-org/argus/pull/71) + ## [v0.3.7] ### Changed - Changes the PUT creation route to a POST. [#68](https://github.com/xmidt-org/argus/pull/68) @@ -82,7 +85,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [v0.1.0] Tue May 07 2020 Jack Murdock - 0.1.0 - initial creation -[Unreleased]: https://github.com/xmidt-org/argus/compare/v0.3.7...HEAD +[Unreleased]: https://github.com/xmidt-org/argus/compare/v0.3.8...HEAD +[v0.3.8]: https://github.com/xmidt-org/argus/compare/v0.3.7...v0.3.8 [v0.3.7]: https://github.com/xmidt-org/argus/compare/v0.3.6...v0.3.7 [v0.3.6]: https://github.com/xmidt-org/argus/compare/v0.3.5...v0.3.6 [v0.3.5]: https://github.com/xmidt-org/argus/compare/v0.3.4...v0.3.5 diff --git a/argus.yaml b/argus.yaml index 8a904ddd..93f15cff 100644 --- a/argus.yaml +++ b/argus.yaml @@ -84,7 +84,7 @@ servers: dynamo: # endpoint is used to set a custom aws endpoint. # (Optional) - endpoint: "http://localhost:8042" + endpoint: "http://localhost:8000" # table is the name of the table that is already configured with bucket and id as the key. table: "gifnoc" @@ -104,11 +104,6 @@ dynamo: # itemTTL configures the default time based ttls for each item. itemTTL: - # defaultTTL is used if not ttl is provided via the api. - # refer to https://golang.org/pkg/time/#ParseDuration for valid strings. - # (Optional) default: 5m - defaultTTL: "5m" - # maxTTL is limit the maxTTL provided via the api. # refer to https://golang.org/pkg/time/#ParseDuration for valid strings. # (Optional) default: 1y @@ -122,6 +117,18 @@ itemTTL: # WARNING! Be sure to remove this from your production config authHeader: ["dXNlcjpwYXNz"] +# request is a config section related to operation authorization +# and request validation. +request: + authorization: + # adminToken serves as a master key which allows performing operations on any + # item regardless of their ownership status. + adminToken: "Hzu1WpIe7S8G" + + validation: + # maxTTL specifies the cap for the TTL of items when values are specified. + maxTTL: "24h" + # jwtValidator provides Bearer auth configuration jwtValidator: keys: @@ -129,6 +136,7 @@ jwtValidator: uri: "http://sample-jwt-validator-uri/{keyId}" purpose: 0 updateInterval: 604800000000000 + # capabilityCheck provides the details needed for checking an incoming JWT's # capabilities. If the type of check isn't provided, no checking is done. The # type can be "monitor" or "enforce". If it is empty or a different value, no diff --git a/chrysom/client.go b/chrysom/client.go index b4276ce6..898cc409 100644 --- a/chrysom/client.go +++ b/chrysom/client.go @@ -31,19 +31,30 @@ import ( "github.com/go-kit/kit/log/level" "github.com/go-kit/kit/metrics/provider" "github.com/xmidt-org/argus/model" + "github.com/xmidt-org/argus/store" "github.com/xmidt-org/bascule/acquire" ) +// PushResult is a simple type to indicate the result type for the +// PushItem operation. +type PushResult string + +// Types of pushItem successful results. +const ( + CreatedPushResult PushResult = "created" + UpdatedPushResult PushResult = "ok" +) + type ClientConfig struct { HTTPClient *http.Client Bucket string PullInterval time.Duration Address string Auth Auth - DefaultTTL int64 MetricsProvider provider.Provider Logger log.Logger Listener Listener + AdminToken string } type Auth struct { @@ -58,15 +69,15 @@ type loggerGroup struct { } type Client struct { - client *http.Client - ticker *time.Ticker - auth acquire.Acquirer - metrics *measures - listener Listener - bucketName string - remoteStoreAddress string - defaultStoreItemTTL int64 - loggers loggerGroup + client *http.Client + ticker *time.Ticker + auth acquire.Acquirer + metrics *measures + listener Listener + bucketName string + remoteStoreAddress string + loggers loggerGroup + adminToken string } func initLoggers(logger log.Logger) loggerGroup { @@ -93,15 +104,15 @@ func CreateClient(config ClientConfig) (*Client, error) { return nil, err } clientStore := &Client{ - client: config.HTTPClient, - ticker: time.NewTicker(config.PullInterval), - auth: auth, - metrics: initMetrics(config.MetricsProvider), - loggers: initLoggers(config.Logger), - listener: config.Listener, - remoteStoreAddress: config.Address, - defaultStoreItemTTL: config.DefaultTTL, - bucketName: config.Bucket, + client: config.HTTPClient, + ticker: time.NewTicker(config.PullInterval), + auth: auth, + metrics: initMetrics(config.MetricsProvider), + loggers: initLoggers(config.Logger), + listener: config.Listener, + remoteStoreAddress: config.Address, + bucketName: config.Bucket, + adminToken: config.AdminToken, } if config.PullInterval > 0 { @@ -121,18 +132,20 @@ func validateConfig(config *ClientConfig) error { if config.Bucket == "" { config.Bucket = "testing" } - if config.DefaultTTL < 1 { - config.DefaultTTL = 300 - } if config.MetricsProvider == nil { return errors.New("a metrics provider is required") } + if config.PullInterval == 0 { + config.PullInterval = time.Second * 5 + } + if config.Logger == nil { config.Logger = log.NewNopLogger() } return nil } + func determineTokenAcquirer(config ClientConfig) (acquire.Acquirer, error) { defaultAcquirer := &acquire.DefaultAcquirer{} if config.Auth.JWT.AuthURL != "" && config.Auth.JWT.Buffer != 0 && config.Auth.JWT.Timeout != 0 { @@ -146,38 +159,46 @@ func determineTokenAcquirer(config ClientConfig) (acquire.Acquirer, error) { return defaultAcquirer, nil } -func (c *Client) GetItems(owner string) ([]model.Item, error) { +func (c *Client) GetItems(owner string, adminMode bool) ([]model.Item, error) { request, err := http.NewRequest("GET", fmt.Sprintf("%s/api/v1/store/%s", c.remoteStoreAddress, c.bucketName), nil) if err != nil { - return []model.Item{}, err + return nil, err } err = acquire.AddAuth(request, c.auth) if err != nil { - return []model.Item{}, err + return nil, err } if owner != "" { - request.Header.Add("X-Midt-Owner", owner) + request.Header.Set(store.ItemOwnerHeaderKey, owner) + } + + if adminMode { + if c.adminToken == "" { + return nil, errors.New("adminToken needed to run as admin") + } + request.Header.Set(store.AdminTokenHeaderKey, c.adminToken) } + response, err := c.client.Do(request) if err != nil { - return []model.Item{}, err + return nil, err } if response.StatusCode == 404 { return []model.Item{}, nil } if response.StatusCode != 200 { c.loggers.Error.Log("msg", "DB responded with non-200 response for request to get items", "code", response.StatusCode) - return []model.Item{}, errors.New("failed to get items, non 200 statuscode") + return nil, errors.New("failed to get items, non 200 statuscode") } data, err := ioutil.ReadAll(response.Body) if err != nil { - return []model.Item{}, err + return nil, err } body := map[string]model.Item{} err = json.Unmarshal(data, &body) if err != nil { - return []model.Item{}, err + return nil, err } responseData := make([]model.Item, len(body)) @@ -189,57 +210,24 @@ func (c *Client) GetItems(owner string) ([]model.Item, error) { return responseData, nil } -func (c *Client) Push(item model.Item, owner string) (string, error) { +func (c *Client) Push(item model.Item, owner string, adminMode bool) (PushResult, error) { if item.Identifier == "" { return "", errors.New("identifier can't be empty") } - if item.TTL < 1 { - item.TTL = c.defaultStoreItemTTL - } - data, err := json.Marshal(&item) - if err != nil { - return "", err - } - request, err := http.NewRequest("POST", fmt.Sprintf("%s/api/v1/store/%s", c.remoteStoreAddress, c.bucketName), bytes.NewReader(data)) - if err != nil { - return "", err - } - err = acquire.AddAuth(request, c.auth) - if err != nil { - return "", err - } - if owner != "" { - request.Header.Add("X-Midt-Owner", owner) - } - response, err := c.client.Do(request) - if err != nil { - return "", err - } - if response.StatusCode != 200 { - c.loggers.Error.Log("msg", "DB responded with non-200 response for request to add/update an item", "code", response.StatusCode) - return "", errors.New("Failed to put item as DB responded with non-200 statuscode") - } - responsePayload, _ := ioutil.ReadAll(response.Body) - key := model.Key{} - err = json.Unmarshal(responsePayload, &key) - if err != nil { - return "", err - } - return key.ID, nil -} -func (c *Client) Update(item model.Item, id string, owner string) (string, error) { - if item.Identifier == "" { - return "", errors.New("identifier can't be empty") + if item.UUID == "" { + return "", errors.New("uuid can't be empty") } - if item.TTL < 1 { - item.TTL = c.defaultStoreItemTTL + + if item.TTL != nil && *item.TTL < 1 { + return "", errors.New("when provided, TTL must be > 0") } + data, err := json.Marshal(&item) if err != nil { return "", err } - request, err := http.NewRequest("PUT", fmt.Sprintf("%s/api/v1/store/%s/%s", c.remoteStoreAddress, c.bucketName, id), bytes.NewReader(data)) + request, err := http.NewRequest("PUT", fmt.Sprintf("%s/api/v1/store/%s/%s", c.remoteStoreAddress, c.bucketName, item.UUID), bytes.NewReader(data)) if err != nil { return "", err } @@ -247,29 +235,36 @@ func (c *Client) Update(item model.Item, id string, owner string) (string, error if err != nil { return "", err } - if owner != "" { - request.Header.Add("X-Midt-Owner", owner) + request.Header.Add(store.ItemOwnerHeaderKey, owner) + + if adminMode { + if c.adminToken == "" { + return "", errors.New("adminToken needed to run as admin") + } + request.Header.Set(store.AdminTokenHeaderKey, c.adminToken) } + response, err := c.client.Do(request) if err != nil { return "", err } - if response.StatusCode != 200 { - c.loggers.Error.Log("msg", "DB responded with non-200 response for request to update an item", "code", response.StatusCode) - return "", errors.New("Failed to put item as DB responded with non-200 statuscode") - } - responsePayload, _ := ioutil.ReadAll(response.Body) - key := model.Key{} - err = json.Unmarshal(responsePayload, &key) - if err != nil { - return "", err + + switch response.StatusCode { + case http.StatusCreated: + return CreatedPushResult, nil + case http.StatusOK: + return UpdatedPushResult, nil } - return key.ID, nil + c.loggers.Error.Log("msg", "DB responded with non-successful response for request to update an item", "code", response.StatusCode) + return "", errors.New("Failed to set item as DB responded with non-success statuscode") } -func (c *Client) Remove(id string, owner string) (model.Item, error) { - request, err := http.NewRequest("DELETE", fmt.Sprintf("%s/api/v1/store/%s/%s", c.remoteStoreAddress, c.bucketName, id), nil) +func (c *Client) Remove(uuid string, owner string, adminMode bool) (model.Item, error) { + if uuid == "" { + return model.Item{}, errors.New("uuid can't be empty") + } + request, err := http.NewRequest("DELETE", fmt.Sprintf("%s/api/v1/store/%s/%s", c.remoteStoreAddress, c.bucketName, uuid), nil) if err != nil { return model.Item{}, err } @@ -277,9 +272,16 @@ func (c *Client) Remove(id string, owner string) (model.Item, error) { if err != nil { return model.Item{}, err } - if owner != "" { - request.Header.Add("X-Midt-Owner", owner) + + request.Header.Add(store.ItemOwnerHeaderKey, owner) + + if adminMode { + if c.adminToken == "" { + return model.Item{}, errors.New("adminToken needed to run as admin") + } + request.Header.Set(store.AdminTokenHeaderKey, c.adminToken) } + response, err := c.client.Do(request) if err != nil { return model.Item{}, err @@ -309,7 +311,7 @@ func (c *Client) Start(ctx context.Context) error { go func() { for range c.ticker.C { outcome := SuccessOutcome - items, err := c.GetItems("") + items, err := c.GetItems("", true) if err == nil { c.listener.Update(items) } else { diff --git a/chrysom/store.go b/chrysom/store.go index b22fa32c..fc3cd83d 100644 --- a/chrysom/store.go +++ b/chrysom/store.go @@ -32,10 +32,10 @@ type PushReader interface { type Pusher interface { // Push applies user configurable for registering an item returning the id // i.e. updated the storage with said item. - Push(item model.Item, owner string) (string, error) + Push(item model.Item, owner string, adminMode bool) (PushResult, error) // Remove will remove the item from the store - Remove(id string, owner string) (model.Item, error) + Remove(uuid string, owner string, adminMode bool) (model.Item, error) } type Listener interface { @@ -54,7 +54,7 @@ func (listener ListenerFunc) Update(items []model.Item) { type Reader interface { // GeItems will return all the current items or an error. - GetItems(owner string) ([]model.Item, error) + GetItems(owner string, adminMode bool) ([]model.Item, error) Start(ctx context.Context) error diff --git a/go.mod b/go.mod index 9c122b3e..26a7b717 100644 --- a/go.mod +++ b/go.mod @@ -9,9 +9,9 @@ require ( github.com/go-kit/kit v0.9.0 github.com/go-playground/validator/v10 v10.3.0 github.com/gocql/gocql v0.0.0-20200505093417-effcbd8bcf0e - github.com/google/uuid v1.1.2 github.com/gorilla/mux v1.7.3 github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed + github.com/influxdata/influxdb1-client v0.0.0-20200827194710-b269163b24ab // indirect github.com/justinas/alice v1.2.0 github.com/prometheus/client_golang v1.4.1 github.com/spf13/pflag v1.0.5 diff --git a/go.sum b/go.sum index c53a25a4..4409ae1d 100644 --- a/go.sum +++ b/go.sum @@ -103,8 +103,6 @@ github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= -github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/goph/emperror v0.17.1/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= github.com/goph/emperror v0.17.3-0.20190703203600-60a8d9faa17b h1:3/cwc6wu5QADzKEW2HP7+kZpKgm7OHysQ3ULVVQzQhs= github.com/goph/emperror v0.17.3-0.20190703203600-60a8d9faa17b/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= @@ -139,6 +137,8 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/influxdata/influxdb v1.5.1-0.20180921190457-8d679cf0c36e/go.mod h1:qZna6X/4elxqT3yI9iZYdZrWWdeFOOprn86kgg4+IzY= github.com/influxdata/influxdb1-client v0.0.0-20200515024757-02f0bf5dbca3 h1:k3/6a1Shi7GGCp9QpyYuXsMM6ncTOjCzOE9Fd6CDA+Q= github.com/influxdata/influxdb1-client v0.0.0-20200515024757-02f0bf5dbca3/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= +github.com/influxdata/influxdb1-client v0.0.0-20200827194710-b269163b24ab h1:HqW4xhhynfjrtEiiSGcQUd6vrK23iMam1FO8rI7mwig= +github.com/influxdata/influxdb1-client v0.0.0-20200827194710-b269163b24ab/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= diff --git a/model/model.go b/model/model.go index 7858d66f..6974da8a 100644 --- a/model/model.go +++ b/model/model.go @@ -22,18 +22,25 @@ type Key struct { // Bucket is a collection of items. Bucket string `json:"bucket"` - // ID is the unique ID for an item in a bucket - ID string `json:"id"` + // UUID is the unique ID for an item in a bucket. + UUID string `json:"uuid"` } // Item defines the abstract item to be stored. type Item struct { - // Identifier is how the client refers to the object. + // UUID is the unique ID identifying this item. It is recommended this value is the resulting + // value of a SHA256 calculation, using the unique attributes of the object being represented + // (e.g. SHA256()). This will be used by argus to determine uniqueness of objects being stored or updated. + UUID string `json:"uuid"` + + // Identifier is the common name of the provided resource. There is no enforcement of uniqueness + // across resource of this type. Identifier string `json:"identifier"` - // Data is an abstract json object + // Data is the JSON object to be stored. Opaque to argus. Data map[string]interface{} `json:"data"` - // TTL is the time to live in storage. If not provided and if the storage requires it the default configuration will be used. - TTL int64 `json:"ttl,omitempty"` + // TTL is the time to live in storage, specified in seconds. + // Optional. When not set, items don't expire. + TTL *int64 `json:"ttl,omitempty"` } diff --git a/routes.go b/routes.go index 49ae2b5f..3547a7ef 100644 --- a/routes.go +++ b/routes.go @@ -92,8 +92,7 @@ type PrimaryRouter struct { type PrimaryRoutes struct { fx.In - Push store.Handler `name:"pushHandler"` - Update store.Handler `name:"updateHandler"` + Set store.Handler `name:"setHandler"` Delete store.Handler `name:"deleteHandler"` Get store.Handler `name:"getHandler"` GetAll store.Handler `name:"getAllHandler"` @@ -101,20 +100,17 @@ type PrimaryRoutes struct { func BuildPrimaryRoutes(router PrimaryRouter, routes PrimaryRoutes) { router.Router.Use(router.AuthChain.Then) - if routes.Push != nil { - router.Router.Handle(fmt.Sprintf("/%s/store/{bucket}", apiBase), routes.Push).Methods(http.MethodPost) - } - if routes.Update != nil { - router.Router.Handle(fmt.Sprintf("/%s/store/{bucket}/{id}", apiBase), routes.Update).Methods(http.MethodPut) + if routes.Set != nil { + router.Router.Handle(fmt.Sprintf("/%s/store/{bucket}/{uuid}", apiBase), routes.Set).Methods(http.MethodPut) } if routes.Get != nil { - router.Router.Handle(fmt.Sprintf("/%s/store/{bucket}/{id}", apiBase), routes.Get).Methods(http.MethodGet) + router.Router.Handle(fmt.Sprintf("/%s/store/{bucket}/{uuid}", apiBase), routes.Get).Methods(http.MethodGet) } if routes.GetAll != nil { router.Router.Handle(fmt.Sprintf("/%s/store/{bucket}", apiBase), routes.GetAll).Methods(http.MethodGet) } if routes.Delete != nil { - router.Router.Handle(fmt.Sprintf("/%s/store/{bucket}/{id}", apiBase), routes.Delete).Methods(http.MethodDelete) + router.Router.Handle(fmt.Sprintf("/%s/store/{bucket}/{uuid}", apiBase), routes.Delete).Methods(http.MethodDelete) } } diff --git a/store/cassandra/executer.go b/store/cassandra/executer.go index 992b0fa7..433470a7 100644 --- a/store/cassandra/executer.go +++ b/store/cassandra/executer.go @@ -20,6 +20,7 @@ package cassandra import ( "encoding/json" "errors" + "github.com/go-kit/kit/log" "github.com/gocql/gocql" "github.com/hailocab/go-hostpool" @@ -60,7 +61,7 @@ func (s *cassandraExecutor) Push(key model.Key, item store.OwnableItem) error { return err } - return s.session.Query("INSERT INTO gifnoc (bucket, id, data) VALUES (?,?,?) USING TTL ?", key.Bucket, key.ID, data, item.TTL).Exec() + return s.session.Query("INSERT INTO gifnoc (bucket, id, data) VALUES (?,?,?) USING TTL ?", key.Bucket, key.UUID, data, item.TTL).Exec() } func (s *cassandraExecutor) Get(key model.Key) (store.OwnableItem, error) { @@ -68,17 +69,17 @@ func (s *cassandraExecutor) Get(key model.Key) (store.OwnableItem, error) { data []byte ttl int64 ) - iter := s.session.Query("SELECT data, ttl(data) from gifnoc WHERE bucket = ? AND id = ?", key.Bucket, key.ID).Iter() + iter := s.session.Query("SELECT data, ttl(data) from gifnoc WHERE bucket = ? AND id = ?", key.Bucket, key.UUID).Iter() defer func() { err := iter.Close() if err != nil { - logging.Error(s.logger).Log(logging.MessageKey(), "failed to close iter ", "bucket", key.Bucket, "id", key.ID) + logging.Error(s.logger).Log(logging.MessageKey(), "failed to close iter ", "bucket", key.Bucket, "id", key.UUID) } }() for iter.Scan(&data, &ttl) { item := store.OwnableItem{} err := json.Unmarshal(data, &item) - item.TTL = ttl + item.TTL = &ttl return item, err } return store.OwnableItem{}, noDataResponse @@ -89,7 +90,7 @@ func (s *cassandraExecutor) Delete(key model.Key) (store.OwnableItem, error) { if err != nil { return item, err } - err = s.session.Query("DELETE from gifnoc WHERE bucket = ? AND id = ?", key.Bucket, key.ID).Exec() + err = s.session.Query("DELETE from gifnoc WHERE bucket = ? AND id = ?", key.Bucket, key.UUID).Exec() return item, err } @@ -110,7 +111,7 @@ func (s *cassandraExecutor) GetAll(bucket string) (map[string]store.OwnableItem, key = "" continue } - item.TTL = ttl + item.TTL = &ttl result[key] = item } err := iter.Close() diff --git a/store/dynamodb/service.go b/store/dynamodb/service.go index d68935f2..714ccf1f 100644 --- a/store/dynamodb/service.go +++ b/store/dynamodb/service.go @@ -58,9 +58,9 @@ type executor struct { tableName string } -type element struct { +type storableItem struct { store.OwnableItem - Expires int64 `json:"expires"` + Expires *int64 `json:"expires,omitempty"` model.Key } @@ -71,8 +71,10 @@ var retryableAWSCodes = map[string]bool{ // Dynamo DB attribute keys const ( - bucketAttributeKey = "bucket" - idAttributeKey = "id" + bucketAttributeKey = "bucket" + uuidAttributeKey = "uuid" + identifierAttributeKey = "identifier" + expirationAttributeKey = "expires" ) func handleClientError(err error) error { @@ -87,12 +89,17 @@ func handleClientError(err error) error { } func (d *executor) Push(key model.Key, item store.OwnableItem) (*dynamodb.ConsumedCapacity, error) { - expirableItem := element{ + storingItem := storableItem{ OwnableItem: item, - Expires: time.Now().Unix() + item.TTL, Key: key, } - av, err := dynamodbattribute.MarshalMap(expirableItem) + + if item.TTL != nil { + unixExpSeconds := time.Now().Unix() + *item.TTL + storingItem.Expires = &unixExpSeconds + } + + av, err := dynamodbattribute.MarshalMap(storingItem) if err != nil { return nil, err } @@ -113,66 +120,80 @@ func (d *executor) Push(key model.Key, item store.OwnableItem) (*dynamodb.Consum } return consumedCapacity, nil } - -func (d *executor) Get(key model.Key) (store.OwnableItem, *dynamodb.ConsumedCapacity, error) { - result, err := d.c.GetItem(&dynamodb.GetItemInput{ +func (d *executor) executeGetOrDelete(key model.Key, delete bool) (*dynamodb.ConsumedCapacity, map[string]*dynamodb.AttributeValue, error) { + if delete { + deleteInput := &dynamodb.DeleteItemInput{ + TableName: aws.String(d.tableName), + Key: map[string]*dynamodb.AttributeValue{ + bucketAttributeKey: { + S: aws.String(key.Bucket), + }, + uuidAttributeKey: { + S: aws.String(key.UUID), + }, + }, + ReturnConsumedCapacity: aws.String(dynamodb.ReturnConsumedCapacityTotal), + ReturnValues: aws.String(dynamodb.ReturnValueAllOld), + } + deleteOutput, err := d.c.DeleteItem(deleteInput) + if err != nil { + return nil, nil, err + } + return deleteOutput.ConsumedCapacity, deleteOutput.Attributes, nil + } + getInput := &dynamodb.GetItemInput{ TableName: aws.String(d.tableName), Key: map[string]*dynamodb.AttributeValue{ bucketAttributeKey: { S: aws.String(key.Bucket), }, - idAttributeKey: { - S: aws.String(key.ID), + uuidAttributeKey: { + S: aws.String(key.UUID), }, }, ReturnConsumedCapacity: aws.String(dynamodb.ReturnConsumedCapacityTotal), - }) - - var consumedCapacity *dynamodb.ConsumedCapacity - if result != nil { - consumedCapacity = result.ConsumedCapacity } + getOutput, err := d.c.GetItem(getInput) if err != nil { - return store.OwnableItem{}, consumedCapacity, handleClientError(err) + return nil, nil, err } - var expirableItem element - err = dynamodbattribute.UnmarshalMap(result.Item, &expirableItem) - expirableItem.OwnableItem.TTL = int64(time.Unix(expirableItem.Expires, 0).Sub(time.Now()).Seconds()) - if expirableItem.Key.Bucket == "" || expirableItem.Key.ID == "" { - return expirableItem.OwnableItem, consumedCapacity, store.KeyNotFoundError{Key: key} - } - return expirableItem.OwnableItem, consumedCapacity, err + return getOutput.ConsumedCapacity, getOutput.Item, nil } - -func (d *executor) Delete(key model.Key) (store.OwnableItem, *dynamodb.ConsumedCapacity, error) { - result, err := d.c.DeleteItem(&dynamodb.DeleteItemInput{ - Key: map[string]*dynamodb.AttributeValue{ - bucketAttributeKey: { - S: aws.String(key.Bucket), - }, - idAttributeKey: { - S: aws.String(key.ID), - }, - }, - ReturnValues: aws.String(dynamodb.ReturnValueAllOld), - TableName: aws.String(d.tableName), - ReturnConsumedCapacity: aws.String(dynamodb.ReturnConsumedCapacityTotal), - }) - var consumedCapacity *dynamodb.ConsumedCapacity - if result != nil { - consumedCapacity = result.ConsumedCapacity - } +func (d *executor) getOrDelete(key model.Key, delete bool) (store.OwnableItem, *dynamodb.ConsumedCapacity, error) { + consumedCapacity, attributes, err := d.executeGetOrDelete(key, delete) if err != nil { return store.OwnableItem{}, consumedCapacity, handleClientError(err) } + item := new(storableItem) + err = dynamodbattribute.UnmarshalMap(attributes, item) + if err != nil { + return store.OwnableItem{}, consumedCapacity, err + } - var expirableItem element - err = dynamodbattribute.UnmarshalMap(result.Attributes, &expirableItem) - expirableItem.OwnableItem.TTL = int64(time.Unix(expirableItem.Expires, 0).Sub(time.Now()).Seconds()) - if expirableItem.Key.Bucket == "" || expirableItem.Key.ID == "" { - return expirableItem.OwnableItem, consumedCapacity, store.KeyNotFoundError{Key: key} + if itemNotFound(item) { + return item.OwnableItem, consumedCapacity, store.KeyNotFoundError{Key: key} } - return expirableItem.OwnableItem, consumedCapacity, err + + if item.Expires != nil { + remainingTTLSeconds := int64(time.Unix(*item.Expires, 0).Sub(time.Now()).Seconds()) + if remainingTTLSeconds < 1 { + return item.OwnableItem, consumedCapacity, store.KeyNotFoundError{Key: key} + } + item.TTL = &remainingTTLSeconds + } + + item.OwnableItem.UUID = key.UUID + + return item.OwnableItem, consumedCapacity, err + +} + +func (d *executor) Get(key model.Key) (store.OwnableItem, *dynamodb.ConsumedCapacity, error) { + return d.getOrDelete(key, false) +} + +func (d *executor) Delete(key model.Key) (store.OwnableItem, *dynamodb.ConsumedCapacity, error) { + return d.getOrDelete(key, true) } //TODO: For data >= 1MB, we'll need to handle pagination @@ -202,19 +223,34 @@ func (d *executor) GetAll(bucket string) (map[string]store.OwnableItem, *dynamod } for _, i := range queryResult.Items { - var expirableItem element - err = dynamodbattribute.UnmarshalMap(i, &expirableItem) + item := new(storableItem) + err = dynamodbattribute.UnmarshalMap(i, item) if err != nil { //logging.Error(d.logger).Log(logging.MessageKey(), "failed to unmarshal item", logging.ErrorKey(), err) continue } - expirableItem.OwnableItem.TTL = int64(time.Unix(expirableItem.Expires, 0).Sub(time.Now()).Seconds()) + if itemNotFound(item) { + continue + } + + if item.Expires != nil { + remainingTTLSeconds := int64(time.Unix(*item.Expires, 0).Sub(time.Now()).Seconds()) + if remainingTTLSeconds < 1 { + continue + } + item.TTL = &remainingTTLSeconds + } + item.OwnableItem.UUID = item.Key.UUID - result[expirableItem.Key.ID] = expirableItem.OwnableItem + result[item.Key.UUID] = item.OwnableItem } return result, consumedCapacity, nil } +func itemNotFound(item *storableItem) bool { + return item.Key.Bucket == "" || item.Key.UUID == "" +} + func newService(config aws.Config, awsProfile string, tableName string, logger log.Logger) (service, error) { sess, err := session.NewSessionWithOptions(session.Options{ Config: config, diff --git a/store/dynamodb/service_test.go b/store/dynamodb/service_test.go index 17e6a764..bf3e2d0c 100644 --- a/store/dynamodb/service_test.go +++ b/store/dynamodb/service_test.go @@ -2,6 +2,7 @@ package dynamodb import ( "errors" + "strconv" "testing" "time" @@ -18,7 +19,7 @@ import ( const ( testTableName = "table01" testBucketName = "bucket01" - testIDName = "ID01" + testUUID = "NaYFGE961cS_3dpzJcoP3QTL4kBYcw9ua3Q6Hy5E4nI" ) var ( @@ -31,15 +32,15 @@ var ( var ( item = store.OwnableItem{ Item: model.Item{ - Identifier: testIDName, + Identifier: testUUID, Data: map[string]interface{}{"dataKey": "dataValue"}, - TTL: int64(time.Second * 300), + TTL: aws.Int64(int64((time.Second * 300).Seconds())), }, Owner: "xmidt", } key = model.Key{ Bucket: testBucketName, - ID: testIDName, + UUID: testUUID, } ) @@ -114,12 +115,6 @@ func TestClientErrors(t *testing.T) { suite.Run(t, new(ClientErrorTestSuite)) } -func TestGetItem(t *testing.T) { - initGlobalInputs() - t.Run("Success", testGetItem) - t.Run("NotFound", testGetItemNotFound) -} - func TestPushItem(t *testing.T) { initGlobalInputs() @@ -141,22 +136,15 @@ func TestPushItem(t *testing.T) { assert.Equal(expectedConsumedCapacity, actualConsumedCapacity) } -func TestDeleteItem(t *testing.T) { - initGlobalInputs() - - t.Run("Success", testDelete) - t.Run("NotFound", testDeleteNotFound) -} - func TestGetAllItems(t *testing.T) { initGlobalInputs() - - t.Run("Success", testGetAll) -} - -func testGetAll(t *testing.T) { assert := assert.New(t) m := new(mockClient) + now := time.Now().Unix() + secondsInHour := int64(time.Hour.Seconds()) + pastExpiration := strconv.Itoa(int(now - secondsInHour)) + futureExpiration := strconv.Itoa(int(now + secondsInHour)) + expectedConsumedCapacity := &dynamodb.ConsumedCapacity{ CapacityUnits: aws.Float64(67), } @@ -167,16 +155,49 @@ func testGetAll(t *testing.T) { bucketAttributeKey: { S: aws.String(testBucketName), }, - idAttributeKey: { - S: aws.String("id01"), + uuidAttributeKey: { + S: aws.String("-mTqHoLhIG-CirKgKRfH6SrMuY47lYgaG0rVK5FLZuM"), + }, + identifierAttributeKey: { + S: aws.String("expired"), + }, + expirationAttributeKey: { + N: aws.String(pastExpiration), + }, + }, + { + bucketAttributeKey: { + S: aws.String(testBucketName), + }, + uuidAttributeKey: { + S: aws.String("1wzI3cbHlIHD9TUi9LgOz1Vt1cZIOloD4PvlB5uFT4E"), + }, + identifierAttributeKey: { + S: aws.String("notYetExpired"), + }, + expirationAttributeKey: { + N: aws.String(futureExpiration), }, }, + { bucketAttributeKey: { S: aws.String(testBucketName), }, - idAttributeKey: { - S: aws.String("id02"), + uuidAttributeKey: { + S: aws.String("dbtIlYXQsAoAmexD6zGV8ZfVImEjsFGHcMJdhCZ-1L4"), + }, + identifierAttributeKey: { + S: aws.String("neverExpires"), + }, + }, + + { + bucketAttributeKey: { + S: aws.String(testBucketName), + }, + identifierAttributeKey: { + S: aws.String("db goes cuckoo"), }, }, }, @@ -191,9 +212,17 @@ func testGetAll(t *testing.T) { assert.Nil(err) assert.Len(ownableItems, 2) assert.Equal(expectedConsumedCapacity, actualConsumedCapacity) + + for _, item := range ownableItems { + assert.NotEmpty(item.UUID) + assert.NotEmpty(item.Identifier) + if item.TTL != nil { + assert.NotZero(*item.TTL) + } + } } -func testDelete(t *testing.T) { +func TestDeleteItem(t *testing.T) { assert := assert.New(t) m := new(mockClient) expectedConsumedCapacity := &dynamodb.ConsumedCapacity{ @@ -205,8 +234,8 @@ func testDelete(t *testing.T) { bucketAttributeKey: { S: aws.String(testBucketName), }, - idAttributeKey: { - S: aws.String(testIDName), + uuidAttributeKey: { + S: aws.String(testUUID), }, "data": { M: map[string]*dynamodb.AttributeValue{ @@ -240,91 +269,171 @@ func testDelete(t *testing.T) { assert.Equal(expectedConsumedCapacity, actualConsumedCapacity) } -func testDeleteNotFound(t *testing.T) { - assert := assert.New(t) - m := new(mockClient) - expectedConsumedCapacity := &dynamodb.ConsumedCapacity{ - CapacityUnits: aws.Float64(67), - } - deleteItemOutput := &dynamodb.DeleteItemOutput{ - ConsumedCapacity: expectedConsumedCapacity, - } - m.On("DeleteItem", deleteItemInput).Return(deleteItemOutput, error(nil)) - service := &executor{ - tableName: testTableName, - c: m, - } - ownableItem, actualConsumedCapacity, err := service.Delete(key) - assert.NotNil(ownableItem) - assert.Equal(store.KeyNotFoundError{Key: key}, err) - assert.Equal(expectedConsumedCapacity, actualConsumedCapacity) -} - -func testGetItemNotFound(t *testing.T) { - assert := assert.New(t) - m := new(mockClient) - expectedConsumedCapacity := &dynamodb.ConsumedCapacity{ - CapacityUnits: aws.Float64(67), - } - getItemOutput := &dynamodb.GetItemOutput{ - ConsumedCapacity: expectedConsumedCapacity, - } - m.On("GetItem", getItemInput).Return(getItemOutput, error(nil)) - service := &executor{ - tableName: testTableName, - c: m, - } - ownableItem, actualConsumedCapacity, err := service.Get(key) - assert.NotNil(ownableItem) - assert.Equal(store.KeyNotFoundError{Key: key}, err) - assert.Equal(expectedConsumedCapacity, actualConsumedCapacity) -} +func TestGetItem(t *testing.T) { + initGlobalInputs() + now := time.Now().Unix() + secondsInHour := int64(time.Hour.Seconds()) + pastExpiration := strconv.Itoa(int(now - secondsInHour)) + futureExpiration := strconv.Itoa(int(now + secondsInHour)) + + testCases := []struct { + Name string + GetItemOutput *dynamodb.GetItemOutput + GetItemOutputErr error + ItemExpires bool + ExpectedResponse store.OwnableItem + ExpectedResponseErr error + }{ + { + Name: "Item does not expire", + GetItemOutput: &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + bucketAttributeKey: { + S: aws.String(testBucketName), + }, + uuidAttributeKey: { + S: aws.String(testUUID), + }, + "data": { + M: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("stringVal"), + }, + }, + }, + "owner": { + S: aws.String("xmidt"), + }, -func testGetItem(t *testing.T) { - assert := assert.New(t) - m := new(mockClient) - expectedConsumedCapacity := &dynamodb.ConsumedCapacity{ - CapacityUnits: aws.Float64(67), - } - getItemOutput := &dynamodb.GetItemOutput{ - ConsumedCapacity: expectedConsumedCapacity, - Item: map[string]*dynamodb.AttributeValue{ - bucketAttributeKey: { - S: aws.String(testBucketName), + "identifier": { + S: aws.String("id01"), + }, + }, }, - idAttributeKey: { - S: aws.String(testIDName), + ExpectedResponse: store.OwnableItem{ + Owner: "xmidt", + Item: model.Item{ + UUID: testUUID, + Identifier: "id01", + Data: map[string]interface{}{ + "key": "stringVal", + }, + }, }, - "data": { - M: map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("stringVal"), + }, + + { + Name: "Expired item", + ItemExpires: true, + GetItemOutput: &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + "expires": { + N: aws.String(pastExpiration), + }, + bucketAttributeKey: { + S: aws.String(testBucketName), + }, + uuidAttributeKey: { + S: aws.String(testUUID), + }, + "data": { + M: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("stringVal"), + }, + }, + }, + "owner": { + S: aws.String("xmidt"), + }, + + "identifier": { + S: aws.String("id01"), }, }, }, - "owner": { - S: aws.String("xmidt"), + ExpectedResponseErr: store.KeyNotFoundError{Key: model.Key{ + UUID: testUUID, + Bucket: testBucketName, + }}, + }, + + { + Name: "Item not yet expired", + ItemExpires: true, + GetItemOutput: &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + "expires": { + N: aws.String(futureExpiration), + }, + bucketAttributeKey: { + S: aws.String(testBucketName), + }, + uuidAttributeKey: { + S: aws.String(testUUID), + }, + "data": { + M: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("stringVal"), + }, + }, + }, + "owner": { + S: aws.String("xmidt"), + }, + + "identifier": { + S: aws.String("id01"), + }, + }, + }, + ExpectedResponse: store.OwnableItem{ + Owner: "xmidt", + Item: model.Item{ + UUID: testUUID, + Identifier: "id01", + Data: map[string]interface{}{ + "key": "stringVal", + }, + }, }, + }, - "identifier": { - S: aws.String("id01"), + { + Name: "Item not found", + GetItemOutput: &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{}, }, + ExpectedResponseErr: store.KeyNotFoundError{Key: key}, }, } - expectedData := map[string]interface{}{ - "key": "stringVal", - } - m.On("GetItem", getItemInput).Return(getItemOutput, error(nil)) - service := &executor{ - tableName: testTableName, - c: m, + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + assert := assert.New(t) + m := new(mockClient) + m.On("GetItem", getItemInput).Return(testCase.GetItemOutput, error(nil)) + service := &executor{ + tableName: testTableName, + c: m, + } + ownableItem, actualConsumedCapacity, err := service.Get(key) + if testCase.ExpectedResponseErr == nil { + assert.Nil(err) + assert.Equal(testCase.GetItemOutput.ConsumedCapacity, actualConsumedCapacity) + assert.Equal(testCase.ExpectedResponse.Owner, ownableItem.Owner) + assert.Equal(testCase.ExpectedResponse.Data, ownableItem.Data) + assert.Equal(testCase.ExpectedResponse.Identifier, ownableItem.Identifier) + assert.Equal(testCase.ExpectedResponse.UUID, ownableItem.UUID) + + if testCase.ItemExpires { + assert.NotZero(*ownableItem.TTL) + } + } else { + assert.Equal(testCase.ExpectedResponseErr, err) + } + }) } - ownableItem, actualConsumedCapacity, err := service.Get(key) - assert.Nil(err) - assert.Equal("xmidt", ownableItem.Owner) - assert.Equal("id01", ownableItem.Identifier) - assert.Equal(expectedData, ownableItem.Data) - assert.Equal(expectedConsumedCapacity, actualConsumedCapacity) } func initGlobalInputs() { @@ -334,8 +443,8 @@ func initGlobalInputs() { bucketAttributeKey: { S: aws.String(key.Bucket), }, - idAttributeKey: { - S: aws.String(key.ID), + uuidAttributeKey: { + S: aws.String(key.UUID), }, }, ReturnConsumedCapacity: aws.String(dynamodb.ReturnConsumedCapacityTotal), @@ -346,8 +455,8 @@ func initGlobalInputs() { bucketAttributeKey: { S: aws.String(key.Bucket), }, - idAttributeKey: { - S: aws.String(key.ID), + uuidAttributeKey: { + S: aws.String(key.UUID), }, }, ReturnValues: aws.String(dynamodb.ReturnValueAllOld), @@ -355,9 +464,9 @@ func initGlobalInputs() { ReturnConsumedCapacity: aws.String(dynamodb.ReturnConsumedCapacityTotal), } - expirableItem := element{ + expirableItem := storableItem{ OwnableItem: item, - Expires: time.Now().Unix() + item.TTL, + Expires: aws.Int64(time.Now().Unix() + *item.TTL), Key: key, } encodedItem, err := dynamodbattribute.MarshalMap(expirableItem) diff --git a/store/endpoint.go b/store/endpoint.go index 1f4aa27e..9c69285e 100644 --- a/store/endpoint.go +++ b/store/endpoint.go @@ -23,6 +23,8 @@ import ( "github.com/go-kit/kit/endpoint" ) +var accessDeniedErr = &ForbiddenRequestErr{Message: "resource owner mismatch"} + func newGetItemEndpoint(s S) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { itemRequest := request.(*getOrDeleteItemRequest) @@ -30,11 +32,11 @@ func newGetItemEndpoint(s S) endpoint.Endpoint { if err != nil { return nil, err } - if userOwnsItem(itemRequest.owner, itemResponse.Owner) { + if authorized(itemRequest.adminMode, itemResponse.Owner, itemRequest.owner) { return &itemResponse, nil } - return nil, &KeyNotFoundError{Key: itemRequest.key} + return nil, accessDeniedErr } } @@ -45,15 +47,16 @@ func newDeleteItemEndpoint(s S) endpoint.Endpoint { if err != nil { return nil, err } - if userOwnsItem(itemRequest.owner, itemResponse.Owner) { - deleteItemResp, deleteItemRespErr := s.Delete(itemRequest.key) - if deleteItemRespErr != nil { - return nil, deleteItemRespErr - } - return &deleteItemResp, nil + + if !authorized(itemRequest.adminMode, itemResponse.Owner, itemRequest.owner) { + return nil, accessDeniedErr } - return nil, &KeyNotFoundError{Key: itemRequest.key} + deleteItemResp, deleteItemRespErr := s.Delete(itemRequest.key) + if deleteItemRespErr != nil { + return nil, deleteItemRespErr + } + return &deleteItemResp, nil } } @@ -64,42 +67,47 @@ func newGetAllItemsEndpoint(s S) endpoint.Endpoint { if err != nil { return nil, err } - + if itemsRequest.adminMode { + return items, nil + } return FilterOwner(items, itemsRequest.owner), nil } } -func newPushItemEndpoint(s S) endpoint.Endpoint { +func newSetItemEndpoint(s S) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - pushItemRequest := request.(*pushItemRequest) - err := s.Push(pushItemRequest.key, pushItemRequest.item) - if err != nil { - return nil, err - } - return &pushItemRequest.key, nil - } -} + setItemRequest := request.(*setItemRequest) + itemResponse, err := s.Get(setItemRequest.key) -func newUpdateItemEndpoint(s S) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - updateItemRequest := request.(*pushItemRequest) - itemResponse, err := s.Get(updateItemRequest.key) if err != nil { - return nil, err - } + switch err.(type) { + case KeyNotFoundError: + err = s.Push(setItemRequest.key, setItemRequest.item) + if err != nil { + return nil, err + } + return &setItemResponse{}, nil - if userOwnsItem(updateItemRequest.item.Owner, itemResponse.Owner) { - err := s.Push(updateItemRequest.key, updateItemRequest.item) - if err != nil { + default: return nil, err } - return &updateItemRequest.key, nil } - return nil, &KeyNotFoundError{Key: updateItemRequest.key} + if !authorized(setItemRequest.adminMode, itemResponse.Owner, setItemRequest.item.Owner) { + return nil, accessDeniedErr + } + + err = s.Push(setItemRequest.key, setItemRequest.item) + if err != nil { + return nil, err + } + + return &setItemResponse{ + existingResource: true, + }, nil } } -func userOwnsItem(requestItemOwner, actualItemOwner string) bool { - return requestItemOwner == "" || requestItemOwner == actualItemOwner +func authorized(adminMode bool, resourceOwner, requestOwner string) bool { + return adminMode || resourceOwner == requestOwner } diff --git a/store/endpoint_test.go b/store/endpoint_test.go index f88de736..0efac36c 100644 --- a/store/endpoint_test.go +++ b/store/endpoint_test.go @@ -31,15 +31,31 @@ func TestGetItemEndpoint(t *testing.T) { ItemRequest: &getOrDeleteItemRequest{ owner: "Kirby", key: model.Key{ - ID: "hammer", + UUID: "hammer", }, }, DAOResponse: OwnableItem{ Owner: "Yoshi", }, - ExpectedErr: &KeyNotFoundError{Key: model.Key{ - ID: "hammer", - }}, + ExpectedErr: accessDeniedErr, + }, + + { + Name: "Wrong owner but admin mode", + ItemRequest: &getOrDeleteItemRequest{ + owner: "Kirby", + key: model.Key{ + UUID: "hammer", + }, + adminMode: true, + }, + DAOResponse: OwnableItem{ + Owner: "Yoshi", + }, + + ExpectedResponse: &OwnableItem{ + Owner: "Yoshi", + }, }, { Name: "Success", @@ -102,8 +118,28 @@ func TestDeleteItemEndpoint(t *testing.T) { GetDAOResponse: OwnableItem{ Owner: "fiber", }, - ExpectedErr: &KeyNotFoundError{}, + ExpectedErr: accessDeniedErr, }, + + { + Name: "Wrong owner but admin mode", + ItemRequest: &getOrDeleteItemRequest{ + owner: "cable", + adminMode: true, + }, + GetDAOResponse: OwnableItem{ + Owner: "fiber", + }, + + DeleteDAOResponse: OwnableItem{ + Owner: "fiber", + }, + + ExpectedResponse: &OwnableItem{ + Owner: "fiber", + }, + }, + { Name: "Deletion fails", ItemRequest: &getOrDeleteItemRequest{ @@ -140,7 +176,7 @@ func TestDeleteItemEndpoint(t *testing.T) { m.On("Get", testCase.ItemRequest.key).Return(testCase.GetDAOResponse, error(testCase.GetDAOResponseErr)).Once() // verify item is not deleted by user who doesn't own it - allowDelete := testCase.ItemRequest.owner == "" || testCase.ItemRequest.owner == testCase.GetDAOResponse.Owner + allowDelete := testCase.ItemRequest.adminMode || testCase.ItemRequest.owner == testCase.GetDAOResponse.Owner if testCase.GetDAOResponseErr != nil || !allowDelete { m.AssertNotCalled(t, "Delete", testCase.ItemRequest) @@ -162,23 +198,23 @@ func TestDeleteItemEndpoint(t *testing.T) { } } -func TestUpdateItemEndpoint(t *testing.T) { +func TestSetItemEndpoint(t *testing.T) { testCases := []struct { - Name string - ItemRequest *pushItemRequest - UpdateDAOResponse OwnableItem - UpdateDAOResponseErr error - GetDAOResponse OwnableItem - GetDAOResponseErr error - ExpectedResponse *model.Key - ExpectedErr error + Name string + ItemRequest *setItemRequest + PushDAOResponse OwnableItem + PushDAOResponseErr error + GetDAOResponse OwnableItem + GetDAOResponseErr error + ExpectedResponse *setItemResponse + ExpectedErr error }{ { - Name: "Update DAO failure", - ItemRequest: &pushItemRequest{ + Name: "Push DAO failure", + ItemRequest: &setItemRequest{ key: model.Key{ Bucket: "fruits", - ID: "random-UUID", + UUID: "XnN_iR2xF1RCo5_ec-UdeBpUVQbXHJVHem3rWYi9f5o", }, item: OwnableItem{ Item: model.Item{ @@ -190,15 +226,15 @@ func TestUpdateItemEndpoint(t *testing.T) { GetDAOResponse: OwnableItem{ Owner: "Bob", }, - UpdateDAOResponseErr: errors.New("DB failed"), - ExpectedErr: errors.New("DB failed"), + PushDAOResponseErr: errors.New("DB failed"), + ExpectedErr: errors.New("DB failed"), }, { Name: "Get DAO failure", - ItemRequest: &pushItemRequest{ + ItemRequest: &setItemRequest{ key: model.Key{ Bucket: "fruits", - ID: "random-UUID", + UUID: "XnN_iR2xF1RCo5_ec-UdeBpUVQbXHJVHem3rWYi9f5o", }, item: OwnableItem{ Item: model.Item{ @@ -212,10 +248,10 @@ func TestUpdateItemEndpoint(t *testing.T) { }, { Name: "Wrong owner", - ItemRequest: &pushItemRequest{ + ItemRequest: &setItemRequest{ key: model.Key{ Bucket: "fruits", - ID: "random-UUID", + UUID: "XnN_iR2xF1RCo5_ec-UdeBpUVQbXHJVHem3rWYi9f5o", }, item: OwnableItem{ Item: model.Item{}, @@ -225,35 +261,52 @@ func TestUpdateItemEndpoint(t *testing.T) { GetDAOResponse: OwnableItem{ Owner: "fiber", }, - ExpectedErr: &KeyNotFoundError{ - Key: model.Key{ - Bucket: "fruits", - ID: "random-UUID", - }, - }, + ExpectedErr: accessDeniedErr, }, { - Name: "Successful Update", - ItemRequest: &pushItemRequest{ + Name: "Successful Update. Wrong owner but admin mode", + ItemRequest: &setItemRequest{ key: model.Key{ Bucket: "fruits", - ID: "random-UUID", + UUID: "XnN_iR2xF1RCo5_ec-UdeBpUVQbXHJVHem3rWYi9f5o", }, item: OwnableItem{ Item: model.Item{}, Owner: "cable", }, + adminMode: true, }, GetDAOResponse: OwnableItem{ Owner: "cable", }, - UpdateDAOResponse: OwnableItem{ + PushDAOResponse: OwnableItem{ + Owner: "cable", + Item: model.Item{}, + }, + ExpectedResponse: &setItemResponse{ + existingResource: true, + }, + }, + + { + Name: "Successful Creation", + ItemRequest: &setItemRequest{ + key: model.Key{ + Bucket: "fruits", + UUID: "XnN_iR2xF1RCo5_ec-UdeBpUVQbXHJVHem3rWYi9f5o", + }, + item: OwnableItem{ + Item: model.Item{}, + Owner: "cable", + }, + }, + GetDAOResponseErr: KeyNotFoundError{}, + PushDAOResponse: OwnableItem{ Owner: "cable", Item: model.Item{}, }, - ExpectedResponse: &model.Key{ - Bucket: "fruits", - ID: "random-UUID", + ExpectedResponse: &setItemResponse{ + existingResource: false, }, }, } @@ -263,130 +316,150 @@ func TestUpdateItemEndpoint(t *testing.T) { assert := assert.New(t) m := new(MockDAO) - if testCase.UpdateDAOResponseErr == nil { + if testCase.PushDAOResponseErr == nil { m.On("Push", testCase.ItemRequest.key, testCase.ItemRequest.item).Return(nil).Once() } else { - m.On("Push", testCase.ItemRequest.key, testCase.ItemRequest.item).Return(testCase.UpdateDAOResponseErr).Once() + m.On("Push", testCase.ItemRequest.key, testCase.ItemRequest.item).Return(testCase.PushDAOResponseErr).Once() } m.On("Get", testCase.ItemRequest.key).Return(testCase.GetDAOResponse, error(testCase.GetDAOResponseErr)).Once() - endpoint := newUpdateItemEndpoint(m) + endpoint := newSetItemEndpoint(m) resp, err := endpoint(context.Background(), testCase.ItemRequest) if testCase.ExpectedErr == nil { assert.Nil(err) - assert.Equal(&testCase.ItemRequest.key, resp) + assert.Equal(testCase.ExpectedResponse, resp) m.AssertExpectations(t) } else { assert.Equal(testCase.ExpectedErr, err) } - }) } } func TestGetAllItemsEndpoint(t *testing.T) { - t.Run("DAOFails", testGetAllItemsEndpointDAOFails) - t.Run("FilteredItems", testGetAllItemsEndpointFiltered) -} + testCases := []struct { + Name string + ItemRequest *getAllItemsRequest + GetAllDAOResponse map[string]OwnableItem + GetAllDAOResponseErr error + ExpectedResponse map[string]OwnableItem + ExpectedErr error + }{ + { + Name: "DAO failure", + ItemRequest: &getAllItemsRequest{ + bucket: "sports-cars", + owner: "alfa-romeo", + }, + GetAllDAOResponseErr: errors.New("DB failed"), + ExpectedErr: errors.New("DB failed"), + }, + { + Name: "Filtered results", + ItemRequest: &getAllItemsRequest{ + bucket: "sports-cars", + owner: "alfa-romeo", + }, + GetAllDAOResponse: map[string]OwnableItem{ + "mustang": OwnableItem{ + Owner: "ford", + }, + "4c-spider": OwnableItem{ + Owner: "alfa-romeo", + }, + "gtr": OwnableItem{ + Owner: "nissan", + }, + "giulia": OwnableItem{ + Owner: "alfa-romeo", + }, + }, -func testGetAllItemsEndpointDAOFails(t *testing.T) { - assert := assert.New(t) - m := new(MockDAO) - itemsRequest := &getAllItemsRequest{ - bucket: "sports-cars", - owner: "alfa-romeo", - } - mockedErr := errors.New("sports cars api is down") - m.On("GetAll", "sports-cars").Return(map[string]OwnableItem{}, mockedErr).Once() + ExpectedResponse: map[string]OwnableItem{ + "4c-spider": OwnableItem{ + Owner: "alfa-romeo", + }, - endpoint := newGetAllItemsEndpoint(m) - resp, err := endpoint(context.Background(), itemsRequest) + "giulia": OwnableItem{ + Owner: "alfa-romeo", + }, + }, + }, - assert.Nil(resp) - assert.Equal(mockedErr, err) - m.AssertExpectations(t) -} + { + Name: "Admin mode", + ItemRequest: &getAllItemsRequest{ + bucket: "sports-cars", + owner: "alfa-romeo", + adminMode: true, + }, + GetAllDAOResponse: map[string]OwnableItem{ + "mustang": OwnableItem{ + Owner: "ford", + }, + "4c-spider": OwnableItem{ + Owner: "alfa-romeo", + }, + "gtr": OwnableItem{ + Owner: "nissan", + }, + "giulia": OwnableItem{ + Owner: "alfa-romeo", + }, + }, -func testGetAllItemsEndpointFiltered(t *testing.T) { - assert := assert.New(t) - m := new(MockDAO) - itemsRequest := &getAllItemsRequest{ - bucket: "sports-cars", - owner: "alfa-romeo", - } - mockedItems := map[string]OwnableItem{ - "mustang": OwnableItem{ - Owner: "ford", - }, - "4c-spider": OwnableItem{ - Owner: "alfa-romeo", - }, - "gtr": OwnableItem{ - Owner: "nissan", - }, - "giulia": OwnableItem{ - Owner: "alfa-romeo", + ExpectedResponse: map[string]OwnableItem{ + "mustang": OwnableItem{ + Owner: "ford", + }, + "4c-spider": OwnableItem{ + Owner: "alfa-romeo", + }, + "gtr": OwnableItem{ + Owner: "nissan", + }, + "giulia": OwnableItem{ + Owner: "alfa-romeo", + }, + }, }, - } - m.On("GetAll", "sports-cars").Return(mockedItems, error(nil)).Once() - endpoint := newGetAllItemsEndpoint(m) - resp, err := endpoint(context.Background(), itemsRequest) + { + Name: "Empty results", + ItemRequest: &getAllItemsRequest{ + bucket: "sports-cars", + owner: "volkswagen", + }, + GetAllDAOResponse: map[string]OwnableItem{ + "mustang": OwnableItem{ + Owner: "ford", + }, + "giulia": OwnableItem{ + Owner: "alfa-romeo", + }, + }, - expectedItems := map[string]OwnableItem{ - "4c-spider": OwnableItem{ - Owner: "alfa-romeo", - }, - "giulia": OwnableItem{ - Owner: "alfa-romeo", + ExpectedResponse: map[string]OwnableItem{}, }, } - assert.Equal(expectedItems, resp) - assert.Nil(err) - m.AssertExpectations(t) -} - -func TestPushItemEndpoint(t *testing.T) { - t.Run("DAOFails", testPushItemEndpointDAOFails) - t.Run("Happy Path", testPushItemEndpointHappyPath) -} + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + assert := assert.New(t) + m := new(MockDAO) -func testPushItemEndpointHappyPath(t *testing.T) { - assert := assert.New(t) - m := new(MockDAO) - key := model.Key{ - Bucket: "fruits", - ID: "strawberry", - } + m.On("GetAll", testCase.ItemRequest.bucket).Return(testCase.GetAllDAOResponse, error(testCase.GetAllDAOResponseErr)) - item := OwnableItem{ - Item: model.Item{ - Identifier: "strawberry", - }, - Owner: "Bob", + endpoint := newGetAllItemsEndpoint(m) + resp, err := endpoint(context.Background(), testCase.ItemRequest) + if testCase.ExpectedErr == nil { + assert.Nil(err) + assert.Equal(testCase.ExpectedResponse, resp) + } else { + assert.Equal(testCase.ExpectedErr, err) + } + }) } - - m.On("Push", key, item).Return(nil).Once() - endpoint := newPushItemEndpoint(m) - resp, err := endpoint(context.Background(), &pushItemRequest{ - key: key, - item: item, - }) - assert.Nil(err) - assert.Equal(&key, resp) - m.AssertExpectations(t) -} - -func testPushItemEndpointDAOFails(t *testing.T) { - assert := assert.New(t) - m := new(MockDAO) - m.On("Push", model.Key{}, OwnableItem{}).Return(errors.New("DB failed")).Once() - endpoint := newPushItemEndpoint(m) - resp, err := endpoint(context.Background(), &pushItemRequest{}) - assert.Nil(resp) - assert.Equal(errors.New("DB failed"), err) - m.AssertExpectations(t) } diff --git a/store/errors.go b/store/errors.go index eddfba21..4b8b4dde 100644 --- a/store/errors.go +++ b/store/errors.go @@ -19,18 +19,30 @@ func (bre BadRequestErr) StatusCode() int { return http.StatusBadRequest } +type ForbiddenRequestErr struct { + Message string +} + +func (f ForbiddenRequestErr) Error() string { + return f.Message +} + +func (f ForbiddenRequestErr) StatusCode() int { + return http.StatusForbidden +} + type KeyNotFoundError struct { Key model.Key } func (knfe KeyNotFoundError) Error() string { - if knfe.Key.ID == "" && knfe.Key.Bucket == "" { + if knfe.Key.UUID == "" && knfe.Key.Bucket == "" { return fmt.Sprint("parameters for key not set") - } else if knfe.Key.ID == "" && knfe.Key.Bucket != "" { + } else if knfe.Key.UUID == "" && knfe.Key.Bucket != "" { return fmt.Sprintf("no value exists for bucket %s", knfe.Key.Bucket) } - return fmt.Sprintf("no value exists with bucket: %s, id: %s", knfe.Key.Bucket, knfe.Key.ID) + return fmt.Sprintf("no value exists with bucket: %s, uuid: %s", knfe.Key.Bucket, knfe.Key.UUID) } func (knfe KeyNotFoundError) StatusCode() int { diff --git a/store/handler.go b/store/handler.go index 0676de02..932ee5d8 100644 --- a/store/handler.go +++ b/store/handler.go @@ -19,7 +19,6 @@ package store import ( "net/http" - "time" kithttp "github.com/go-kit/kit/transport/http" "github.com/xmidt-org/argus/model" @@ -33,52 +32,38 @@ type KeyItemPairRequest struct { Method string } -type ItemTTL struct { - DefaultTTL time.Duration - MaxTTL time.Duration -} - -func newGetItemHandler(s S) Handler { +func newGetItemHandler(config *requestConfig, s S) Handler { return kithttp.NewServer( newGetItemEndpoint(s), - decodeGetOrDeleteItemRequest, + getOrDeleteItemRequestDecoder(config), encodeGetOrDeleteItemResponse, kithttp.ServerErrorEncoder(encodeError), ) } -func newDeleteItemHandler(s S) Handler { +func newDeleteItemHandler(config *requestConfig, s S) Handler { return kithttp.NewServer( newDeleteItemEndpoint(s), - decodeGetOrDeleteItemRequest, + getOrDeleteItemRequestDecoder(config), encodeGetOrDeleteItemResponse, kithttp.ServerErrorEncoder(encodeError), ) } -func newGetAllItemsHandler(s S) Handler { +func newGetAllItemsHandler(config *requestConfig, s S) Handler { return kithttp.NewServer( newGetAllItemsEndpoint(s), - decodeGetAllItemsRequest, + getAllItemsRequestDecoder(config), encodeGetAllItemsResponse, kithttp.ServerErrorEncoder(encodeError), ) } -func newPushItemHandler(itemTTLInfo ItemTTL, s S) Handler { - return kithttp.NewServer( - newPushItemEndpoint(s), - pushItemRequestDecoder(itemTTLInfo, false), - encodePushItemResponse, - kithttp.ServerErrorEncoder(encodeError), - ) -} - -func newUpdateItemHandler(itemTTLInfo ItemTTL, s S) Handler { +func newSetItemHandler(config *requestConfig, s S) Handler { return kithttp.NewServer( - newUpdateItemEndpoint(s), - pushItemRequestDecoder(itemTTLInfo, true), - encodePushItemResponse, + newSetItemEndpoint(s), + setItemRequestDecoder(config), + encodeSetItemResponse, kithttp.ServerErrorEncoder(encodeError), ) } diff --git a/store/inmem/InMem.go b/store/inmem/InMem.go index 2cd57b69..3bf62e24 100644 --- a/store/inmem/InMem.go +++ b/store/inmem/InMem.go @@ -18,9 +18,10 @@ package inmem import ( + "sync" + "github.com/xmidt-org/argus/model" "github.com/xmidt-org/argus/store" - "sync" ) type InMem struct { @@ -38,10 +39,10 @@ func (i *InMem) Push(key model.Key, item store.OwnableItem) error { i.lock.Lock() if _, ok := i.data[key.Bucket]; !ok { i.data[key.Bucket] = map[string]store.OwnableItem{ - key.ID: item, + key.UUID: item, } } else { - i.data[key.Bucket][key.ID] = item + i.data[key.Bucket][key.UUID] = item } i.lock.Unlock() return nil @@ -56,7 +57,7 @@ func (i *InMem) Get(key model.Key) (store.OwnableItem, error) { if _, ok := i.data[key.Bucket]; !ok { err = store.KeyNotFoundError{Key: key} } else { - if value, ok := i.data[key.Bucket][key.ID]; !ok { + if value, ok := i.data[key.Bucket][key.UUID]; !ok { err = store.KeyNotFoundError{Key: key} } else { item = value @@ -78,7 +79,7 @@ func (i *InMem) GetAll(bucket string) (map[string]store.OwnableItem, error) { } else { err = store.KeyNotFoundError{Key: model.Key{ Bucket: bucket, - ID: "", + UUID: "", }} } i.lock.RUnlock() @@ -94,11 +95,11 @@ func (i *InMem) Delete(key model.Key) (store.OwnableItem, error) { if _, ok := i.data[key.Bucket]; !ok { err = store.KeyNotFoundError{Key: key} } else { - if value, ok := i.data[key.Bucket][key.ID]; !ok { + if value, ok := i.data[key.Bucket][key.UUID]; !ok { err = store.KeyNotFoundError{Key: key} } else { item = value - delete(i.data[key.Bucket], key.ID) + delete(i.data[key.Bucket], key.UUID) } } i.lock.Unlock() diff --git a/store/provide.go b/store/provide.go index 36f9f104..a0fd52a2 100644 --- a/store/provide.go +++ b/store/provide.go @@ -35,11 +35,8 @@ type StoreIn struct { type StoreOut struct { fx.Out - // PushItemHandler is the http.Handler to add a new item to the store. - PushItemHandler Handler `name:"pushHandler"` - // SetItemHandler is the http.Handler to update an item in the store. - UpdateItemHandler Handler `name:"updateHandler"` + SetItemHandler Handler `name:"setHandler"` // SetKeyHandler is the http.Handler to fetch an individual item from the store. GetItemHandler Handler `name:"getHandler"` @@ -53,27 +50,20 @@ type StoreOut struct { // Provide is an uber/fx style provider for this package's components func Provide(unmarshaller config.Unmarshaller, in StoreIn) StoreOut { - itemTTL := ItemTTL{ - DefaultTTL: DefaultTTL, - MaxTTL: YearTTL, - } - unmarshaller.UnmarshalKey("itemTTL", itemTTL) - validateItemTTLConfig(&itemTTL) + cfg := new(requestConfig) + unmarshaller.UnmarshalKey("request", cfg) + validateRequestConfig(cfg) return StoreOut{ - PushItemHandler: newPushItemHandler(itemTTL, in.Store), - UpdateItemHandler: newUpdateItemHandler(itemTTL, in.Store), - GetItemHandler: newGetItemHandler(in.Store), - GetAllItemsHandler: newGetAllItemsHandler(in.Store), - DeleteKeyHandler: newDeleteItemHandler(in.Store), + SetItemHandler: newSetItemHandler(cfg, in.Store), + GetItemHandler: newGetItemHandler(cfg, in.Store), + GetAllItemsHandler: newGetAllItemsHandler(cfg, in.Store), + DeleteKeyHandler: newDeleteItemHandler(cfg, in.Store), } } -func validateItemTTLConfig(ttl *ItemTTL) { - if ttl.MaxTTL <= time.Second { - ttl.MaxTTL = YearTTL * time.Second - } - if ttl.DefaultTTL <= time.Millisecond { - ttl.DefaultTTL = DefaultTTL * time.Second +func validateRequestConfig(cfg *requestConfig) { + if cfg.Validation.MaxTTL <= time.Second { + cfg.Validation.MaxTTL = YearTTL * time.Second } } diff --git a/store/store.go b/store/store.go index 1722aff3..e3fe47a0 100644 --- a/store/store.go +++ b/store/store.go @@ -51,10 +51,6 @@ type OwnableItem struct { } func FilterOwner(value map[string]OwnableItem, owner string) map[string]OwnableItem { - if owner == "" { - return value - } - filteredResults := map[string]OwnableItem{} for k, v := range value { if v.Owner == owner { diff --git a/store/test/storetest.go b/store/test/storetest.go index 30019bfd..acdaa979 100644 --- a/store/test/storetest.go +++ b/store/test/storetest.go @@ -18,17 +18,19 @@ package test import ( + "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/xmidt-org/argus/model" "github.com/xmidt-org/argus/store" - "testing" - "time" ) +var testInt int64 = 3 var GenericTestKeyPair = store.KeyItemPairRequest{ Key: model.Key{ Bucket: "world", - ID: "earth", + UUID: "earth", }, OwnableItem: store.OwnableItem{ Item: model.Item{ @@ -37,7 +39,7 @@ var GenericTestKeyPair = store.KeyItemPairRequest{ "year": float64(1967), "words": []interface{}{"What", "a", "Wonderful", "World"}, }, - TTL: 3, + TTL: &testInt, }, Owner: "Louis Armstrong", }, diff --git a/store/transport.go b/store/transport.go index 70cdca65..969ca5ce 100644 --- a/store/transport.go +++ b/store/transport.go @@ -6,9 +6,10 @@ import ( "errors" "io/ioutil" "net/http" + "sort" + "time" kithttp "github.com/go-kit/kit/transport/http" - "github.com/google/uuid" "github.com/gorilla/mux" "github.com/xmidt-org/argus/model" ) @@ -16,17 +17,18 @@ import ( // request URL path keys const ( bucketVarKey = "bucket" - idVarKey = "id" + uuidVarKey = "uuid" ) const ( bucketVarMissingMsg = "{bucket} URL path parameter missing" - idVarMissingMsg = "{id} URL path parameter missing" + uuidVarMissingMsg = "{uuid} URL path parameter missing" ) // Request and Response Headers const ( ItemOwnerHeaderKey = "X-Midt-Owner" + AdminTokenHeaderKey = "X-Midt-Admin-Token" XmidtErrorHeaderKey = "X-Midt-Error" ) @@ -34,50 +36,57 @@ const ( // encoders. var ErrCasting = errors.New("casting error due to middleware wiring mistake") +type requestConfig struct { + Validation validationConfig + Authorization authorizationConfig +} + +type validationConfig struct { + MaxTTL time.Duration +} + +type authorizationConfig struct { + AdminToken string +} + type getOrDeleteItemRequest struct { - key model.Key - owner string + key model.Key + owner string + adminMode bool } type getAllItemsRequest struct { - bucket string - owner string + bucket string + owner string + adminMode bool } -type pushItemRequest struct { - key model.Key - item OwnableItem +type setItemRequest struct { + key model.Key + item OwnableItem + adminMode bool } -func decodeGetAllItemsRequest(ctx context.Context, r *http.Request) (interface{}, error) { - vars := mux.Vars(r) - bucket, ok := vars[bucketVarKey] - if !ok { - return nil, &BadRequestErr{Message: bucketVarMissingMsg} - } - return &getAllItemsRequest{ - bucket: bucket, - owner: r.Header.Get(ItemOwnerHeaderKey), - }, nil +type setItemResponse struct { + key model.Key + existingResource bool } -func pushItemRequestDecoder(itemTTLInfo ItemTTL, update bool) kithttp.DecodeRequestFunc { +func getAllItemsRequestDecoder(config *requestConfig) kithttp.DecodeRequestFunc { return func(ctx context.Context, r *http.Request) (interface{}, error) { - vars := mux.Vars(r) - bucket, ok := vars[bucketVarKey] - if !ok { - return nil, &BadRequestErr{Message: bucketVarMissingMsg} - } - - var id string - - if update { - id, ok = vars[idVarKey] + return &getAllItemsRequest{ + bucket: mux.Vars(r)[bucketVarKey], + owner: r.Header.Get(ItemOwnerHeaderKey), + adminMode: config.Authorization.AdminToken == r.Header.Get(AdminTokenHeaderKey), + }, nil + } +} - if !ok { - return nil, &BadRequestErr{Message: idVarMissingMsg} - } - } +func setItemRequestDecoder(config *requestConfig) kithttp.DecodeRequestFunc { + return func(ctx context.Context, r *http.Request) (interface{}, error) { + URLVars := mux.Vars(r) + bucket := URLVars[bucketVarKey] + uuid := URLVars[uuidVarKey] data, err := ioutil.ReadAll(r.Body) if err != nil { @@ -94,99 +103,75 @@ func pushItemRequestDecoder(itemTTLInfo ItemTTL, update bool) kithttp.DecodeRequ return nil, &BadRequestErr{Message: "data field must be set"} } - if !update { - id = generateID() - } + validateItemTTL(&item, config.Validation.MaxTTL) - validateItemTTL(&item, itemTTLInfo) + if item.UUID != uuid { + return nil, &BadRequestErr{Message: "UUIDs must match between URL and payload"} + } - return &pushItemRequest{ + return &setItemRequest{ item: OwnableItem{ Item: item, Owner: r.Header.Get(ItemOwnerHeaderKey), }, key: model.Key{ Bucket: bucket, - ID: id, + UUID: uuid, }, + adminMode: config.Authorization.AdminToken == r.Header.Get(AdminTokenHeaderKey), }, nil } } -func validateItemTTL(item *model.Item, itemTTLInfo ItemTTL) { - if item.TTL > int64(itemTTLInfo.MaxTTL.Seconds()) { - item.TTL = int64(itemTTLInfo.MaxTTL.Seconds()) - } +func getOrDeleteItemRequestDecoder(config *requestConfig) kithttp.DecodeRequestFunc { + return func(ctx context.Context, r *http.Request) (interface{}, error) { + URLVars := mux.Vars(r) - if item.TTL < 1 { - item.TTL = int64(itemTTLInfo.DefaultTTL.Seconds()) + return &getOrDeleteItemRequest{ + key: model.Key{ + Bucket: URLVars[bucketVarKey], + UUID: URLVars[uuidVarKey], + }, + adminMode: config.Authorization.AdminToken == r.Header.Get(AdminTokenHeaderKey), + owner: r.Header.Get(ItemOwnerHeaderKey), + }, nil } } -func generateID() string { - return uuid.New().String() -} - -func encodePushItemResponse(ctx context.Context, rw http.ResponseWriter, response interface{}) error { - pushItemResponse := response.(*model.Key) - data, err := json.Marshal(&pushItemResponse) - if err != nil { - return err +func encodeSetItemResponse(ctx context.Context, rw http.ResponseWriter, response interface{}) error { + r := response.(*setItemResponse) + if r.existingResource { + rw.WriteHeader(http.StatusOK) + } else { + rw.WriteHeader(http.StatusCreated) } - rw.Header().Add("Content-Type", "application/json") - rw.Write(data) return nil } + +// TODO: I noticed order of result elements get shuffled around on multiple fetches +// This is because of dynamodb. To make tests easier, results are sorted by lexicographical non-decreasing +// order of the UUIDs. func encodeGetAllItemsResponse(ctx context.Context, rw http.ResponseWriter, response interface{}) error { items := response.(map[string]OwnableItem) - payload := map[string]model.Item{} - for k, value := range items { - if value.TTL <= 0 { - continue - } - payload[k] = value.Item + list := []model.Item{} + for _, value := range items { + list = append(list, value.Item) } - data, err := json.Marshal(&payload) + data, err := json.Marshal(&list) if err != nil { return err } + + sort.SliceStable(list, func(i, j int) bool { + return list[i].UUID < list[j].UUID + }) rw.Header().Add("Content-Type", "application/json") rw.Write(data) return nil } -func decodeGetOrDeleteItemRequest(ctx context.Context, r *http.Request) (interface{}, error) { - vars := mux.Vars(r) - bucket, ok := vars[bucketVarKey] - if !ok { - return nil, &BadRequestErr{Message: bucketVarMissingMsg} - } - - id, ok := vars[idVarKey] - - if !ok { - return nil, &BadRequestErr{Message: idVarMissingMsg} - } - - return &getOrDeleteItemRequest{ - key: model.Key{ - Bucket: bucket, - ID: id, - }, - owner: r.Header.Get(ItemOwnerHeaderKey), - }, nil -} - func encodeGetOrDeleteItemResponse(ctx context.Context, rw http.ResponseWriter, response interface{}) error { - item, ok := response.(*OwnableItem) - if !ok { - return ErrCasting - } - - if item.TTL <= 0 { - rw.WriteHeader(http.StatusNotFound) - return nil - } + item := response.(*OwnableItem) data, err := json.Marshal(&item.Item) if err != nil { @@ -213,3 +198,12 @@ func encodeError(ctx context.Context, err error, w http.ResponseWriter) { } w.WriteHeader(code) } + +func validateItemTTL(item *model.Item, maxTTL time.Duration) { + if item.TTL != nil { + ttlCapSeconds := int64(maxTTL.Seconds()) + if *item.TTL > ttlCapSeconds { + item.TTL = &ttlCapSeconds + } + } +} diff --git a/store/transport_test.go b/store/transport_test.go index f3633e92..05bc2422 100644 --- a/store/transport_test.go +++ b/store/transport_test.go @@ -8,12 +8,13 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go/aws" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/xmidt-org/argus/model" ) -func TestDecodeGetOrDeleteItemRequest(t *testing.T) { +func TestGetOrDeleteItemRequestDecoder(t *testing.T) { testCases := []struct { Name string URLVars map[string]string @@ -22,48 +23,37 @@ func TestDecodeGetOrDeleteItemRequest(t *testing.T) { ExpectedErr error }{ { - Name: "Missing id", + Name: "Happy path - No owner - Normal mode", URLVars: map[string]string{ "bucket": "california", - }, - ExpectedErr: &BadRequestErr{Message: idVarMissingMsg}, - }, - { - Name: "Missing bucket", - URLVars: map[string]string{ - "id": "san francisco", - }, - ExpectedErr: &BadRequestErr{Message: bucketVarMissingMsg}, - }, - { - Name: "Happy path - No owner", - URLVars: map[string]string{ - "bucket": "california", - "id": "san francisco", + "uuid": "san francisco", }, ExpectedDecodedRequest: &getOrDeleteItemRequest{ key: model.Key{ Bucket: "california", - ID: "san francisco", + UUID: "san francisco", }, }, }, { - Name: "Happy path", + Name: "Happy path - Owner - Admin mode", URLVars: map[string]string{ "bucket": "california", - "id": "san francisco", + "uuid": "san francisco", + }, + + Headers: map[string][]string{ + ItemOwnerHeaderKey: []string{"SF Giants"}, + AdminTokenHeaderKey: []string{"secretAdminToken"}, }, ExpectedDecodedRequest: &getOrDeleteItemRequest{ key: model.Key{ Bucket: "california", - ID: "san francisco", + UUID: "san francisco", }, - owner: "SF Giants", - }, - Headers: map[string][]string{ - ItemOwnerHeaderKey: []string{"SF Giants"}, + owner: "SF Giants", + adminMode: true, }, }, } @@ -75,7 +65,13 @@ func TestDecodeGetOrDeleteItemRequest(t *testing.T) { transferHeaders(testCase.Headers, r) r = mux.SetURLVars(r, testCase.URLVars) - decodedRequest, err := decodeGetOrDeleteItemRequest(context.Background(), r) + config := &requestConfig{ + Authorization: authorizationConfig{ + AdminToken: "secretAdminToken", + }, + } + decoder := getOrDeleteItemRequestDecoder(config) + decodedRequest, err := decoder(context.Background(), r) assert.Equal(testCase.ExpectedDecodedRequest, decodedRequest) assert.Equal(testCase.ExpectedErr, err) @@ -92,38 +88,20 @@ func TestEncodeGetOrDeleteItemResponse(t *testing.T) { ExpectedBody string ExpectedErr error }{ - { - Name: "Unexpected casting error", - ItemResponse: nil, - ExpectedHeaders: make(http.Header), - ExpectedErr: ErrCasting, - // used due to limitations in httptest. In reality, any non-nil error promises nothing - // would be written to the response writer - ExpectedCode: 200, - }, - { - Name: "Expired item", - ItemResponse: &OwnableItem{ - Item: model.Item{ - TTL: 0, - }, - }, - ExpectedCode: http.StatusNotFound, - ExpectedHeaders: make(http.Header), - }, { Name: "Happy path", ItemResponse: &OwnableItem{ Owner: "xmidt", Item: model.Item{ - TTL: 20, + UUID: "NaYFGE961cS_3dpzJcoP3QTL4kBYcw9ua3Q6Hy5E4nI", + TTL: aws.Int64(20), Identifier: "id01", Data: map[string]interface{}{ "key": 10, }, }, }, - ExpectedBody: `{"identifier":"id01","data":{"key":10},"ttl":20}`, + ExpectedBody: `{"uuid":"NaYFGE961cS_3dpzJcoP3QTL4kBYcw9ua3Q6Hy5E4nI","identifier":"id01","data":{"key":10},"ttl":20}`, ExpectedCode: 200, ExpectedHeaders: http.Header{ "Content-Type": []string{"application/json"}, @@ -144,62 +122,84 @@ func TestEncodeGetOrDeleteItemResponse(t *testing.T) { } } -func TestDecodeGetAllItemsRequest(t *testing.T) { - t.Run("Bucket Missing", testDecodeGetAllItemsRequestBucketMissing) - t.Run("Success", testDecodeGetAllItemsRequestSuccessful) -} - -func testDecodeGetAllItemsRequestBucketMissing(t *testing.T) { - assert := assert.New(t) - r := httptest.NewRequest(http.MethodGet, "http://localhost:9030", nil) +func TestgetAllItemsRequestDecoder(t *testing.T) { + testCases := []struct { + Name string + URLVars map[string]string + Headers map[string][]string + ExpectedDecodedRequest interface{} + ExpectedErr error + }{ + { + Name: "Happy path - No owner - Normal mode", + URLVars: map[string]string{ + "bucket": "california", + }, + ExpectedDecodedRequest: &getAllItemsRequest{ + bucket: "california", + }, + }, + { + Name: "Happy path - Owner - Admin mode", + URLVars: map[string]string{ + "bucket": "california", + "uuid": "san francisco", + }, - decodedRequest, err := decodeGetAllItemsRequest(context.Background(), r) - assert.Nil(decodedRequest) - assert.Equal(&BadRequestErr{Message: bucketVarMissingMsg}, err) -} + Headers: map[string][]string{ + ItemOwnerHeaderKey: []string{"SF Giants"}, + AdminTokenHeaderKey: []string{"secretAdminToken"}, + }, -func testDecodeGetAllItemsRequestSuccessful(t *testing.T) { - assert := assert.New(t) - r := httptest.NewRequest(http.MethodGet, "http://localhost:9030", nil) - r.Header.Set(ItemOwnerHeaderKey, "bob-ross") - r = mux.SetURLVars(r, map[string]string{bucketVarKey: "happy-little-accidents"}) - expectedDecodedRequest := &getAllItemsRequest{ - bucket: "happy-little-accidents", - owner: "bob-ross", + ExpectedDecodedRequest: &getAllItemsRequest{ + owner: "SF Giants", + adminMode: true, + }, + }, } - decodedRequest, err := decodeGetAllItemsRequest(context.Background(), r) - assert.Nil(err) - assert.Equal(expectedDecodedRequest, decodedRequest) + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + assert := assert.New(t) + r := httptest.NewRequest(http.MethodGet, "http://localhost/test", nil) + transferHeaders(testCase.Headers, r) + + r = mux.SetURLVars(r, testCase.URLVars) + config := &requestConfig{ + Authorization: authorizationConfig{ + AdminToken: "secretAdminToken", + }, + } + decoder := getAllItemsRequestDecoder(config) + decodedRequest, err := decoder(context.Background(), r) + + assert.Equal(testCase.ExpectedDecodedRequest, decodedRequest) + assert.Equal(testCase.ExpectedErr, err) + }) + } } func TestEncodeGetAllItemsResponse(t *testing.T) { assert := assert.New(t) response := map[string]OwnableItem{ - "fix-you": OwnableItem{ + "E-VG": OwnableItem{ Item: model.Item{ - Identifier: "coldplay-04", - TTL: 1, + UUID: "E-VG", + Identifier: "fix-you", Data: map[string]interface{}{}, + TTL: aws.Int64(1), }, }, - "bohemian-rhapsody": OwnableItem{ + "Y9G": OwnableItem{ Item: model.Item{ - Identifier: "queen-03", - TTL: 0, - Data: map[string]interface{}{}, - }, - }, - "don't-stop-me-know": OwnableItem{ - Item: model.Item{ - Identifier: "queen-02", - TTL: 0, + UUID: "Y9G", + Identifier: "this-is-it", Data: map[string]interface{}{}, }, }, } recorder := httptest.NewRecorder() - expectedResponseBody := `{"fix-you":{"identifier":"coldplay-04","data":{},"ttl":1}}` + expectedResponseBody := `[{"uuid":"E-VG","identifier":"fix-you","data":{},"ttl":1},{"uuid":"Y9G","identifier":"this-is-it","data":{}}]` err := encodeGetAllItemsResponse(context.Background(), recorder, response) assert.Nil(err) assert.Equal(expectedResponseBody, recorder.Body.String()) @@ -213,70 +213,63 @@ func transferHeaders(headers map[string][]string, r *http.Request) { } } -func TestPushItemRequestDecoder(t *testing.T) { +func TestsetItemRequestDecoder(t *testing.T) { testCases := []struct { Name string - Bucket string - Owner string - ID string + URLVars map[string]string + Headers map[string][]string RequestBody string ExpectedErr error - UpdateRequest bool - ExpectedRequest *pushItemRequest + ExpectedRequest *setItemRequest }{ - { - Name: "Missing bucket", - ExpectedErr: &BadRequestErr{ - Message: bucketVarMissingMsg, - }, - }, { Name: "Bad JSON data", + URLVars: map[string]string{bucketVarKey: "bucketVal", uuidVarKey: "rWPSg7pI0jj8mMG9tmscdQMOGKeRAquySfkObTasRBc"}, RequestBody: `{"validJSON": false,}`, - Bucket: "invalid", - ExpectedErr: &BadRequestErr{ + ExpectedErr: BadRequestErr{ Message: "failed to unmarshal json", }, }, { Name: "Missing data item field", - RequestBody: `{"identifier": "xyz"}`, - Bucket: "no-data", - ExpectedErr: &BadRequestErr{ + URLVars: map[string]string{bucketVarKey: "letters", uuidVarKey: "ypeBEsobvcr6wjGzmiPcTaeG7_gUfE5yuYB3ha_uSLs"}, + RequestBody: `{"uuid": "ypeBEsobvcr6wjGzmiPcTaeG7_gUfE5yuYB3ha_uSLs","identifier": "a"}`, + ExpectedErr: BadRequestErr{ Message: "data field must be set", }, }, { Name: "Capped TTL", - RequestBody: `{"identifier": "xyz", "data": {"x": 0, "y": 1, "z": 2}, "ttl": 3900}`, - Bucket: "variables", - Owner: "math", - ExpectedRequest: &pushItemRequest{ + URLVars: map[string]string{bucketVarKey: "variables", uuidVarKey: "evCz5Hw1gg-r72nMVCOSvS0PbjfDSYUXKPDGgwE1Y84"}, + Headers: map[string][]string{ItemOwnerHeaderKey: []string{"math"}}, + RequestBody: `{"uuid":"evCz5Hw1gg-r72nMVCOSvS0PbjfDSYUXKPDGgwE1Y84", "identifier": "xyz", "data": {"x": 0, "y": 1, "z": 2}, "ttl": 3900}`, + ExpectedRequest: &setItemRequest{ item: OwnableItem{ Item: model.Item{ + UUID: "evCz5Hw1gg-r72nMVCOSvS0PbjfDSYUXKPDGgwE1Y84", Identifier: "xyz", Data: map[string]interface{}{ "x": float64(0), "y": float64(1), "z": float64(2), }, - TTL: int64(time.Hour.Seconds()), + TTL: aws.Int64(int64((time.Minute * 5).Seconds())), }, Owner: "math", }, key: model.Key{ Bucket: "variables", + UUID: "evCz5Hw1gg-r72nMVCOSvS0PbjfDSYUXKPDGgwE1Y84", }, }, }, { - Name: "Defaulted TTL", - RequestBody: `{"identifier": "xyz", "data": {"x": 0, "y": 1, "z": 2}}`, - Bucket: "variables", - Owner: "math", - ExpectedRequest: &pushItemRequest{ + Name: "UUID mismatch TTL", + URLVars: map[string]string{bucketVarKey: "variables", uuidVarKey: "evCz5Hw1gg-r72nMVCOSvS0PbjfDSYUXKPDGgwE1Y84"}, + RequestBody: `{"uuid":"iBCtWB5Z8rw5KLJhcHpxMI9-E56wSCA2bcTVwY2YAiU", "identifier": "xyz", "data": {"x": 0, "y": 1, "z": 2}, "ttl": 3900}`, + ExpectedRequest: &setItemRequest{ item: OwnableItem{ Item: model.Item{ Identifier: "xyz", @@ -285,9 +278,8 @@ func TestPushItemRequestDecoder(t *testing.T) { "y": float64(1), "z": float64(2), }, - TTL: 60, + TTL: aws.Int64(60), }, - Owner: "math", }, key: model.Key{ Bucket: "variables", @@ -296,102 +288,54 @@ func TestPushItemRequestDecoder(t *testing.T) { }, { - Name: "Happy path", - RequestBody: `{"identifier": "xyz", "data": {"x": 0, "y": 1, "z": 2}, "ttl": 120}`, - Bucket: "variables", - Owner: "math", - ExpectedRequest: &pushItemRequest{ + Name: "Happy Path - Admin mode", + URLVars: map[string]string{bucketVarKey: "variables", uuidVarKey: "evCz5Hw1gg-r72nMVCOSvS0PbjfDSYUXKPDGgwE1Y84"}, + Headers: map[string][]string{ItemOwnerHeaderKey: []string{"math"}, AdminTokenHeaderKey: []string{"secretAdminPassKey"}}, + RequestBody: `{"uuid":"evCz5Hw1gg-r72nMVCOSvS0PbjfDSYUXKPDGgwE1Y84", "identifier": "xyz", "data": {"x": 0, "y": 1, "z": 2}, "ttl": 39}`, + ExpectedRequest: &setItemRequest{ item: OwnableItem{ Item: model.Item{ + UUID: "evCz5Hw1gg-r72nMVCOSvS0PbjfDSYUXKPDGgwE1Y84", Identifier: "xyz", Data: map[string]interface{}{ "x": float64(0), "y": float64(1), "z": float64(2), }, - TTL: 120, + TTL: aws.Int64(39), }, Owner: "math", }, key: model.Key{ Bucket: "variables", + UUID: "evCz5Hw1gg-r72nMVCOSvS0PbjfDSYUXKPDGgwE1Y84", }, + adminMode: true, }, }, - { - Name: "Update Request", - RequestBody: `{"identifier": "xyz", "data": {"x": 0, "y": 1, "z": 2}, "ttl": 120}`, - Bucket: "variables", - Owner: "math", - ID: "id-that-should-stay-the-same", - UpdateRequest: true, - ExpectedRequest: &pushItemRequest{ - item: OwnableItem{ - Item: model.Item{ - Identifier: "xyz", - Data: map[string]interface{}{ - "x": float64(0), - "y": float64(1), - "z": float64(2), - }, - TTL: 120, - }, - Owner: "math", - }, - key: model.Key{ - Bucket: "variables", - ID: "id-that-should-stay-the-same", - }, - }, - }, - { - Name: "Update Request-No ID", - RequestBody: `{"identifier": "xyz", "data": {"x": 0, "y": 1, "z": 2}, "ttl": 120}`, - Bucket: "variables", - Owner: "math", - UpdateRequest: true, - ExpectedErr: &BadRequestErr{Message: idVarMissingMsg}, - }, } for _, testCase := range testCases { t.Run(testCase.Name, func(t *testing.T) { assert := assert.New(t) r := httptest.NewRequest(http.MethodGet, "http://localhost", bytes.NewBufferString(testCase.RequestBody)) - if len(testCase.Bucket) > 0 || len(testCase.ID) > 0 { - - pathVars := make(map[string]string) - if len(testCase.Bucket) > 0 { - pathVars[bucketVarKey] = testCase.Bucket - } - - if len(testCase.ID) > 0 { - pathVars[idVarKey] = testCase.ID - } - - r = mux.SetURLVars(r, pathVars) - } + r = mux.SetURLVars(r, testCase.URLVars) + transferHeaders(testCase.Headers, r) - if len(testCase.Owner) > 0 { - r.Header.Set(ItemOwnerHeaderKey, testCase.Owner) + config := &requestConfig{ + Authorization: authorizationConfig{ + AdminToken: "secretAdminPassKey", + }, + Validation: validationConfig{ + MaxTTL: time.Minute * 5, + }, } - decoder := pushItemRequestDecoder(ItemTTL{ - DefaultTTL: time.Minute, - MaxTTL: time.Hour, - }, testCase.UpdateRequest) + decoder := setItemRequestDecoder(config) decodedRequest, err := decoder(context.Background(), r) if testCase.ExpectedRequest == nil { assert.Nil(decodedRequest) } else { - decodedRequestCast, ok := decodedRequest.(*pushItemRequest) - assert.True(ok) - - // id should only be generated if it is not an update request - if !testCase.UpdateRequest { - testCase.ExpectedRequest.key.ID = decodedRequestCast.key.ID - } - assert.Equal(testCase.ExpectedRequest, decodedRequest) } assert.Equal(testCase.ExpectedErr, err) @@ -399,13 +343,19 @@ func TestPushItemRequestDecoder(t *testing.T) { } } -func TestEncodePushItemResponse(t *testing.T) { +func TestEncodeSetItemResponse(t *testing.T) { assert := assert.New(t) - recorder := httptest.NewRecorder() - err := encodePushItemResponse(context.Background(), recorder, &model.Key{ - Bucket: "north-america", - ID: "usa", + createdRecorder := httptest.NewRecorder() + err := encodeSetItemResponse(context.Background(), createdRecorder, &setItemResponse{ + existingResource: false, + }) + assert.Nil(err) + assert.Equal(http.StatusCreated, createdRecorder.Code) + + updatedRecorder := httptest.NewRecorder() + err = encodeSetItemResponse(context.Background(), updatedRecorder, &setItemResponse{ + existingResource: true, }) assert.Nil(err) - assert.Equal(`{"bucket":"north-america","id":"usa"}`, recorder.Body.String()) + assert.Equal(http.StatusOK, updatedRecorder.Code) }