diff --git a/.envrc b/.envrc index c5e72d09ad2..e2aacb52680 100644 --- a/.envrc +++ b/.envrc @@ -232,19 +232,22 @@ export TZ="UTC" # AWS development access # -# To use S3/SES for local builds, you'll need to uncomment the following. +# To use S3/SES or SNS & SQS for local builds, you'll need to uncomment the following. # Do not commit the change: # # export STORAGE_BACKEND=s3 # export EMAIL_BACKEND=ses +# export RECEIVER_BACKEND=sns_sqs # # Instructions for using S3 storage backend here: https://dp3.atlassian.net/wiki/spaces/MT/pages/1470955567/How+to+test+storing+data+in+S3+locally # Instructions for using SES email backend here: https://dp3.atlassian.net/wiki/spaces/MT/pages/1467973894/How+to+test+sending+email+locally +# Instructions for using SNS&SQS backend here: https://dp3.atlassian.net/wiki/spaces/MT/pages/2793242625/How+to+test+notifications+receiver+locally # # The default and equivalent to not being set is: # # export STORAGE_BACKEND=local # export EMAIL_BACKEND=local +# export RECEIVER_BACKEND=local # # Setting region and profile conditionally while we migrate from com to govcloud. if [ "$STORAGE_BACKEND" == "s3" ]; then @@ -258,6 +261,13 @@ export AWS_S3_KEY_NAMESPACE=$USER export AWS_SES_DOMAIN="devlocal.dp3.us" export AWS_SES_REGION="us-gov-west-1" +if [ "$RECEIVER_BACKEND" == "sns_sqs" ]; then + export SNS_TAGS_UPDATED_TOPIC="app_s3_tag_events" + export SNS_REGION="us-gov-west-1" +# cleanup flag false by default, only used at server startup to wipe receiver artifacts from previous runs +# export RECEIVER_CLEANUP_ON_START=false +fi + # To use s3 links aws-bucketname/xx/user/ for local builds, # you'll need to add the following to your .envrc.local: # @@ -415,7 +425,7 @@ if [ ! -r .nix-disable ] && has nix-env; then # add the NIX_PROFILE bin path so that everything we just installed # is available on the path - PATH_add ${NIX_PROFILE}/bin + PATH_add "${NIX_PROFILE}"/bin # Add the node binaries to our path PATH_add ./node_modules/.bin # nix is immutable, so we need to specify a path for local changes, e.g. @@ -444,4 +454,4 @@ then fi # Check that all required environment variables are set -check_required_variables \ No newline at end of file +check_required_variables diff --git a/cmd/milmove/serve.go b/cmd/milmove/serve.go index 505936d3868..4f05b86beaa 100644 --- a/cmd/milmove/serve.go +++ b/cmd/milmove/serve.go @@ -478,6 +478,13 @@ func buildRoutingConfig(appCtx appcontext.AppContext, v *viper.Viper, redisPool appCtx.Logger().Fatal("notification sender sending not enabled", zap.Error(err)) } + // Notification Receiver + runReceiverCleanup := v.GetBool(cli.ReceiverCleanupOnStartFlag) // Cleanup aws artifacts left over from previous runs + notificationReceiver, err := notifications.InitReceiver(v, appCtx.Logger(), runReceiverCleanup) + if err != nil { + appCtx.Logger().Fatal("notification receiver not enabled", zap.Error(err)) + } + routingConfig.BuildRoot = v.GetString(cli.BuildRootFlag) sendProductionInvoice := v.GetBool(cli.GEXSendProdInvoiceFlag) @@ -567,6 +574,7 @@ func buildRoutingConfig(appCtx appcontext.AppContext, v *viper.Viper, redisPool dtodRoutePlanner, fileStorer, notificationSender, + notificationReceiver, iwsPersonLookup, sendProductionInvoice, gexSender, diff --git a/go.mod b/go.mod index 6e0a3a4356e..3838a18537b 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,8 @@ require ( github.com/aws/aws-sdk-go-v2/service/rds v1.78.2 github.com/aws/aws-sdk-go-v2/service/s3 v1.59.0 github.com/aws/aws-sdk-go-v2/service/ses v1.25.3 + github.com/aws/aws-sdk-go-v2/service/sns v1.31.8 + github.com/aws/aws-sdk-go-v2/service/sqs v1.34.6 github.com/aws/aws-sdk-go-v2/service/ssm v1.52.8 github.com/aws/aws-sdk-go-v2/service/sts v1.30.7 github.com/aws/smithy-go v1.20.4 diff --git a/go.sum b/go.sum index 8dbbb90bbe0..8bd8f6a235e 100644 --- a/go.sum +++ b/go.sum @@ -82,6 +82,10 @@ github.com/aws/aws-sdk-go-v2/service/s3 v1.59.0 h1:Cso4Ev/XauMVsbwdhYEoxg8rxZWw4 github.com/aws/aws-sdk-go-v2/service/s3 v1.59.0/go.mod h1:BSPI0EfnYUuNHPS0uqIo5VrRwzie+Fp+YhQOUs16sKI= github.com/aws/aws-sdk-go-v2/service/ses v1.25.3 h1:wcfUsE2nqsXhEj68gxr7MnGXNPcBPKx0RW2DzBVgVlM= github.com/aws/aws-sdk-go-v2/service/ses v1.25.3/go.mod h1:6Ul/Ir8oOCsI3dFN0prULK9fvpxP+WTYmlHDkFzaAVA= +github.com/aws/aws-sdk-go-v2/service/sns v1.31.8 h1:vRSk062d1SmaEVbiqFePkvYuhCTnW2JnPkUdt19nqeY= +github.com/aws/aws-sdk-go-v2/service/sns v1.31.8/go.mod h1:wjhxA9hlVu75dCL/5Wcx8Cwmszvu6t0i8WEDypcB4+s= +github.com/aws/aws-sdk-go-v2/service/sqs v1.34.6 h1:DbjODDHumQBdJ3T+EO7AXVoFUeUhAsJYOdjStH5Ws4A= +github.com/aws/aws-sdk-go-v2/service/sqs v1.34.6/go.mod h1:7idt3XszF6sE9WPS1GqZRiDJOxw4oPtlRBXodWnCGjU= github.com/aws/aws-sdk-go-v2/service/ssm v1.52.8 h1:7cjN4Wp3U3cud17TsnUxSomTwKzKQGUWdq/N1aWqgMk= github.com/aws/aws-sdk-go-v2/service/ssm v1.52.8/go.mod h1:nUSNPaG8mv5rIu7EclHnFqZOjhreEUwRKENtKTtJ9aw= github.com/aws/aws-sdk-go-v2/service/sso v1.22.7 h1:pIaGg+08llrP7Q5aiz9ICWbY8cqhTkyy+0SHvfzQpTc= diff --git a/pkg/cli/receiver.go b/pkg/cli/receiver.go new file mode 100644 index 00000000000..ed71d45d209 --- /dev/null +++ b/pkg/cli/receiver.go @@ -0,0 +1,61 @@ +package cli + +import ( + "fmt" + + "github.com/spf13/pflag" + "github.com/spf13/viper" +) + +const ( + // ReceiverBackendFlag is the Receiver Backend Flag + ReceiverBackendFlag string = "receiver-backend" + // SNSTagsUpdatedTopicFlag is the SNS Tags Updated Topic Flag + SNSTagsUpdatedTopicFlag string = "sns-tags-updated-topic" + // SNSRegionFlag is the SNS Region flag + SNSRegionFlag string = "sns-region" + // SNSAccountId is the application's AWS account id + SNSAccountId string = "aws-account-id" + // ReceiverCleanupOnStartFlag is the Receiver Cleanup On Start Flag + ReceiverCleanupOnStartFlag string = "receiver-cleanup-on-start" +) + +// InitReceiverFlags initializes Storage command line flags +func InitReceiverFlags(flag *pflag.FlagSet) { + flag.String(ReceiverBackendFlag, "local", "Receiver backend to use, either local or sns_sqs.") + flag.String(SNSTagsUpdatedTopicFlag, "", "SNS Topic for receiving event messages") + flag.String(SNSRegionFlag, "", "Region used for SNS and SQS") + flag.String(SNSAccountId, "", "SNS account Id") + flag.Bool(ReceiverCleanupOnStartFlag, false, "Receiver will cleanup previous aws artifacts on start.") +} + +// CheckReceiver validates Storage command line flags +func CheckReceiver(v *viper.Viper) error { + + receiverBackend := v.GetString(ReceiverBackendFlag) + if !stringSliceContains([]string{"local", "sns_sqs"}, receiverBackend) { + return fmt.Errorf("invalid receiver_backend %s, expecting local or sns_sqs", receiverBackend) + } + + receiverCleanupOnStart := v.GetString(ReceiverCleanupOnStartFlag) + if !stringSliceContains([]string{"true", "false"}, receiverCleanupOnStart) { + return fmt.Errorf("invalid receiver_cleanup_on_start %s, expecting true or false", receiverCleanupOnStart) + } + + if receiverBackend == "sns_sqs" { + r := v.GetString(SNSRegionFlag) + if r == "" { + return fmt.Errorf("invalid value for %s: %s", SNSRegionFlag, r) + } + topic := v.GetString(SNSTagsUpdatedTopicFlag) + if topic == "" { + return fmt.Errorf("invalid value for %s: %s", SNSTagsUpdatedTopicFlag, topic) + } + accountId := v.GetString(SNSAccountId) + if topic == "" { + return fmt.Errorf("invalid value for %s: %s", SNSAccountId, accountId) + } + } + + return nil +} diff --git a/pkg/cli/receiver_test.go b/pkg/cli/receiver_test.go new file mode 100644 index 00000000000..7095a672f5f --- /dev/null +++ b/pkg/cli/receiver_test.go @@ -0,0 +1,6 @@ +package cli + +func (suite *cliTestSuite) TestConfigReceiver() { + suite.Setup(InitReceiverFlags, []string{}) + suite.NoError(CheckReceiver(suite.viper)) +} diff --git a/pkg/gen/ghcapi/configure_mymove.go b/pkg/gen/ghcapi/configure_mymove.go index 32eb5174c09..bb80917f608 100644 --- a/pkg/gen/ghcapi/configure_mymove.go +++ b/pkg/gen/ghcapi/configure_mymove.go @@ -4,6 +4,7 @@ package ghcapi import ( "crypto/tls" + "io" "net/http" "github.com/go-openapi/errors" @@ -64,6 +65,9 @@ func configureAPI(api *ghcoperations.MymoveAPI) http.Handler { api.BinProducer = runtime.ByteStreamProducer() api.JSONProducer = runtime.JSONProducer() + api.TextEventStreamProducer = runtime.ProducerFunc(func(w io.Writer, data interface{}) error { + return errors.NotImplemented("textEventStream producer has not yet been implemented") + }) // You may change here the memory limit for this multipart form parser. Below is the default (32 MB). // uploads.CreateUploadMaxParseMemory = 32 << 20 @@ -392,6 +396,11 @@ func configureAPI(api *ghcoperations.MymoveAPI) http.Handler { return middleware.NotImplemented("operation uploads.GetUpload has not yet been implemented") }) } + if api.UploadsGetUploadStatusHandler == nil { + api.UploadsGetUploadStatusHandler = uploads.GetUploadStatusHandlerFunc(func(params uploads.GetUploadStatusParams) middleware.Responder { + return middleware.NotImplemented("operation uploads.GetUploadStatus has not yet been implemented") + }) + } if api.CalendarIsDateWeekendHolidayHandler == nil { api.CalendarIsDateWeekendHolidayHandler = calendar.IsDateWeekendHolidayHandlerFunc(func(params calendar.IsDateWeekendHolidayParams) middleware.Responder { return middleware.NotImplemented("operation calendar.IsDateWeekendHoliday has not yet been implemented") diff --git a/pkg/gen/ghcapi/doc.go b/pkg/gen/ghcapi/doc.go index 24f788c8fb2..24ba756c211 100644 --- a/pkg/gen/ghcapi/doc.go +++ b/pkg/gen/ghcapi/doc.go @@ -21,6 +21,7 @@ // Produces: // - application/pdf // - application/json +// - text/event-stream // // swagger:meta package ghcapi diff --git a/pkg/gen/ghcapi/embedded_spec.go b/pkg/gen/ghcapi/embedded_spec.go index 7f8a82f24d8..f50c560d2ab 100644 --- a/pkg/gen/ghcapi/embedded_spec.go +++ b/pkg/gen/ghcapi/embedded_spec.go @@ -6413,6 +6413,58 @@ func init() { } } }, + "/uploads/{uploadID}/status": { + "get": { + "description": "Returns status of an upload based on antivirus run", + "produces": [ + "text/event-stream" + ], + "tags": [ + "uploads" + ], + "summary": "Returns status of an upload", + "operationId": "getUploadStatus", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "UUID of the upload to return status of", + "name": "uploadID", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "the requested upload status", + "schema": { + "type": "string", + "enum": [ + "INFECTED", + "CLEAN", + "PROCESSING" + ], + "readOnly": true + } + }, + "400": { + "description": "invalid request", + "schema": { + "$ref": "#/definitions/InvalidRequestResponsePayload" + } + }, + "403": { + "description": "not authorized" + }, + "404": { + "description": "not found" + }, + "500": { + "description": "server error" + } + } + } + }, "/uploads/{uploadID}/update": { "patch": { "description": "Uploads represent a single digital file, such as a JPEG or PDF. The rotation is relevant to how it is displayed on the page.", @@ -23459,6 +23511,58 @@ func init() { } } }, + "/uploads/{uploadID}/status": { + "get": { + "description": "Returns status of an upload based on antivirus run", + "produces": [ + "text/event-stream" + ], + "tags": [ + "uploads" + ], + "summary": "Returns status of an upload", + "operationId": "getUploadStatus", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "UUID of the upload to return status of", + "name": "uploadID", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "the requested upload status", + "schema": { + "type": "string", + "enum": [ + "INFECTED", + "CLEAN", + "PROCESSING" + ], + "readOnly": true + } + }, + "400": { + "description": "invalid request", + "schema": { + "$ref": "#/definitions/InvalidRequestResponsePayload" + } + }, + "403": { + "description": "not authorized" + }, + "404": { + "description": "not found" + }, + "500": { + "description": "server error" + } + } + } + }, "/uploads/{uploadID}/update": { "patch": { "description": "Uploads represent a single digital file, such as a JPEG or PDF. The rotation is relevant to how it is displayed on the page.", diff --git a/pkg/gen/ghcapi/ghcoperations/mymove_api.go b/pkg/gen/ghcapi/ghcoperations/mymove_api.go index c53c0fec4d7..57a2b196ffc 100644 --- a/pkg/gen/ghcapi/ghcoperations/mymove_api.go +++ b/pkg/gen/ghcapi/ghcoperations/mymove_api.go @@ -7,6 +7,7 @@ package ghcoperations import ( "fmt" + "io" "net/http" "strings" @@ -70,6 +71,9 @@ func NewMymoveAPI(spec *loads.Document) *MymoveAPI { BinProducer: runtime.ByteStreamProducer(), JSONProducer: runtime.JSONProducer(), + TextEventStreamProducer: runtime.ProducerFunc(func(w io.Writer, data interface{}) error { + return errors.NotImplemented("textEventStream producer has not yet been implemented") + }), OrderAcknowledgeExcessUnaccompaniedBaggageWeightRiskHandler: order.AcknowledgeExcessUnaccompaniedBaggageWeightRiskHandlerFunc(func(params order.AcknowledgeExcessUnaccompaniedBaggageWeightRiskParams) middleware.Responder { return middleware.NotImplemented("operation order.AcknowledgeExcessUnaccompaniedBaggageWeightRisk has not yet been implemented") @@ -263,6 +267,9 @@ func NewMymoveAPI(spec *loads.Document) *MymoveAPI { UploadsGetUploadHandler: uploads.GetUploadHandlerFunc(func(params uploads.GetUploadParams) middleware.Responder { return middleware.NotImplemented("operation uploads.GetUpload has not yet been implemented") }), + UploadsGetUploadStatusHandler: uploads.GetUploadStatusHandlerFunc(func(params uploads.GetUploadStatusParams) middleware.Responder { + return middleware.NotImplemented("operation uploads.GetUploadStatus has not yet been implemented") + }), CalendarIsDateWeekendHolidayHandler: calendar.IsDateWeekendHolidayHandlerFunc(func(params calendar.IsDateWeekendHolidayParams) middleware.Responder { return middleware.NotImplemented("operation calendar.IsDateWeekendHoliday has not yet been implemented") }), @@ -440,6 +447,9 @@ type MymoveAPI struct { // JSONProducer registers a producer for the following mime types: // - application/json JSONProducer runtime.Producer + // TextEventStreamProducer registers a producer for the following mime types: + // - text/event-stream + TextEventStreamProducer runtime.Producer // OrderAcknowledgeExcessUnaccompaniedBaggageWeightRiskHandler sets the operation handler for the acknowledge excess unaccompanied baggage weight risk operation OrderAcknowledgeExcessUnaccompaniedBaggageWeightRiskHandler order.AcknowledgeExcessUnaccompaniedBaggageWeightRiskHandler @@ -569,6 +579,8 @@ type MymoveAPI struct { TransportationOfficeGetTransportationOfficesOpenHandler transportation_office.GetTransportationOfficesOpenHandler // UploadsGetUploadHandler sets the operation handler for the get upload operation UploadsGetUploadHandler uploads.GetUploadHandler + // UploadsGetUploadStatusHandler sets the operation handler for the get upload status operation + UploadsGetUploadStatusHandler uploads.GetUploadStatusHandler // CalendarIsDateWeekendHolidayHandler sets the operation handler for the is date weekend holiday operation CalendarIsDateWeekendHolidayHandler calendar.IsDateWeekendHolidayHandler // MtoServiceItemListMTOServiceItemsHandler sets the operation handler for the list m t o service items operation @@ -739,6 +751,9 @@ func (o *MymoveAPI) Validate() error { if o.JSONProducer == nil { unregistered = append(unregistered, "JSONProducer") } + if o.TextEventStreamProducer == nil { + unregistered = append(unregistered, "TextEventStreamProducer") + } if o.OrderAcknowledgeExcessUnaccompaniedBaggageWeightRiskHandler == nil { unregistered = append(unregistered, "order.AcknowledgeExcessUnaccompaniedBaggageWeightRiskHandler") @@ -932,6 +947,9 @@ func (o *MymoveAPI) Validate() error { if o.UploadsGetUploadHandler == nil { unregistered = append(unregistered, "uploads.GetUploadHandler") } + if o.UploadsGetUploadStatusHandler == nil { + unregistered = append(unregistered, "uploads.GetUploadStatusHandler") + } if o.CalendarIsDateWeekendHolidayHandler == nil { unregistered = append(unregistered, "calendar.IsDateWeekendHolidayHandler") } @@ -1116,6 +1134,8 @@ func (o *MymoveAPI) ProducersFor(mediaTypes []string) map[string]runtime.Produce result["application/pdf"] = o.BinProducer case "application/json": result["application/json"] = o.JSONProducer + case "text/event-stream": + result["text/event-stream"] = o.TextEventStreamProducer } if p, ok := o.customProducers[mt]; ok { @@ -1415,6 +1435,10 @@ func (o *MymoveAPI) initHandlerCache() { if o.handlers["GET"] == nil { o.handlers["GET"] = make(map[string]http.Handler) } + o.handlers["GET"]["/uploads/{uploadID}/status"] = uploads.NewGetUploadStatus(o.context, o.UploadsGetUploadStatusHandler) + if o.handlers["GET"] == nil { + o.handlers["GET"] = make(map[string]http.Handler) + } o.handlers["GET"]["/calendar/{countryCode}/is-weekend-holiday/{date}"] = calendar.NewIsDateWeekendHoliday(o.context, o.CalendarIsDateWeekendHolidayHandler) if o.handlers["GET"] == nil { o.handlers["GET"] = make(map[string]http.Handler) diff --git a/pkg/gen/ghcapi/ghcoperations/uploads/get_upload_status.go b/pkg/gen/ghcapi/ghcoperations/uploads/get_upload_status.go new file mode 100644 index 00000000000..b893657d488 --- /dev/null +++ b/pkg/gen/ghcapi/ghcoperations/uploads/get_upload_status.go @@ -0,0 +1,58 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package uploads + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the generate command + +import ( + "net/http" + + "github.com/go-openapi/runtime/middleware" +) + +// GetUploadStatusHandlerFunc turns a function with the right signature into a get upload status handler +type GetUploadStatusHandlerFunc func(GetUploadStatusParams) middleware.Responder + +// Handle executing the request and returning a response +func (fn GetUploadStatusHandlerFunc) Handle(params GetUploadStatusParams) middleware.Responder { + return fn(params) +} + +// GetUploadStatusHandler interface for that can handle valid get upload status params +type GetUploadStatusHandler interface { + Handle(GetUploadStatusParams) middleware.Responder +} + +// NewGetUploadStatus creates a new http.Handler for the get upload status operation +func NewGetUploadStatus(ctx *middleware.Context, handler GetUploadStatusHandler) *GetUploadStatus { + return &GetUploadStatus{Context: ctx, Handler: handler} +} + +/* + GetUploadStatus swagger:route GET /uploads/{uploadID}/status uploads getUploadStatus + +# Returns status of an upload + +Returns status of an upload based on antivirus run +*/ +type GetUploadStatus struct { + Context *middleware.Context + Handler GetUploadStatusHandler +} + +func (o *GetUploadStatus) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + route, rCtx, _ := o.Context.RouteInfo(r) + if rCtx != nil { + *r = *rCtx + } + var Params = NewGetUploadStatusParams() + if err := o.Context.BindValidRequest(r, route, &Params); err != nil { // bind params + o.Context.Respond(rw, r, route.Produces, route, err) + return + } + + res := o.Handler.Handle(Params) // actually handle the request + o.Context.Respond(rw, r, route.Produces, route, res) + +} diff --git a/pkg/gen/ghcapi/ghcoperations/uploads/get_upload_status_parameters.go b/pkg/gen/ghcapi/ghcoperations/uploads/get_upload_status_parameters.go new file mode 100644 index 00000000000..fa1b3ef9329 --- /dev/null +++ b/pkg/gen/ghcapi/ghcoperations/uploads/get_upload_status_parameters.go @@ -0,0 +1,91 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package uploads + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "net/http" + + "github.com/go-openapi/errors" + "github.com/go-openapi/runtime/middleware" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/validate" +) + +// NewGetUploadStatusParams creates a new GetUploadStatusParams object +// +// There are no default values defined in the spec. +func NewGetUploadStatusParams() GetUploadStatusParams { + + return GetUploadStatusParams{} +} + +// GetUploadStatusParams contains all the bound params for the get upload status operation +// typically these are obtained from a http.Request +// +// swagger:parameters getUploadStatus +type GetUploadStatusParams struct { + + // HTTP Request Object + HTTPRequest *http.Request `json:"-"` + + /*UUID of the upload to return status of + Required: true + In: path + */ + UploadID strfmt.UUID +} + +// BindRequest both binds and validates a request, it assumes that complex things implement a Validatable(strfmt.Registry) error interface +// for simple values it will use straight method calls. +// +// To ensure default values, the struct must have been initialized with NewGetUploadStatusParams() beforehand. +func (o *GetUploadStatusParams) BindRequest(r *http.Request, route *middleware.MatchedRoute) error { + var res []error + + o.HTTPRequest = r + + rUploadID, rhkUploadID, _ := route.Params.GetOK("uploadID") + if err := o.bindUploadID(rUploadID, rhkUploadID, route.Formats); err != nil { + res = append(res, err) + } + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +// bindUploadID binds and validates parameter UploadID from path. +func (o *GetUploadStatusParams) bindUploadID(rawData []string, hasKey bool, formats strfmt.Registry) error { + var raw string + if len(rawData) > 0 { + raw = rawData[len(rawData)-1] + } + + // Required: true + // Parameter is provided by construction from the route + + // Format: uuid + value, err := formats.Parse("uuid", raw) + if err != nil { + return errors.InvalidType("uploadID", "path", "strfmt.UUID", raw) + } + o.UploadID = *(value.(*strfmt.UUID)) + + if err := o.validateUploadID(formats); err != nil { + return err + } + + return nil +} + +// validateUploadID carries on validations for parameter UploadID +func (o *GetUploadStatusParams) validateUploadID(formats strfmt.Registry) error { + + if err := validate.FormatOf("uploadID", "path", "uuid", o.UploadID.String(), formats); err != nil { + return err + } + return nil +} diff --git a/pkg/gen/ghcapi/ghcoperations/uploads/get_upload_status_responses.go b/pkg/gen/ghcapi/ghcoperations/uploads/get_upload_status_responses.go new file mode 100644 index 00000000000..894980d6a2b --- /dev/null +++ b/pkg/gen/ghcapi/ghcoperations/uploads/get_upload_status_responses.go @@ -0,0 +1,177 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package uploads + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "net/http" + + "github.com/go-openapi/runtime" + + "github.com/transcom/mymove/pkg/gen/ghcmessages" +) + +// GetUploadStatusOKCode is the HTTP code returned for type GetUploadStatusOK +const GetUploadStatusOKCode int = 200 + +/* +GetUploadStatusOK the requested upload status + +swagger:response getUploadStatusOK +*/ +type GetUploadStatusOK struct { + + /* + In: Body + */ + Payload string `json:"body,omitempty"` +} + +// NewGetUploadStatusOK creates GetUploadStatusOK with default headers values +func NewGetUploadStatusOK() *GetUploadStatusOK { + + return &GetUploadStatusOK{} +} + +// WithPayload adds the payload to the get upload status o k response +func (o *GetUploadStatusOK) WithPayload(payload string) *GetUploadStatusOK { + o.Payload = payload + return o +} + +// SetPayload sets the payload to the get upload status o k response +func (o *GetUploadStatusOK) SetPayload(payload string) { + o.Payload = payload +} + +// WriteResponse to the client +func (o *GetUploadStatusOK) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) { + + rw.WriteHeader(200) + payload := o.Payload + if err := producer.Produce(rw, payload); err != nil { + panic(err) // let the recovery middleware deal with this + } +} + +// GetUploadStatusBadRequestCode is the HTTP code returned for type GetUploadStatusBadRequest +const GetUploadStatusBadRequestCode int = 400 + +/* +GetUploadStatusBadRequest invalid request + +swagger:response getUploadStatusBadRequest +*/ +type GetUploadStatusBadRequest struct { + + /* + In: Body + */ + Payload *ghcmessages.InvalidRequestResponsePayload `json:"body,omitempty"` +} + +// NewGetUploadStatusBadRequest creates GetUploadStatusBadRequest with default headers values +func NewGetUploadStatusBadRequest() *GetUploadStatusBadRequest { + + return &GetUploadStatusBadRequest{} +} + +// WithPayload adds the payload to the get upload status bad request response +func (o *GetUploadStatusBadRequest) WithPayload(payload *ghcmessages.InvalidRequestResponsePayload) *GetUploadStatusBadRequest { + o.Payload = payload + return o +} + +// SetPayload sets the payload to the get upload status bad request response +func (o *GetUploadStatusBadRequest) SetPayload(payload *ghcmessages.InvalidRequestResponsePayload) { + o.Payload = payload +} + +// WriteResponse to the client +func (o *GetUploadStatusBadRequest) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) { + + rw.WriteHeader(400) + if o.Payload != nil { + payload := o.Payload + if err := producer.Produce(rw, payload); err != nil { + panic(err) // let the recovery middleware deal with this + } + } +} + +// GetUploadStatusForbiddenCode is the HTTP code returned for type GetUploadStatusForbidden +const GetUploadStatusForbiddenCode int = 403 + +/* +GetUploadStatusForbidden not authorized + +swagger:response getUploadStatusForbidden +*/ +type GetUploadStatusForbidden struct { +} + +// NewGetUploadStatusForbidden creates GetUploadStatusForbidden with default headers values +func NewGetUploadStatusForbidden() *GetUploadStatusForbidden { + + return &GetUploadStatusForbidden{} +} + +// WriteResponse to the client +func (o *GetUploadStatusForbidden) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) { + + rw.Header().Del(runtime.HeaderContentType) //Remove Content-Type on empty responses + + rw.WriteHeader(403) +} + +// GetUploadStatusNotFoundCode is the HTTP code returned for type GetUploadStatusNotFound +const GetUploadStatusNotFoundCode int = 404 + +/* +GetUploadStatusNotFound not found + +swagger:response getUploadStatusNotFound +*/ +type GetUploadStatusNotFound struct { +} + +// NewGetUploadStatusNotFound creates GetUploadStatusNotFound with default headers values +func NewGetUploadStatusNotFound() *GetUploadStatusNotFound { + + return &GetUploadStatusNotFound{} +} + +// WriteResponse to the client +func (o *GetUploadStatusNotFound) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) { + + rw.Header().Del(runtime.HeaderContentType) //Remove Content-Type on empty responses + + rw.WriteHeader(404) +} + +// GetUploadStatusInternalServerErrorCode is the HTTP code returned for type GetUploadStatusInternalServerError +const GetUploadStatusInternalServerErrorCode int = 500 + +/* +GetUploadStatusInternalServerError server error + +swagger:response getUploadStatusInternalServerError +*/ +type GetUploadStatusInternalServerError struct { +} + +// NewGetUploadStatusInternalServerError creates GetUploadStatusInternalServerError with default headers values +func NewGetUploadStatusInternalServerError() *GetUploadStatusInternalServerError { + + return &GetUploadStatusInternalServerError{} +} + +// WriteResponse to the client +func (o *GetUploadStatusInternalServerError) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) { + + rw.Header().Del(runtime.HeaderContentType) //Remove Content-Type on empty responses + + rw.WriteHeader(500) +} diff --git a/pkg/gen/ghcapi/ghcoperations/uploads/get_upload_status_urlbuilder.go b/pkg/gen/ghcapi/ghcoperations/uploads/get_upload_status_urlbuilder.go new file mode 100644 index 00000000000..edd3c2fd6f8 --- /dev/null +++ b/pkg/gen/ghcapi/ghcoperations/uploads/get_upload_status_urlbuilder.go @@ -0,0 +1,101 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package uploads + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the generate command + +import ( + "errors" + "net/url" + golangswaggerpaths "path" + "strings" + + "github.com/go-openapi/strfmt" +) + +// GetUploadStatusURL generates an URL for the get upload status operation +type GetUploadStatusURL struct { + UploadID strfmt.UUID + + _basePath string + // avoid unkeyed usage + _ struct{} +} + +// WithBasePath sets the base path for this url builder, only required when it's different from the +// base path specified in the swagger spec. +// When the value of the base path is an empty string +func (o *GetUploadStatusURL) WithBasePath(bp string) *GetUploadStatusURL { + o.SetBasePath(bp) + return o +} + +// SetBasePath sets the base path for this url builder, only required when it's different from the +// base path specified in the swagger spec. +// When the value of the base path is an empty string +func (o *GetUploadStatusURL) SetBasePath(bp string) { + o._basePath = bp +} + +// Build a url path and query string +func (o *GetUploadStatusURL) Build() (*url.URL, error) { + var _result url.URL + + var _path = "/uploads/{uploadID}/status" + + uploadID := o.UploadID.String() + if uploadID != "" { + _path = strings.Replace(_path, "{uploadID}", uploadID, -1) + } else { + return nil, errors.New("uploadId is required on GetUploadStatusURL") + } + + _basePath := o._basePath + if _basePath == "" { + _basePath = "/ghc/v1" + } + _result.Path = golangswaggerpaths.Join(_basePath, _path) + + return &_result, nil +} + +// Must is a helper function to panic when the url builder returns an error +func (o *GetUploadStatusURL) Must(u *url.URL, err error) *url.URL { + if err != nil { + panic(err) + } + if u == nil { + panic("url can't be nil") + } + return u +} + +// String returns the string representation of the path with query string +func (o *GetUploadStatusURL) String() string { + return o.Must(o.Build()).String() +} + +// BuildFull builds a full url with scheme, host, path and query string +func (o *GetUploadStatusURL) BuildFull(scheme, host string) (*url.URL, error) { + if scheme == "" { + return nil, errors.New("scheme is required for a full url on GetUploadStatusURL") + } + if host == "" { + return nil, errors.New("host is required for a full url on GetUploadStatusURL") + } + + base, err := o.Build() + if err != nil { + return nil, err + } + + base.Scheme = scheme + base.Host = host + return base, nil +} + +// StringFull returns the string representation of a complete url +func (o *GetUploadStatusURL) StringFull(scheme, host string) string { + return o.Must(o.BuildFull(scheme, host)).String() +} diff --git a/pkg/handlers/apitests.go b/pkg/handlers/apitests.go index a84a6627f2c..a540d37e1f3 100644 --- a/pkg/handlers/apitests.go +++ b/pkg/handlers/apitests.go @@ -9,6 +9,7 @@ import ( "path" "path/filepath" "runtime/debug" + "strings" "time" "github.com/go-openapi/runtime" @@ -148,6 +149,11 @@ func (suite *BaseHandlerTestSuite) TestNotificationSender() notifications.Notifi return suite.notificationSender } +// TestNotificationReceiver returns the notification sender to use in the suite +func (suite *BaseHandlerTestSuite) TestNotificationReceiver() notifications.NotificationReceiver { + return notifications.NewStubNotificationReceiver() +} + // HasWebhookNotification checks that there's a record on the WebhookNotifications table for the object and trace IDs func (suite *BaseHandlerTestSuite) HasWebhookNotification(objectID uuid.UUID, traceID uuid.UUID) { notification := &models.WebhookNotification{} @@ -277,8 +283,12 @@ func (suite *BaseHandlerTestSuite) Fixture(name string) *runtime.File { if err != nil { suite.T().Error(err) } + cdRouting := "" + if strings.Contains(cwd, "routing") { + cdRouting = ".." + } - fixturePath := path.Join(cwd, "..", "..", fixtureDir, name) + fixturePath := path.Join(cwd, "..", "..", cdRouting, fixtureDir, name) file, err := os.Open(filepath.Clean(fixturePath)) if err != nil { diff --git a/pkg/handlers/authentication/auth.go b/pkg/handlers/authentication/auth.go index a01f499de5e..8e59132c750 100644 --- a/pkg/handlers/authentication/auth.go +++ b/pkg/handlers/authentication/auth.go @@ -221,6 +221,7 @@ var allowedRoutes = map[string]bool{ "uploads.deleteUpload": true, "users.showLoggedInUser": true, "okta_profile.showOktaInfo": true, + "uploads.getUploadStatus": true, } // checkIfRouteIsAllowed checks to see if the route is one of the ones that should be allowed through without stricter diff --git a/pkg/handlers/config.go b/pkg/handlers/config.go index b4bb2026915..50d45ee1978 100644 --- a/pkg/handlers/config.go +++ b/pkg/handlers/config.go @@ -39,6 +39,7 @@ type HandlerConfig interface { ) http.Handler FileStorer() storage.FileStorer NotificationSender() notifications.NotificationSender + NotificationReceiver() notifications.NotificationReceiver HHGPlanner() route.Planner DTODPlanner() route.Planner CookieSecret() string @@ -66,6 +67,7 @@ type Config struct { dtodPlanner route.Planner storage storage.FileStorer notificationSender notifications.NotificationSender + notificationReceiver notifications.NotificationReceiver iwsPersonLookup iws.PersonLookup sendProductionInvoice bool senderToGex services.GexSender @@ -86,6 +88,7 @@ func NewHandlerConfig( dtodPlanner route.Planner, storage storage.FileStorer, notificationSender notifications.NotificationSender, + notificationReceiver notifications.NotificationReceiver, iwsPersonLookup iws.PersonLookup, sendProductionInvoice bool, senderToGex services.GexSender, @@ -103,6 +106,7 @@ func NewHandlerConfig( dtodPlanner: dtodPlanner, storage: storage, notificationSender: notificationSender, + notificationReceiver: notificationReceiver, iwsPersonLookup: iwsPersonLookup, sendProductionInvoice: sendProductionInvoice, senderToGex: senderToGex, @@ -247,6 +251,16 @@ func (c *Config) SetNotificationSender(sender notifications.NotificationSender) c.notificationSender = sender } +// NotificationReceiver returns the sender to use in the current context +func (c *Config) NotificationReceiver() notifications.NotificationReceiver { + return c.notificationReceiver +} + +// SetNotificationSender is a simple setter for AWS SQS private field +func (c *Config) SetNotificationReceiver(receiver notifications.NotificationReceiver) { + c.notificationReceiver = receiver +} + // SetPlanner is a simple setter for the route.Planner private field func (c *Config) SetPlanner(planner route.Planner) { c.planner = planner diff --git a/pkg/handlers/config_test.go b/pkg/handlers/config_test.go index 26595daea29..85c9ccbff7c 100644 --- a/pkg/handlers/config_test.go +++ b/pkg/handlers/config_test.go @@ -30,7 +30,7 @@ func (suite *ConfigSuite) TestConfigHandler() { appCtx := suite.AppContextForTest() sessionManagers := auth.SetupSessionManagers(nil, false, time.Duration(180*time.Second), time.Duration(180*time.Second)) - handler := NewHandlerConfig(appCtx.DB(), nil, "", nil, nil, nil, nil, nil, false, nil, nil, false, ApplicationTestServername(), sessionManagers, nil) + handler := NewHandlerConfig(appCtx.DB(), nil, "", nil, nil, nil, nil, nil, nil, false, nil, nil, false, ApplicationTestServername(), sessionManagers, nil) req, err := http.NewRequest("GET", "/", nil) suite.NoError(err) myMethodCalled := false diff --git a/pkg/handlers/ghcapi/api.go b/pkg/handlers/ghcapi/api.go index 4a6bed9c52d..a3552d5edf5 100644 --- a/pkg/handlers/ghcapi/api.go +++ b/pkg/handlers/ghcapi/api.go @@ -4,6 +4,7 @@ import ( "log" "github.com/go-openapi/loads" + "github.com/go-openapi/runtime" "github.com/transcom/mymove/pkg/gen/ghcapi" ghcops "github.com/transcom/mymove/pkg/gen/ghcapi/ghcoperations" @@ -683,6 +684,8 @@ func NewGhcAPIHandler(handlerConfig handlers.HandlerConfig) *ghcops.MymoveAPI { ghcAPI.UploadsCreateUploadHandler = CreateUploadHandler{handlerConfig} ghcAPI.UploadsUpdateUploadHandler = UpdateUploadHandler{handlerConfig, upload.NewUploadInformationFetcher()} ghcAPI.UploadsDeleteUploadHandler = DeleteUploadHandler{handlerConfig, upload.NewUploadInformationFetcher()} + ghcAPI.UploadsGetUploadStatusHandler = GetUploadStatusHandler{handlerConfig, upload.NewUploadInformationFetcher()} + ghcAPI.TextEventStreamProducer = runtime.ByteStreamProducer() // GetUploadStatus produces Event Stream ghcAPI.CustomerSearchCustomersHandler = SearchCustomersHandler{ HandlerConfig: handlerConfig, diff --git a/pkg/handlers/ghcapi/internal/payloads/model_to_payload.go b/pkg/handlers/ghcapi/internal/payloads/model_to_payload.go index 87d56d72748..ca9a75bb8eb 100644 --- a/pkg/handlers/ghcapi/internal/payloads/model_to_payload.go +++ b/pkg/handlers/ghcapi/internal/payloads/model_to_payload.go @@ -2035,10 +2035,10 @@ func Upload(storer storage.FileStorer, upload models.Upload, url string) *ghcmes } tags, err := storer.Tags(upload.StorageKey) - if err != nil || len(tags) == 0 { - uploadPayload.Status = "PROCESSING" + if err != nil { + uploadPayload.Status = string(models.AVStatusPROCESSING) } else { - uploadPayload.Status = tags["av-status"] + uploadPayload.Status = string(models.GetAVStatusFromTags(tags)) } return uploadPayload } @@ -2057,10 +2057,10 @@ func WeightTicketUpload(storer storage.FileStorer, upload models.Upload, url str IsWeightTicket: isWeightTicket, } tags, err := storer.Tags(upload.StorageKey) - if err != nil || len(tags) == 0 { - uploadPayload.Status = "PROCESSING" + if err != nil { + uploadPayload.Status = string(models.AVStatusPROCESSING) } else { - uploadPayload.Status = tags["av-status"] + uploadPayload.Status = string(models.GetAVStatusFromTags(tags)) } return uploadPayload } @@ -2113,10 +2113,10 @@ func PayloadForUploadModel( } tags, err := storer.Tags(upload.StorageKey) - if err != nil || len(tags) == 0 { - uploadPayload.Status = "PROCESSING" + if err != nil { + uploadPayload.Status = string(models.AVStatusPROCESSING) } else { - uploadPayload.Status = tags["av-status"] + uploadPayload.Status = string(models.GetAVStatusFromTags(tags)) } return uploadPayload } diff --git a/pkg/handlers/ghcapi/move.go b/pkg/handlers/ghcapi/move.go index aaf96dde91e..f4abb0b549a 100644 --- a/pkg/handlers/ghcapi/move.go +++ b/pkg/handlers/ghcapi/move.go @@ -429,10 +429,10 @@ func payloadForUploadModelFromAdditionalDocumentsUpload(storer storage.FileStore UpdatedAt: strfmt.DateTime(upload.UpdatedAt), } tags, err := storer.Tags(upload.StorageKey) - if err != nil || len(tags) == 0 { - uploadPayload.Status = "PROCESSING" + if err != nil { + uploadPayload.Status = string(models.AVStatusPROCESSING) } else { - uploadPayload.Status = tags["av-status"] + uploadPayload.Status = string(models.GetAVStatusFromTags(tags)) } return uploadPayload, nil } diff --git a/pkg/handlers/ghcapi/orders.go b/pkg/handlers/ghcapi/orders.go index f6b559513a2..af6d5385bb4 100644 --- a/pkg/handlers/ghcapi/orders.go +++ b/pkg/handlers/ghcapi/orders.go @@ -933,10 +933,10 @@ func payloadForUploadModelFromAmendedOrdersUpload(storer storage.FileStorer, upl UpdatedAt: strfmt.DateTime(upload.UpdatedAt), } tags, err := storer.Tags(upload.StorageKey) - if err != nil || len(tags) == 0 { - uploadPayload.Status = "PROCESSING" + if err != nil { + uploadPayload.Status = string(models.AVStatusPROCESSING) } else { - uploadPayload.Status = tags["av-status"] + uploadPayload.Status = string(models.GetAVStatusFromTags(tags)) } return uploadPayload, nil } diff --git a/pkg/handlers/ghcapi/uploads.go b/pkg/handlers/ghcapi/uploads.go index a74e5d48498..70660150326 100644 --- a/pkg/handlers/ghcapi/uploads.go +++ b/pkg/handlers/ghcapi/uploads.go @@ -1,9 +1,16 @@ package ghcapi import ( + "context" + "fmt" + "net/http" + "strconv" + "time" + "github.com/go-openapi/runtime" "github.com/go-openapi/runtime/middleware" "github.com/gofrs/uuid" + "github.com/pkg/errors" "go.uber.org/zap" "github.com/transcom/mymove/pkg/appcontext" @@ -12,8 +19,10 @@ import ( "github.com/transcom/mymove/pkg/handlers" "github.com/transcom/mymove/pkg/handlers/ghcapi/internal/payloads" "github.com/transcom/mymove/pkg/models" + "github.com/transcom/mymove/pkg/notifications" "github.com/transcom/mymove/pkg/services" "github.com/transcom/mymove/pkg/services/upload" + "github.com/transcom/mymove/pkg/storage" uploaderpkg "github.com/transcom/mymove/pkg/uploader" ) @@ -157,3 +166,189 @@ func (h DeleteUploadHandler) Handle(params uploadop.DeleteUploadParams) middlewa }) } + +// UploadStatusHandler returns status of an upload +type GetUploadStatusHandler struct { + handlers.HandlerConfig + services.UploadInformationFetcher +} + +type CustomGetUploadStatusResponse struct { + params uploadop.GetUploadStatusParams + storageKey string + appCtx appcontext.AppContext + receiver notifications.NotificationReceiver + storer storage.FileStorer +} + +func (o *CustomGetUploadStatusResponse) writeEventStreamMessage(rw http.ResponseWriter, producer runtime.Producer, id int, event string, data string) { + resProcess := []byte(fmt.Sprintf("id: %s\nevent: %s\ndata: %s\n\n", strconv.Itoa(id), event, data)) + if produceErr := producer.Produce(rw, resProcess); produceErr != nil { + o.appCtx.Logger().Error(produceErr.Error()) + } + if f, ok := rw.(http.Flusher); ok { + f.Flush() + } +} + +func (o *CustomGetUploadStatusResponse) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) { + + // Check current tag before event-driven wait for anti-virus + tags, err := o.storer.Tags(o.storageKey) + var uploadStatus models.AVStatusType + if err != nil { + uploadStatus = models.AVStatusPROCESSING + } else { + uploadStatus = models.GetAVStatusFromTags(tags) + } + + // Limitation: once the status code header has been written (first response), we are not able to update the status for subsequent responses. + // Standard 200 OK used with common SSE paradigm + rw.WriteHeader(http.StatusOK) + if uploadStatus == models.AVStatusCLEAN || uploadStatus == models.AVStatusINFECTED { + o.writeEventStreamMessage(rw, producer, 0, "message", string(uploadStatus)) + o.writeEventStreamMessage(rw, producer, 1, "close", "Connection closed") + return // skip notification loop since object already tagged from anti-virus + } else { + o.writeEventStreamMessage(rw, producer, 0, "message", string(uploadStatus)) + } + + // Start waiting for tag updates + topicName, err := o.receiver.GetDefaultTopic() + if err != nil { + o.appCtx.Logger().Error(err.Error()) + } + + filterPolicy := fmt.Sprintf(`{ + "detail": { + "object": { + "key": [ + {"suffix": "%s"} + ] + } + } + }`, o.params.UploadID) + + notificationParams := notifications.NotificationQueueParams{ + SubscriptionTopicName: topicName, + NamePrefix: notifications.QueuePrefixObjectTagsAdded, + FilterPolicy: filterPolicy, + } + + queueUrl, err := o.receiver.CreateQueueWithSubscription(o.appCtx, notificationParams) + if err != nil { + o.appCtx.Logger().Error(err.Error()) + } + + id_counter := 1 + + // For loop over 120 seconds, cancel context when done and it breaks the loop + totalReceiverContext, totalReceiverContextCancelFunc := context.WithTimeout(context.Background(), 120*time.Second) + defer func() { + id_counter++ + o.writeEventStreamMessage(rw, producer, id_counter, "close", "Connection closed") + totalReceiverContextCancelFunc() + }() + + // Cleanup if client closes connection + go func() { + <-o.params.HTTPRequest.Context().Done() + totalReceiverContextCancelFunc() + }() + + // Cleanup at end of work + go func() { + <-totalReceiverContext.Done() + _ = o.receiver.CloseoutQueue(o.appCtx, queueUrl) + }() + + for { + o.appCtx.Logger().Info("Receiving Messages...") + messages, errs := o.receiver.ReceiveMessages(o.appCtx, queueUrl, totalReceiverContext) + + if errors.Is(errs, context.Canceled) || errors.Is(errs, context.DeadlineExceeded) { + return + } + if errs != nil { + o.appCtx.Logger().Error(err.Error()) + return + } + + if len(messages) != 0 { + errTransaction := o.appCtx.NewTransaction(func(txnAppCtx appcontext.AppContext) error { + + tags, err := o.storer.Tags(o.storageKey) + + if err != nil { + uploadStatus = models.AVStatusPROCESSING + } else { + uploadStatus = models.GetAVStatusFromTags(tags) + } + + o.writeEventStreamMessage(rw, producer, id_counter, "message", string(uploadStatus)) + + if uploadStatus == models.AVStatusCLEAN || uploadStatus == models.AVStatusINFECTED { + return errors.New("connection_closed") + } + + return err + }) + + if errTransaction != nil && errTransaction.Error() == "connection_closed" { + return + } + + if errTransaction != nil { + o.appCtx.Logger().Error(err.Error()) + return + } + } + id_counter++ + + select { + case <-totalReceiverContext.Done(): + return + default: + time.Sleep(1 * time.Second) // Throttle as a precaution against hounding of the SDK + continue + } + } +} + +// Handle returns status of an upload +func (h GetUploadStatusHandler) Handle(params uploadop.GetUploadStatusParams) middleware.Responder { + return h.AuditableAppContextFromRequestWithErrors(params.HTTPRequest, + func(appCtx appcontext.AppContext) (middleware.Responder, error) { + + handleError := func(err error) (middleware.Responder, error) { + appCtx.Logger().Error("GetUploadStatusHandler error", zap.Error(err)) + switch errors.Cause(err) { + case models.ErrFetchForbidden: + return uploadop.NewGetUploadStatusForbidden(), err + case models.ErrFetchNotFound: + return uploadop.NewGetUploadStatusNotFound(), err + default: + return uploadop.NewGetUploadStatusInternalServerError(), err + } + } + + uploadId := params.UploadID.String() + uploadUUID, err := uuid.FromString(uploadId) + if err != nil { + return handleError(err) + } + + uploaded, err := models.FetchUserUploadFromUploadID(appCtx.DB(), appCtx.Session(), uploadUUID) + if err != nil { + return handleError(err) + } + + return &CustomGetUploadStatusResponse{ + params: params, + storageKey: uploaded.Upload.StorageKey, + appCtx: h.AppContextFromRequest(params.HTTPRequest), + receiver: h.NotificationReceiver(), + storer: h.FileStorer(), + }, nil + }) +} diff --git a/pkg/handlers/ghcapi/uploads_test.go b/pkg/handlers/ghcapi/uploads_test.go index 94830bdb5bf..0a22ea6b87a 100644 --- a/pkg/handlers/ghcapi/uploads_test.go +++ b/pkg/handlers/ghcapi/uploads_test.go @@ -4,13 +4,17 @@ import ( "net/http" "github.com/go-openapi/runtime/middleware" + "github.com/go-openapi/strfmt" "github.com/gofrs/uuid" "github.com/transcom/mymove/pkg/factory" uploadop "github.com/transcom/mymove/pkg/gen/ghcapi/ghcoperations/uploads" "github.com/transcom/mymove/pkg/handlers" "github.com/transcom/mymove/pkg/models" + "github.com/transcom/mymove/pkg/notifications" + "github.com/transcom/mymove/pkg/services/upload" storageTest "github.com/transcom/mymove/pkg/storage/test" + "github.com/transcom/mymove/pkg/uploader" ) const FixturePDF = "test.pdf" @@ -156,3 +160,127 @@ func (suite *HandlerSuite) TestCreateUploadsHandlerFailure() { t.Fatalf("Wrong number of uploads in database: expected %d, got %d", currentCount, count) } } + +func (suite *HandlerSuite) TestGetUploadStatusHandlerSuccess() { + fakeS3 := storageTest.NewFakeS3Storage(true) + localReceiver := notifications.StubNotificationReceiver{} + + orders := factory.BuildOrder(suite.DB(), nil, nil) + uploadUser1 := factory.BuildUserUpload(suite.DB(), []factory.Customization{ + { + Model: orders.UploadedOrders, + LinkOnly: true, + }, + { + Model: models.Upload{ + Filename: "FileName", + Bytes: int64(15), + ContentType: uploader.FileTypePDF, + }, + }, + }, nil) + + file := suite.Fixture(FixturePDF) + _, err := fakeS3.Store(uploadUser1.Upload.StorageKey, file.Data, "somehash", nil) + suite.NoError(err) + + params := uploadop.NewGetUploadStatusParams() + params.UploadID = strfmt.UUID(uploadUser1.Upload.ID.String()) + + req := &http.Request{} + req = suite.AuthenticateRequest(req, uploadUser1.Document.ServiceMember) + params.HTTPRequest = req + + handlerConfig := suite.HandlerConfig() + handlerConfig.SetFileStorer(fakeS3) + handlerConfig.SetNotificationReceiver(localReceiver) + uploadInformationFetcher := upload.NewUploadInformationFetcher() + handler := GetUploadStatusHandler{handlerConfig, uploadInformationFetcher} + + response := handler.Handle(params) + _, ok := response.(*CustomGetUploadStatusResponse) + suite.True(ok) + + queriedUpload := models.Upload{} + err = suite.DB().Find(&queriedUpload, uploadUser1.Upload.ID) + suite.NoError(err) +} + +func (suite *HandlerSuite) TestGetUploadStatusHandlerFailure() { + suite.Run("Error on no match for uploadId", func() { + orders := factory.BuildOrder(suite.DB(), factory.GetTraitActiveServiceMemberUser(), nil) + + uploadUUID := uuid.Must(uuid.NewV4()) + + params := uploadop.NewGetUploadStatusParams() + params.UploadID = strfmt.UUID(uploadUUID.String()) + + req := &http.Request{} + req = suite.AuthenticateRequest(req, orders.ServiceMember) + params.HTTPRequest = req + + fakeS3 := storageTest.NewFakeS3Storage(true) + localReceiver := notifications.StubNotificationReceiver{} + + handlerConfig := suite.HandlerConfig() + handlerConfig.SetFileStorer(fakeS3) + handlerConfig.SetNotificationReceiver(localReceiver) + uploadInformationFetcher := upload.NewUploadInformationFetcher() + handler := GetUploadStatusHandler{handlerConfig, uploadInformationFetcher} + + response := handler.Handle(params) + _, ok := response.(*uploadop.GetUploadStatusNotFound) + suite.True(ok) + + queriedUpload := models.Upload{} + err := suite.DB().Find(&queriedUpload, uploadUUID) + suite.Error(err) + }) + + suite.Run("Error when attempting access to another service member's upload", func() { + fakeS3 := storageTest.NewFakeS3Storage(true) + localReceiver := notifications.StubNotificationReceiver{} + + otherServiceMember := factory.BuildServiceMember(suite.DB(), nil, nil) + + orders := factory.BuildOrder(suite.DB(), nil, nil) + uploadUser1 := factory.BuildUserUpload(suite.DB(), []factory.Customization{ + { + Model: orders.UploadedOrders, + LinkOnly: true, + }, + { + Model: models.Upload{ + Filename: "FileName", + Bytes: int64(15), + ContentType: uploader.FileTypePDF, + }, + }, + }, nil) + + file := suite.Fixture(FixturePDF) + _, err := fakeS3.Store(uploadUser1.Upload.StorageKey, file.Data, "somehash", nil) + suite.NoError(err) + + params := uploadop.NewGetUploadStatusParams() + params.UploadID = strfmt.UUID(uploadUser1.Upload.ID.String()) + + req := &http.Request{} + req = suite.AuthenticateRequest(req, otherServiceMember) + params.HTTPRequest = req + + handlerConfig := suite.HandlerConfig() + handlerConfig.SetFileStorer(fakeS3) + handlerConfig.SetNotificationReceiver(localReceiver) + uploadInformationFetcher := upload.NewUploadInformationFetcher() + handler := GetUploadStatusHandler{handlerConfig, uploadInformationFetcher} + + response := handler.Handle(params) + _, ok := response.(*uploadop.GetUploadStatusForbidden) + suite.True(ok) + + queriedUpload := models.Upload{} + err = suite.DB().Find(&queriedUpload, uploadUser1.Upload.ID) + suite.NoError(err) + }) +} diff --git a/pkg/handlers/internalapi/internal/payloads/model_to_payload.go b/pkg/handlers/internalapi/internal/payloads/model_to_payload.go index 68e9cd5b576..26b25349e02 100644 --- a/pkg/handlers/internalapi/internal/payloads/model_to_payload.go +++ b/pkg/handlers/internalapi/internal/payloads/model_to_payload.go @@ -453,12 +453,14 @@ func PayloadForUploadModel( CreatedAt: strfmt.DateTime(upload.CreatedAt), UpdatedAt: strfmt.DateTime(upload.UpdatedAt), } + tags, err := storer.Tags(upload.StorageKey) - if err != nil || len(tags) == 0 { - uploadPayload.Status = "PROCESSING" + if err != nil { + uploadPayload.Status = string(models.AVStatusPROCESSING) } else { - uploadPayload.Status = tags["av-status"] + uploadPayload.Status = string(models.GetAVStatusFromTags(tags)) } + return uploadPayload } diff --git a/pkg/handlers/internalapi/moves.go b/pkg/handlers/internalapi/moves.go index 891c990e15e..f431da62850 100644 --- a/pkg/handlers/internalapi/moves.go +++ b/pkg/handlers/internalapi/moves.go @@ -588,10 +588,10 @@ func payloadForUploadModelFromAdditionalDocumentsUpload(storer storage.FileStore UpdatedAt: strfmt.DateTime(upload.UpdatedAt), } tags, err := storer.Tags(upload.StorageKey) - if err != nil || len(tags) == 0 { - uploadPayload.Status = "PROCESSING" + if err != nil { + uploadPayload.Status = string(models.AVStatusPROCESSING) } else { - uploadPayload.Status = tags["av-status"] + uploadPayload.Status = string(models.GetAVStatusFromTags(tags)) } return uploadPayload, nil } diff --git a/pkg/handlers/internalapi/orders.go b/pkg/handlers/internalapi/orders.go index 6e663bff4ea..673b0f55b57 100644 --- a/pkg/handlers/internalapi/orders.go +++ b/pkg/handlers/internalapi/orders.go @@ -35,10 +35,10 @@ func payloadForUploadModelFromAmendedOrdersUpload(storer storage.FileStorer, upl UpdatedAt: strfmt.DateTime(upload.UpdatedAt), } tags, err := storer.Tags(upload.StorageKey) - if err != nil || len(tags) == 0 { - uploadPayload.Status = "PROCESSING" + if err != nil { + uploadPayload.Status = string(models.AVStatusPROCESSING) } else { - uploadPayload.Status = tags["av-status"] + uploadPayload.Status = string(models.GetAVStatusFromTags(tags)) } return uploadPayload, nil } diff --git a/pkg/handlers/routing/base_routing_suite.go b/pkg/handlers/routing/base_routing_suite.go index 23e538792b7..77049e33664 100644 --- a/pkg/handlers/routing/base_routing_suite.go +++ b/pkg/handlers/routing/base_routing_suite.go @@ -85,6 +85,7 @@ func (suite *BaseRoutingSuite) RoutingConfig() *Config { handlerConfig := suite.BaseHandlerTestSuite.HandlerConfig() handlerConfig.SetAppNames(handlers.ApplicationTestServername()) handlerConfig.SetNotificationSender(suite.TestNotificationSender()) + handlerConfig.SetNotificationReceiver(suite.TestNotificationReceiver()) // Need this for any requests that will either retrieve or save files or their info. fakeS3 := storageTest.NewFakeS3Storage(true) diff --git a/pkg/handlers/routing/ghcapi_test/uploads_test.go b/pkg/handlers/routing/ghcapi_test/uploads_test.go new file mode 100644 index 00000000000..5eb27758d00 --- /dev/null +++ b/pkg/handlers/routing/ghcapi_test/uploads_test.go @@ -0,0 +1,85 @@ +package ghcapi_test + +import ( + "net/http" + "net/http/httptest" + + "github.com/transcom/mymove/pkg/factory" + "github.com/transcom/mymove/pkg/models" + "github.com/transcom/mymove/pkg/models/roles" + storageTest "github.com/transcom/mymove/pkg/storage/test" + "github.com/transcom/mymove/pkg/uploader" +) + +func (suite *GhcAPISuite) TestUploads() { + + suite.Run("Received status for upload, read tag without event queue", func() { + orders := factory.BuildOrder(suite.DB(), factory.GetTraitActiveServiceMemberUser(), nil) + uploadUser1 := factory.BuildUserUpload(suite.DB(), []factory.Customization{ + { + Model: orders.UploadedOrders, + LinkOnly: true, + }, + { + Model: models.Upload{ + Filename: "FileName", + Bytes: int64(15), + ContentType: uploader.FileTypePDF, + }, + }, + }, nil) + file := suite.Fixture("test.pdf") + _, err := suite.HandlerConfig().FileStorer().Store(uploadUser1.Upload.StorageKey, file.Data, "somehash", nil) + suite.NoError(err) + + officeUser := factory.BuildOfficeUserWithRoles(suite.DB(), factory.GetTraitActiveOfficeUser(), + []roles.RoleType{roles.RoleTypeTOO}) + req := suite.NewAuthenticatedOfficeRequest("GET", "/ghc/v1/uploads/"+uploadUser1.Upload.ID.String()+"/status", nil, officeUser) + rr := httptest.NewRecorder() + + suite.SetupSiteHandler().ServeHTTP(rr, req) + + suite.Equal(http.StatusOK, rr.Code) + suite.Equal("text/event-stream", rr.Header().Get("content-type")) + suite.Equal("id: 0\nevent: message\ndata: CLEAN\n\nid: 1\nevent: close\ndata: Connection closed\n\n", rr.Body.String()) + }) + + suite.Run("Received statuses for upload, receiving multiple statuses with event queue", func() { + orders := factory.BuildOrder(suite.DB(), factory.GetTraitActiveServiceMemberUser(), nil) + uploadUser1 := factory.BuildUserUpload(suite.DB(), []factory.Customization{ + { + Model: orders.UploadedOrders, + LinkOnly: true, + }, + { + Model: models.Upload{ + Filename: "FileName", + Bytes: int64(15), + ContentType: uploader.FileTypePDF, + }, + }, + }, nil) + file := suite.Fixture("test.pdf") + _, err := suite.HandlerConfig().FileStorer().Store(uploadUser1.Upload.StorageKey, file.Data, "somehash", nil) + suite.NoError(err) + + officeUser := factory.BuildOfficeUserWithRoles(suite.DB(), factory.GetTraitActiveOfficeUser(), + []roles.RoleType{roles.RoleTypeTOO}) + req := suite.NewAuthenticatedOfficeRequest("GET", "/ghc/v1/uploads/"+uploadUser1.Upload.ID.String()+"/status", nil, officeUser) + rr := httptest.NewRecorder() + + fakeS3, ok := suite.HandlerConfig().FileStorer().(*storageTest.FakeS3Storage) + suite.True(ok) + suite.NotNil(fakeS3, "FileStorer should be fakeS3") + + fakeS3.EmptyTags = true + suite.SetupSiteHandler().ServeHTTP(rr, req) + + suite.Equal(http.StatusOK, rr.Code) + suite.Equal("text/event-stream", rr.Header().Get("content-type")) + + suite.Contains(rr.Body.String(), "PROCESSING") + suite.Contains(rr.Body.String(), "CLEAN") + suite.Contains(rr.Body.String(), "Connection closed") + }) +} diff --git a/pkg/models/upload.go b/pkg/models/upload.go index d6afc2d0d4a..c03c4ec2bd2 100644 --- a/pkg/models/upload.go +++ b/pkg/models/upload.go @@ -13,6 +13,26 @@ import ( "github.com/transcom/mymove/pkg/db/utilities" ) +// Used tangentally in association with an Upload to provide status of anti-virus scan +// AVStatusType represents the type of the anti-virus status, whether it is still processing, clean or infected +type AVStatusType string + +const ( + // AVStatusPROCESSING string PROCESSING + AVStatusPROCESSING AVStatusType = "PROCESSING" + // AVStatusCLEAN string CLEAN + AVStatusCLEAN AVStatusType = "CLEAN" + // AVStatusINFECTED string INFECTED + AVStatusINFECTED AVStatusType = "INFECTED" +) + +func GetAVStatusFromTags(tags map[string]string) AVStatusType { + if status, exists := tags["av-status"]; exists { + return AVStatusType(status) + } + return AVStatusType(AVStatusPROCESSING) +} + // UploadType represents the type of upload this is, whether is it uploaded for a User or for the Prime type UploadType string diff --git a/pkg/notifications/notification_receiver.go b/pkg/notifications/notification_receiver.go new file mode 100644 index 00000000000..6dfab1b5d74 --- /dev/null +++ b/pkg/notifications/notification_receiver.go @@ -0,0 +1,334 @@ +package notifications + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/sns" + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/gofrs/uuid" + "go.uber.org/zap" + + "github.com/transcom/mymove/pkg/appcontext" + "github.com/transcom/mymove/pkg/cli" +) + +// NotificationQueueParams stores the params for queue creation +type NotificationQueueParams struct { + SubscriptionTopicName string + NamePrefix QueuePrefixType + FilterPolicy string +} + +// NotificationReceiver is an interface for receiving notifications +type NotificationReceiver interface { + CreateQueueWithSubscription(appCtx appcontext.AppContext, params NotificationQueueParams) (string, error) + ReceiveMessages(appCtx appcontext.AppContext, queueUrl string, timerContext context.Context) ([]ReceivedMessage, error) + CloseoutQueue(appCtx appcontext.AppContext, queueUrl string) error + GetDefaultTopic() (string, error) +} + +// NotificationReceiverConext provides context to a notification Receiver. Maps use queueUrl for key +type NotificationReceiverContext struct { + viper ViperType + snsService SnsClient + sqsService SqsClient + awsRegion string + awsAccountId string + queueSubscriptionMap map[string]string + receiverCancelMap map[string]context.CancelFunc +} + +// QueuePrefixType represents a prefix identifier given to a name of dynamic notification queues +type QueuePrefixType string + +const ( + QueuePrefixObjectTagsAdded QueuePrefixType = "ObjectTagsAdded" +) + +//go:generate mockery --name SnsClient --output ./receiverMocks +type SnsClient interface { + Subscribe(ctx context.Context, params *sns.SubscribeInput, optFns ...func(*sns.Options)) (*sns.SubscribeOutput, error) + Unsubscribe(ctx context.Context, params *sns.UnsubscribeInput, optFns ...func(*sns.Options)) (*sns.UnsubscribeOutput, error) + ListSubscriptionsByTopic(context.Context, *sns.ListSubscriptionsByTopicInput, ...func(*sns.Options)) (*sns.ListSubscriptionsByTopicOutput, error) +} + +//go:generate mockery --name SqsClient --output ./receiverMocks +type SqsClient interface { + CreateQueue(ctx context.Context, params *sqs.CreateQueueInput, optFns ...func(*sqs.Options)) (*sqs.CreateQueueOutput, error) + ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) + DeleteMessage(ctx context.Context, params *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) (*sqs.DeleteMessageOutput, error) + DeleteQueue(ctx context.Context, params *sqs.DeleteQueueInput, optFns ...func(*sqs.Options)) (*sqs.DeleteQueueOutput, error) + ListQueues(ctx context.Context, params *sqs.ListQueuesInput, optFns ...func(*sqs.Options)) (*sqs.ListQueuesOutput, error) +} + +//go:generate mockery --name ViperType --output ./receiverMocks +type ViperType interface { + GetString(string) string + SetEnvKeyReplacer(*strings.Replacer) +} + +// ReceivedMessage standardizes the format of the received message +type ReceivedMessage struct { + MessageId string + Body *string +} + +// NewNotificationReceiver returns a new NotificationReceiverContext +func NewNotificationReceiver(v ViperType, snsService SnsClient, sqsService SqsClient, awsRegion string, awsAccountId string) NotificationReceiverContext { + return NotificationReceiverContext{ + viper: v, + snsService: snsService, + sqsService: sqsService, + awsRegion: awsRegion, + awsAccountId: awsAccountId, + queueSubscriptionMap: make(map[string]string), + receiverCancelMap: make(map[string]context.CancelFunc), + } +} + +// CreateQueueWithSubscription first creates a new queue, then subscribes an AWS topic to it +func (n NotificationReceiverContext) CreateQueueWithSubscription(appCtx appcontext.AppContext, params NotificationQueueParams) (string, error) { + + queueUUID := uuid.Must(uuid.NewV4()) + + queueName := fmt.Sprintf("%s_%s", params.NamePrefix, queueUUID) + queueArn := n.constructArn("sqs", queueName) + topicArn := n.constructArn("sns", params.SubscriptionTopicName) + + accessPolicy := fmt.Sprintf(`{ + "Version": "2012-10-17", + "Statement": [{ + "Sid": "AllowSNSPublish", + "Effect": "Allow", + "Principal": { + "Service": "sns.amazonaws.com" + }, + "Action": ["sqs:SendMessage"], + "Resource": "%s", + "Condition": { + "ArnEquals": { + "aws:SourceArn": "%s" + } + } + }, { + "Sid": "DenyNonSSLAccess", + "Effect": "Deny", + "Principal": "*", + "Action": "sqs:*", + "Resource": "%s", + "Condition": { + "Bool": { + "aws:SecureTransport": "false" + } + } + }] + }`, queueArn, topicArn, queueArn) + + input := &sqs.CreateQueueInput{ + QueueName: &queueName, + Attributes: map[string]string{ + "MessageRetentionPeriod": "120", + "Policy": accessPolicy, + }, + } + + result, err := n.sqsService.CreateQueue(context.Background(), input) + if err != nil { + appCtx.Logger().Error("Failed to create SQS queue, %v", zap.Error(err)) + return "", err + } + + subscribeInput := &sns.SubscribeInput{ + TopicArn: &topicArn, + Protocol: aws.String("sqs"), + Endpoint: &queueArn, + Attributes: map[string]string{ + "FilterPolicy": params.FilterPolicy, + "FilterPolicyScope": "MessageBody", + }, + } + subscribeOutput, err := n.snsService.Subscribe(context.Background(), subscribeInput) + if err != nil { + appCtx.Logger().Error("Failed to create subscription, %v", zap.Error(err)) + return "", err + } + + n.queueSubscriptionMap[*result.QueueUrl] = *subscribeOutput.SubscriptionArn + + return *result.QueueUrl, nil +} + +// ReceiveMessages polls given queue continuously for messages for up to 20 seconds +func (n NotificationReceiverContext) ReceiveMessages(appCtx appcontext.AppContext, queueUrl string, timerContext context.Context) ([]ReceivedMessage, error) { + recCtx, cancelRecCtx := context.WithCancel(timerContext) + defer cancelRecCtx() + n.receiverCancelMap[queueUrl] = cancelRecCtx + + result, err := n.sqsService.ReceiveMessage(recCtx, &sqs.ReceiveMessageInput{ + QueueUrl: &queueUrl, + MaxNumberOfMessages: 1, + WaitTimeSeconds: 20, + }) + if errors.Is(recCtx.Err(), context.Canceled) || errors.Is(recCtx.Err(), context.DeadlineExceeded) { + return nil, recCtx.Err() + } + + if err != nil { + appCtx.Logger().Info("Couldn't get messages from queue. Error: %v\n", zap.Error(err)) + return nil, err + } + + receivedMessages := make([]ReceivedMessage, len(result.Messages)) + for index, value := range result.Messages { + receivedMessages[index] = ReceivedMessage{ + MessageId: *value.MessageId, + Body: value.Body, + } + + appCtx.Logger().Info("Message received.", zap.String("messageId", *value.MessageId)) + + _, err := n.sqsService.DeleteMessage(recCtx, &sqs.DeleteMessageInput{ + QueueUrl: &queueUrl, + ReceiptHandle: value.ReceiptHandle, + }) + if err != nil { + appCtx.Logger().Info("Couldn't delete message from queue. Error: %v\n", zap.Error(err)) + } + } + + return receivedMessages, recCtx.Err() +} + +// CloseoutQueue stops receiving messages and cleans up the queue and its subscriptions +func (n NotificationReceiverContext) CloseoutQueue(appCtx appcontext.AppContext, queueUrl string) error { + appCtx.Logger().Info("Closing out queue: ", zap.String("queueUrl", queueUrl)) + + if cancelFunc, exists := n.receiverCancelMap[queueUrl]; exists { + cancelFunc() + delete(n.receiverCancelMap, queueUrl) + } + + if subscriptionArn, exists := n.queueSubscriptionMap[queueUrl]; exists { + _, err := n.snsService.Unsubscribe(context.Background(), &sns.UnsubscribeInput{ + SubscriptionArn: &subscriptionArn, + }) + if err != nil { + return err + } + delete(n.queueSubscriptionMap, queueUrl) + } + + _, err := n.sqsService.DeleteQueue(context.Background(), &sqs.DeleteQueueInput{ + QueueUrl: &queueUrl, + }) + + return err +} + +// GetDefaultTopic returns the topic value set within the environment +func (n NotificationReceiverContext) GetDefaultTopic() (string, error) { + topicName := n.viper.GetString(cli.SNSTagsUpdatedTopicFlag) + receiverBackend := n.viper.GetString(cli.ReceiverBackendFlag) + if topicName == "" && receiverBackend == "sns_sqs" { + return "", errors.New("sns_tags_updated_topic key not available") + } + return topicName, nil +} + +// InitReceiver initializes the receiver backend, only call this once +func InitReceiver(v ViperType, logger *zap.Logger, wipeAllNotificationQueues bool) (NotificationReceiver, error) { + + if v.GetString(cli.ReceiverBackendFlag) == "sns_sqs" { + // Setup notification receiver service with SNS & SQS backend dependencies + awsSNSRegion := v.GetString(cli.SNSRegionFlag) + awsAccountId := v.GetString(cli.SNSAccountId) + + logger.Info("Using aws sns_sqs receiver backend", zap.String("region", awsSNSRegion)) + + cfg, err := config.LoadDefaultConfig(context.Background(), + config.WithRegion(awsSNSRegion), + ) + if err != nil { + logger.Fatal("error loading sns aws config", zap.Error(err)) + return nil, err + } + + snsService := sns.NewFromConfig(cfg) + sqsService := sqs.NewFromConfig(cfg) + + notificationReceiver := NewNotificationReceiver(v, snsService, sqsService, awsSNSRegion, awsAccountId) + + // Remove any remaining previous notification queues on server start + if wipeAllNotificationQueues { + err = notificationReceiver.wipeAllNotificationQueues(logger) + if err != nil { + return nil, err + } + } + + return notificationReceiver, nil + } + + logger.Info("Using local notification receiver backend", zap.String("receiver_backend", v.GetString(cli.ReceiverBackendFlag))) + + return NewStubNotificationReceiver(), nil +} + +func (n NotificationReceiverContext) constructArn(awsService string, endpointName string) string { + return fmt.Sprintf("arn:aws-us-gov:%s:%s:%s:%s", awsService, n.awsRegion, n.awsAccountId, endpointName) +} + +// Removes ALL previously created notification queues +func (n *NotificationReceiverContext) wipeAllNotificationQueues(logger *zap.Logger) error { + defaultTopic, err := n.GetDefaultTopic() + if err != nil { + return err + } + + logger.Info("Receiver cleanup - Removing previous subscriptions...") + paginator := sns.NewListSubscriptionsByTopicPaginator(n.snsService, &sns.ListSubscriptionsByTopicInput{ + TopicArn: aws.String(n.constructArn("sns", defaultTopic)), + }) + + for paginator.HasMorePages() { + output, err := paginator.NextPage(context.Background()) + if err != nil { + return err + } + for _, subscription := range output.Subscriptions { + if strings.Contains(*subscription.Endpoint, string(QueuePrefixObjectTagsAdded)) { + logger.Info("Subscription ARN: ", zap.String("subscription arn", *subscription.SubscriptionArn)) + logger.Info("Endpoint ARN: ", zap.String("endpoint arn", *subscription.Endpoint)) + _, err = n.snsService.Unsubscribe(context.Background(), &sns.UnsubscribeInput{ + SubscriptionArn: subscription.SubscriptionArn, + }) + if err != nil { + return err + } + } + } + } + + logger.Info("Receiver cleanup - Removing previous queues...") + result, err := n.sqsService.ListQueues(context.Background(), &sqs.ListQueuesInput{ + QueueNamePrefix: aws.String(string(QueuePrefixObjectTagsAdded)), + }) + if err != nil { + return err + } + + for _, url := range result.QueueUrls { + _, err = n.sqsService.DeleteQueue(context.Background(), &sqs.DeleteQueueInput{ + QueueUrl: &url, + }) + if err != nil { + return err + } + } + return nil +} diff --git a/pkg/notifications/notification_receiver_stub.go b/pkg/notifications/notification_receiver_stub.go new file mode 100644 index 00000000000..e98f0c8aa1e --- /dev/null +++ b/pkg/notifications/notification_receiver_stub.go @@ -0,0 +1,51 @@ +package notifications + +import ( + "context" + "time" + + "go.uber.org/zap" + + "github.com/transcom/mymove/pkg/appcontext" +) + +// StubNotificationReceiver mocks an SNS & SQS client for local usage +type StubNotificationReceiver NotificationReceiverContext + +// NewStubNotificationReceiver returns a new StubNotificationReceiver +func NewStubNotificationReceiver() StubNotificationReceiver { + return StubNotificationReceiver{ + snsService: nil, + sqsService: nil, + awsRegion: "", + awsAccountId: "", + queueSubscriptionMap: make(map[string]string), + receiverCancelMap: make(map[string]context.CancelFunc), + } +} + +func (n StubNotificationReceiver) CreateQueueWithSubscription(appCtx appcontext.AppContext, params NotificationQueueParams) (string, error) { + return "stubQueueName", nil +} + +func (n StubNotificationReceiver) ReceiveMessages(appCtx appcontext.AppContext, queueUrl string, timerContext context.Context) ([]ReceivedMessage, error) { + time.Sleep(3 * time.Second) + messageId := "stubMessageId" + body := queueUrl + ":stubMessageBody" + mockMessages := make([]ReceivedMessage, 1) + mockMessages[0] = ReceivedMessage{ + MessageId: messageId, + Body: &body, + } + appCtx.Logger().Debug("Receiving a stubbed message for queue: %v", zap.String("queueUrl", queueUrl)) + return mockMessages, nil +} + +func (n StubNotificationReceiver) CloseoutQueue(appCtx appcontext.AppContext, queueUrl string) error { + appCtx.Logger().Debug("Closing out the stubbed queue.") + return nil +} + +func (n StubNotificationReceiver) GetDefaultTopic() (string, error) { + return "stubDefaultTopic", nil +} diff --git a/pkg/notifications/notification_receiver_test.go b/pkg/notifications/notification_receiver_test.go new file mode 100644 index 00000000000..f7dab5a91b7 --- /dev/null +++ b/pkg/notifications/notification_receiver_test.go @@ -0,0 +1,146 @@ +package notifications + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sns" + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/transcom/mymove/pkg/cli" + mocks "github.com/transcom/mymove/pkg/notifications/receiverMocks" + "github.com/transcom/mymove/pkg/testingsuite" +) + +type notificationReceiverSuite struct { + *testingsuite.PopTestSuite +} + +func TestNotificationReceiverSuite(t *testing.T) { + + hs := ¬ificationReceiverSuite{ + PopTestSuite: testingsuite.NewPopTestSuite(testingsuite.CurrentPackage(), + testingsuite.WithPerTestTransaction()), + } + suite.Run(t, hs) + hs.PopTestSuite.TearDown() +} + +func (suite *notificationReceiverSuite) TestSuccessPath() { + + suite.Run("local backend - notification receiver stub", func() { + // Setup mocks + mockedViper := mocks.ViperType{} + mockedViper.On("GetString", cli.ReceiverBackendFlag).Return("local") + mockedViper.On("GetString", cli.SNSRegionFlag).Return("us-gov-west-1") + mockedViper.On("GetString", cli.SNSAccountId).Return("12345") + mockedViper.On("GetString", cli.SNSTagsUpdatedTopicFlag).Return("fake_sns_topic") + localReceiver, err := InitReceiver(&mockedViper, suite.Logger(), true) + + suite.NoError(err) + suite.IsType(StubNotificationReceiver{}, localReceiver) + + defaultTopic, err := localReceiver.GetDefaultTopic() + suite.Equal("stubDefaultTopic", defaultTopic) + suite.NoError(err) + + queueParams := NotificationQueueParams{ + NamePrefix: "testPrefix", + } + createdQueueUrl, err := localReceiver.CreateQueueWithSubscription(suite.AppContextForTest(), queueParams) + suite.NoError(err) + suite.NotContains(createdQueueUrl, queueParams.NamePrefix) + suite.Equal(createdQueueUrl, "stubQueueName") + + timerContext, cancelTimerContext := context.WithTimeout(context.Background(), 2*time.Second) + defer cancelTimerContext() + + receivedMessages, err := localReceiver.ReceiveMessages(suite.AppContextForTest(), createdQueueUrl, timerContext) + suite.NoError(err) + suite.Len(receivedMessages, 1) + suite.Equal(receivedMessages[0].MessageId, "stubMessageId") + suite.Equal(*receivedMessages[0].Body, fmt.Sprintf("%s:stubMessageBody", createdQueueUrl)) + }) + + suite.Run("aws backend - notification receiver InitReceiver", func() { + // Setup mocks + mockedViper := mocks.ViperType{} + mockedViper.On("GetString", cli.ReceiverBackendFlag).Return("sns_sqs") + mockedViper.On("GetString", cli.SNSRegionFlag).Return("us-gov-west-1") + mockedViper.On("GetString", cli.SNSAccountId).Return("12345") + mockedViper.On("GetString", cli.SNSTagsUpdatedTopicFlag).Return("fake_sns_topic") + + receiver, err := InitReceiver(&mockedViper, suite.Logger(), false) + + suite.NoError(err) + suite.IsType(NotificationReceiverContext{}, receiver) + defaultTopic, err := receiver.GetDefaultTopic() + suite.Equal("fake_sns_topic", defaultTopic) + suite.NoError(err) + }) + + suite.Run("aws backend - notification receiver with mock services", func() { + // Setup mocks + mockedViper := mocks.ViperType{} + mockedViper.On("GetString", cli.ReceiverBackendFlag).Return("sns_sqs") + mockedViper.On("GetString", cli.SNSRegionFlag).Return("us-gov-west-1") + mockedViper.On("GetString", cli.SNSAccountId).Return("12345") + mockedViper.On("GetString", cli.SNSTagsUpdatedTopicFlag).Return("fake_sns_topic") + + mockedSns := mocks.SnsClient{} + mockedSns.On("Subscribe", mock.Anything, mock.AnythingOfType("*sns.SubscribeInput")).Return(&sns.SubscribeOutput{ + SubscriptionArn: aws.String("FakeSubscriptionArn"), + }, nil) + mockedSns.On("Unsubscribe", mock.Anything, mock.AnythingOfType("*sns.UnsubscribeInput")).Return(&sns.UnsubscribeOutput{}, nil) + mockedSns.On("ListSubscriptionsByTopic", mock.Anything, mock.AnythingOfType("*sns.ListSubscriptionsByTopicInput")).Return(&sns.ListSubscriptionsByTopicOutput{}, nil) + + mockedSqs := mocks.SqsClient{} + mockedSqs.On("CreateQueue", mock.Anything, mock.AnythingOfType("*sqs.CreateQueueInput")).Return(&sqs.CreateQueueOutput{ + QueueUrl: aws.String("fakeQueueUrl"), + }, nil) + mockedSqs.On("ReceiveMessage", mock.Anything, mock.AnythingOfType("*sqs.ReceiveMessageInput")).Return(&sqs.ReceiveMessageOutput{ + Messages: []types.Message{ + { + MessageId: aws.String("fakeMessageId"), + Body: aws.String("fakeQueueUrl:fakeMessageBody"), + }, + }, + }, nil) + mockedSqs.On("DeleteMessage", mock.Anything, mock.AnythingOfType("*sqs.DeleteMessageInput")).Return(&sqs.DeleteMessageOutput{}, nil) + mockedSqs.On("DeleteQueue", mock.Anything, mock.AnythingOfType("*sqs.DeleteQueueInput")).Return(&sqs.DeleteQueueOutput{}, nil) + mockedSqs.On("ListQueues", mock.Anything, mock.AnythingOfType("*sqs.ListQueuesInput")).Return(&sqs.ListQueuesOutput{}, nil) + + // Run test + receiver := NewNotificationReceiver(&mockedViper, &mockedSns, &mockedSqs, "", "") + suite.IsType(NotificationReceiverContext{}, receiver) + + defaultTopic, err := receiver.GetDefaultTopic() + suite.Equal("fake_sns_topic", defaultTopic) + suite.NoError(err) + + queueParams := NotificationQueueParams{ + NamePrefix: "testPrefix", + } + createdQueueUrl, err := receiver.CreateQueueWithSubscription(suite.AppContextForTest(), queueParams) + suite.NoError(err) + suite.Equal("fakeQueueUrl", createdQueueUrl) + + timerContext, cancelTimerContext := context.WithTimeout(context.Background(), 2*time.Second) + defer cancelTimerContext() + + receivedMessages, err := receiver.ReceiveMessages(suite.AppContextForTest(), createdQueueUrl, timerContext) + suite.NoError(err) + suite.Len(receivedMessages, 1) + suite.Equal(receivedMessages[0].MessageId, "fakeMessageId") + suite.Equal(*receivedMessages[0].Body, fmt.Sprintf("%s:fakeMessageBody", createdQueueUrl)) + + err = receiver.CloseoutQueue(suite.AppContextForTest(), createdQueueUrl) + suite.NoError(err) + }) +} diff --git a/pkg/notifications/notification_stub.go b/pkg/notifications/notification_sender_stub.go similarity index 100% rename from pkg/notifications/notification_stub.go rename to pkg/notifications/notification_sender_stub.go diff --git a/pkg/notifications/notification_test.go b/pkg/notifications/notification_sender_test.go similarity index 100% rename from pkg/notifications/notification_test.go rename to pkg/notifications/notification_sender_test.go diff --git a/pkg/notifications/receiverMocks/SnsClient.go b/pkg/notifications/receiverMocks/SnsClient.go new file mode 100644 index 00000000000..0c562896a0d --- /dev/null +++ b/pkg/notifications/receiverMocks/SnsClient.go @@ -0,0 +1,141 @@ +// Code generated by mockery. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + sns "github.com/aws/aws-sdk-go-v2/service/sns" +) + +// SnsClient is an autogenerated mock type for the SnsClient type +type SnsClient struct { + mock.Mock +} + +// ListSubscriptionsByTopic provides a mock function with given fields: _a0, _a1, _a2 +func (_m *SnsClient) ListSubscriptionsByTopic(_a0 context.Context, _a1 *sns.ListSubscriptionsByTopicInput, _a2 ...func(*sns.Options)) (*sns.ListSubscriptionsByTopicOutput, error) { + _va := make([]interface{}, len(_a2)) + for _i := range _a2 { + _va[_i] = _a2[_i] + } + var _ca []interface{} + _ca = append(_ca, _a0, _a1) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for ListSubscriptionsByTopic") + } + + var r0 *sns.ListSubscriptionsByTopicOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sns.ListSubscriptionsByTopicInput, ...func(*sns.Options)) (*sns.ListSubscriptionsByTopicOutput, error)); ok { + return rf(_a0, _a1, _a2...) + } + if rf, ok := ret.Get(0).(func(context.Context, *sns.ListSubscriptionsByTopicInput, ...func(*sns.Options)) *sns.ListSubscriptionsByTopicOutput); ok { + r0 = rf(_a0, _a1, _a2...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sns.ListSubscriptionsByTopicOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sns.ListSubscriptionsByTopicInput, ...func(*sns.Options)) error); ok { + r1 = rf(_a0, _a1, _a2...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Subscribe provides a mock function with given fields: ctx, params, optFns +func (_m *SnsClient) Subscribe(ctx context.Context, params *sns.SubscribeInput, optFns ...func(*sns.Options)) (*sns.SubscribeOutput, error) { + _va := make([]interface{}, len(optFns)) + for _i := range optFns { + _va[_i] = optFns[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, params) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Subscribe") + } + + var r0 *sns.SubscribeOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sns.SubscribeInput, ...func(*sns.Options)) (*sns.SubscribeOutput, error)); ok { + return rf(ctx, params, optFns...) + } + if rf, ok := ret.Get(0).(func(context.Context, *sns.SubscribeInput, ...func(*sns.Options)) *sns.SubscribeOutput); ok { + r0 = rf(ctx, params, optFns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sns.SubscribeOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sns.SubscribeInput, ...func(*sns.Options)) error); ok { + r1 = rf(ctx, params, optFns...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Unsubscribe provides a mock function with given fields: ctx, params, optFns +func (_m *SnsClient) Unsubscribe(ctx context.Context, params *sns.UnsubscribeInput, optFns ...func(*sns.Options)) (*sns.UnsubscribeOutput, error) { + _va := make([]interface{}, len(optFns)) + for _i := range optFns { + _va[_i] = optFns[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, params) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Unsubscribe") + } + + var r0 *sns.UnsubscribeOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sns.UnsubscribeInput, ...func(*sns.Options)) (*sns.UnsubscribeOutput, error)); ok { + return rf(ctx, params, optFns...) + } + if rf, ok := ret.Get(0).(func(context.Context, *sns.UnsubscribeInput, ...func(*sns.Options)) *sns.UnsubscribeOutput); ok { + r0 = rf(ctx, params, optFns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sns.UnsubscribeOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sns.UnsubscribeInput, ...func(*sns.Options)) error); ok { + r1 = rf(ctx, params, optFns...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewSnsClient creates a new instance of SnsClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewSnsClient(t interface { + mock.TestingT + Cleanup(func()) +}) *SnsClient { + mock := &SnsClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/notifications/receiverMocks/SqsClient.go b/pkg/notifications/receiverMocks/SqsClient.go new file mode 100644 index 00000000000..c8e6e6aa284 --- /dev/null +++ b/pkg/notifications/receiverMocks/SqsClient.go @@ -0,0 +1,215 @@ +// Code generated by mockery. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + sqs "github.com/aws/aws-sdk-go-v2/service/sqs" +) + +// SqsClient is an autogenerated mock type for the SqsClient type +type SqsClient struct { + mock.Mock +} + +// CreateQueue provides a mock function with given fields: ctx, params, optFns +func (_m *SqsClient) CreateQueue(ctx context.Context, params *sqs.CreateQueueInput, optFns ...func(*sqs.Options)) (*sqs.CreateQueueOutput, error) { + _va := make([]interface{}, len(optFns)) + for _i := range optFns { + _va[_i] = optFns[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, params) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for CreateQueue") + } + + var r0 *sqs.CreateQueueOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sqs.CreateQueueInput, ...func(*sqs.Options)) (*sqs.CreateQueueOutput, error)); ok { + return rf(ctx, params, optFns...) + } + if rf, ok := ret.Get(0).(func(context.Context, *sqs.CreateQueueInput, ...func(*sqs.Options)) *sqs.CreateQueueOutput); ok { + r0 = rf(ctx, params, optFns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sqs.CreateQueueOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sqs.CreateQueueInput, ...func(*sqs.Options)) error); ok { + r1 = rf(ctx, params, optFns...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteMessage provides a mock function with given fields: ctx, params, optFns +func (_m *SqsClient) DeleteMessage(ctx context.Context, params *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) (*sqs.DeleteMessageOutput, error) { + _va := make([]interface{}, len(optFns)) + for _i := range optFns { + _va[_i] = optFns[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, params) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for DeleteMessage") + } + + var r0 *sqs.DeleteMessageOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sqs.DeleteMessageInput, ...func(*sqs.Options)) (*sqs.DeleteMessageOutput, error)); ok { + return rf(ctx, params, optFns...) + } + if rf, ok := ret.Get(0).(func(context.Context, *sqs.DeleteMessageInput, ...func(*sqs.Options)) *sqs.DeleteMessageOutput); ok { + r0 = rf(ctx, params, optFns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sqs.DeleteMessageOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sqs.DeleteMessageInput, ...func(*sqs.Options)) error); ok { + r1 = rf(ctx, params, optFns...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteQueue provides a mock function with given fields: ctx, params, optFns +func (_m *SqsClient) DeleteQueue(ctx context.Context, params *sqs.DeleteQueueInput, optFns ...func(*sqs.Options)) (*sqs.DeleteQueueOutput, error) { + _va := make([]interface{}, len(optFns)) + for _i := range optFns { + _va[_i] = optFns[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, params) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for DeleteQueue") + } + + var r0 *sqs.DeleteQueueOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sqs.DeleteQueueInput, ...func(*sqs.Options)) (*sqs.DeleteQueueOutput, error)); ok { + return rf(ctx, params, optFns...) + } + if rf, ok := ret.Get(0).(func(context.Context, *sqs.DeleteQueueInput, ...func(*sqs.Options)) *sqs.DeleteQueueOutput); ok { + r0 = rf(ctx, params, optFns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sqs.DeleteQueueOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sqs.DeleteQueueInput, ...func(*sqs.Options)) error); ok { + r1 = rf(ctx, params, optFns...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ListQueues provides a mock function with given fields: ctx, params, optFns +func (_m *SqsClient) ListQueues(ctx context.Context, params *sqs.ListQueuesInput, optFns ...func(*sqs.Options)) (*sqs.ListQueuesOutput, error) { + _va := make([]interface{}, len(optFns)) + for _i := range optFns { + _va[_i] = optFns[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, params) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for ListQueues") + } + + var r0 *sqs.ListQueuesOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sqs.ListQueuesInput, ...func(*sqs.Options)) (*sqs.ListQueuesOutput, error)); ok { + return rf(ctx, params, optFns...) + } + if rf, ok := ret.Get(0).(func(context.Context, *sqs.ListQueuesInput, ...func(*sqs.Options)) *sqs.ListQueuesOutput); ok { + r0 = rf(ctx, params, optFns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sqs.ListQueuesOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sqs.ListQueuesInput, ...func(*sqs.Options)) error); ok { + r1 = rf(ctx, params, optFns...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReceiveMessage provides a mock function with given fields: ctx, params, optFns +func (_m *SqsClient) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { + _va := make([]interface{}, len(optFns)) + for _i := range optFns { + _va[_i] = optFns[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, params) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for ReceiveMessage") + } + + var r0 *sqs.ReceiveMessageOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sqs.ReceiveMessageInput, ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error)); ok { + return rf(ctx, params, optFns...) + } + if rf, ok := ret.Get(0).(func(context.Context, *sqs.ReceiveMessageInput, ...func(*sqs.Options)) *sqs.ReceiveMessageOutput); ok { + r0 = rf(ctx, params, optFns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sqs.ReceiveMessageOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sqs.ReceiveMessageInput, ...func(*sqs.Options)) error); ok { + r1 = rf(ctx, params, optFns...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewSqsClient creates a new instance of SqsClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewSqsClient(t interface { + mock.TestingT + Cleanup(func()) +}) *SqsClient { + mock := &SqsClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/notifications/receiverMocks/ViperType.go b/pkg/notifications/receiverMocks/ViperType.go new file mode 100644 index 00000000000..bf5e6f84090 --- /dev/null +++ b/pkg/notifications/receiverMocks/ViperType.go @@ -0,0 +1,51 @@ +// Code generated by mockery. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + + strings "strings" +) + +// ViperType is an autogenerated mock type for the ViperType type +type ViperType struct { + mock.Mock +} + +// GetString provides a mock function with given fields: _a0 +func (_m *ViperType) GetString(_a0 string) string { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for GetString") + } + + var r0 string + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// SetEnvKeyReplacer provides a mock function with given fields: _a0 +func (_m *ViperType) SetEnvKeyReplacer(_a0 *strings.Replacer) { + _m.Called(_a0) +} + +// NewViperType creates a new instance of ViperType. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewViperType(t interface { + mock.TestingT + Cleanup(func()) +}) *ViperType { + mock := &ViperType{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/storage/filesystem.go b/pkg/storage/filesystem.go index 259fd4ee8ab..f6e43583420 100644 --- a/pkg/storage/filesystem.go +++ b/pkg/storage/filesystem.go @@ -116,6 +116,8 @@ func (fs *Filesystem) Fetch(key string) (io.ReadCloser, error) { // Tags returns the tags for a specified key func (fs *Filesystem) Tags(_ string) (map[string]string, error) { tags := make(map[string]string) + // Assume anti-virus complete + tags["av-status"] = "CLEAN" return tags, nil } diff --git a/pkg/storage/filesystem_test.go b/pkg/storage/filesystem_test.go index 27ecc5e951c..9c37b9204c8 100644 --- a/pkg/storage/filesystem_test.go +++ b/pkg/storage/filesystem_test.go @@ -1,6 +1,8 @@ package storage import ( + "io" + "strings" "testing" ) @@ -21,3 +23,62 @@ func TestFilesystemPresignedURL(t *testing.T) { t.Errorf("wrong presigned url: expected %s, got %s", expected, url) } } + +func TestFilesystemReturnsSuccessful(t *testing.T) { + fsParams := FilesystemParams{ + root: "./", + webRoot: "https://example.text/files", + } + filesystem := NewFilesystem(fsParams) + if filesystem == nil { + t.Fatal("could not create new filesystem") + } + + storeValue := strings.NewReader("anyValue") + _, err := filesystem.Store("anyKey", storeValue, "", nil) + if err != nil { + t.Fatalf("could not store in filesystem: %s", err) + } + + retReader, err := filesystem.Fetch("anyKey") + if err != nil { + t.Fatalf("could not fetch from filesystem: %s", err) + } + + err = filesystem.Delete("anyKey") + if err != nil { + t.Fatalf("could not delete on filesystem: %s", err) + } + + retValue, err := io.ReadAll(retReader) + if strings.Compare(string(retValue[:]), "anyValue") != 0 { + t.Fatalf("could not fetch from filesystem: %s", err) + } + + fileSystem := filesystem.FileSystem() + if fileSystem == nil { + t.Fatal("could not retrieve filesystem from filesystem") + } + + tempFileSystem := filesystem.TempFileSystem() + if tempFileSystem == nil { + t.Fatal("could not retrieve filesystem from filesystem") + } +} + +func TestFilesystemTags(t *testing.T) { + fsParams := FilesystemParams{ + root: "/home/username", + webRoot: "https://example.text/files", + } + fs := NewFilesystem(fsParams) + + tags, err := fs.Tags("anyKey") + if err != nil { + t.Fatalf("could not get tags: %s", err) + } + + if tag, exists := tags["av-status"]; exists && strings.Compare(tag, "CLEAN") != 0 { + t.Fatal("tag 'av-status' should return CLEAN") + } +} diff --git a/pkg/storage/memory.go b/pkg/storage/memory.go index 2f06ed6b96e..4e171e40e9d 100644 --- a/pkg/storage/memory.go +++ b/pkg/storage/memory.go @@ -116,6 +116,8 @@ func (fs *Memory) Fetch(key string) (io.ReadCloser, error) { // Tags returns the tags for a specified key func (fs *Memory) Tags(_ string) (map[string]string, error) { tags := make(map[string]string) + // Assume anti-virus complete + tags["av-status"] = "CLEAN" return tags, nil } diff --git a/pkg/storage/memory_test.go b/pkg/storage/memory_test.go index 59384c5acee..bdf3133e9c8 100644 --- a/pkg/storage/memory_test.go +++ b/pkg/storage/memory_test.go @@ -1,6 +1,8 @@ package storage import ( + "io" + "strings" "testing" ) @@ -21,3 +23,62 @@ func TestMemoryPresignedURL(t *testing.T) { t.Errorf("wrong presigned url: expected %s, got %s", expected, url) } } + +func TestMemoryReturnsSuccessful(t *testing.T) { + fsParams := MemoryParams{ + root: "/home/username", + webRoot: "https://example.text/files", + } + memory := NewMemory(fsParams) + if memory == nil { + t.Fatal("could not create new memory") + } + + storeValue := strings.NewReader("anyValue") + _, err := memory.Store("anyKey", storeValue, "", nil) + if err != nil { + t.Fatalf("could not store in memory: %s", err) + } + + retReader, err := memory.Fetch("anyKey") + if err != nil { + t.Fatalf("could not fetch from memory: %s", err) + } + + err = memory.Delete("anyKey") + if err != nil { + t.Fatalf("could not delete on memory: %s", err) + } + + retValue, err := io.ReadAll(retReader) + if strings.Compare(string(retValue[:]), "anyValue") != 0 { + t.Fatalf("could not fetch from memory: %s", err) + } + + fileSystem := memory.FileSystem() + if fileSystem == nil { + t.Fatal("could not retrieve filesystem from memory") + } + + tempFileSystem := memory.TempFileSystem() + if tempFileSystem == nil { + t.Fatal("could not retrieve filesystem from memory") + } +} + +func TestMemoryTags(t *testing.T) { + fsParams := MemoryParams{ + root: "/home/username", + webRoot: "https://example.text/files", + } + fs := NewMemory(fsParams) + + tags, err := fs.Tags("anyKey") + if err != nil { + t.Fatalf("could not get tags: %s", err) + } + + if tag, exists := tags["av-status"]; exists && strings.Compare(tag, "CLEAN") != 0 { + t.Fatal("tag 'av-status' should return CLEAN") + } +} diff --git a/pkg/storage/test/s3.go b/pkg/storage/test/s3.go index 97d06e7733d..56fbac83564 100644 --- a/pkg/storage/test/s3.go +++ b/pkg/storage/test/s3.go @@ -18,6 +18,7 @@ type FakeS3Storage struct { willSucceed bool fs *afero.Afero tempFs *afero.Afero + EmptyTags bool // Used for testing only } // Delete removes a file. @@ -95,7 +96,11 @@ func (fake *FakeS3Storage) TempFileSystem() *afero.Afero { // Tags returns the tags for a specified key func (fake *FakeS3Storage) Tags(_ string) (map[string]string, error) { tags := map[string]string{ - "tagName": "tagValue", + "av-status": "CLEAN", // Assume anti-virus run + } + if fake.EmptyTags { + tags = map[string]string{} + fake.EmptyTags = false // Reset after initial return, so future calls (tests) have filled tags } return tags, nil } diff --git a/pkg/storage/test/s3_test.go b/pkg/storage/test/s3_test.go new file mode 100644 index 00000000000..3c2f63bbeff --- /dev/null +++ b/pkg/storage/test/s3_test.go @@ -0,0 +1,101 @@ +package test + +import ( + "errors" + "io" + "strings" + "testing" +) + +// Tests all functions of FakeS3Storage +func TestFakeS3ReturnsSuccessful(t *testing.T) { + fakeS3 := NewFakeS3Storage(true) + if fakeS3 == nil { + t.Fatal("could not create new fakeS3") + } + + storeValue := strings.NewReader("anyValue") + _, err := fakeS3.Store("anyKey", storeValue, "", nil) + if err != nil { + t.Fatalf("could not store in fakeS3: %s", err) + } + + retReader, err := fakeS3.Fetch("anyKey") + if err != nil { + t.Fatalf("could not fetch from fakeS3: %s", err) + } + + err = fakeS3.Delete("anyKey") + if err != nil { + t.Fatalf("could not delete on fakeS3: %s", err) + } + + retValue, err := io.ReadAll(retReader) + if strings.Compare(string(retValue[:]), "anyValue") != 0 { + t.Fatalf("could not fetch from fakeS3: %s", err) + } + + fileSystem := fakeS3.FileSystem() + if fileSystem == nil { + t.Fatal("could not retrieve filesystem from fakeS3") + } + + tempFileSystem := fakeS3.TempFileSystem() + if tempFileSystem == nil { + t.Fatal("could not retrieve filesystem from fakeS3") + } + + tags, err := fakeS3.Tags("anyKey") + if err != nil { + t.Fatalf("could not fetch from fakeS3: %s", err) + } + if len(tags) != 1 { + t.Fatal("return tags must have av-status key assigned for fakeS3") + } + + presignedUrl, err := fakeS3.PresignedURL("anyKey", "anyContentType", "anyFileName") + if err != nil { + t.Fatal("could not retrieve presignedUrl from fakeS3") + } + + if strings.Compare(presignedUrl, "https://example.com/dir/anyKey?response-content-disposition=attachment%3B+filename%3D%22anyFileName%22&response-content-type=anyContentType&signed=test") != 0 { + t.Fatalf("could not retrieve proper presignedUrl from fakeS3 %s", presignedUrl) + } +} + +// Test for willSucceed false +func TestFakeS3WillNotSucceed(t *testing.T) { + fakeS3 := NewFakeS3Storage(false) + if fakeS3 == nil { + t.Fatalf("could not create new fakeS3") + } + + storeValue := strings.NewReader("anyValue") + _, err := fakeS3.Store("anyKey", storeValue, "", nil) + if err == nil || errors.Is(err, errors.New("failed to push")) { + t.Fatalf("should not be able to store when willSucceed false: %s", err) + } + + _, err = fakeS3.Fetch("anyKey") + if err == nil || errors.Is(err, errors.New("failed to fetch file")) { + t.Fatalf("should not find file on Fetch for willSucceed false: %s", err) + } +} + +// Tests empty tag returns empty tags on FakeS3Storage +func TestFakeS3ReturnsEmptyTags(t *testing.T) { + fakeS3 := NewFakeS3Storage(true) + if fakeS3 == nil { + t.Fatal("could not create new fakeS3") + } + + fakeS3.EmptyTags = true + + tags, err := fakeS3.Tags("anyKey") + if err != nil { + t.Fatalf("could not fetch from fakeS3: %s", err) + } + if len(tags) != 0 { + t.Fatal("return tags must be empty for FakeS3 when EmptyTags set to true") + } +} diff --git a/swagger-def/ghc.yaml b/swagger-def/ghc.yaml index df06a4ca220..e429a430bcd 100644 --- a/swagger-def/ghc.yaml +++ b/swagger-def/ghc.yaml @@ -4284,6 +4284,42 @@ paths: description: payload is too large '500': description: server error + /uploads/{uploadID}/status: + get: + summary: Returns status of an upload + description: Returns status of an upload based on antivirus run + operationId: getUploadStatus + produces: + - text/event-stream + tags: + - uploads + parameters: + - in: path + name: uploadID + type: string + format: uuid + required: true + description: UUID of the upload to return status of + responses: + '200': + description: the requested upload status + schema: + type: string + enum: + - INFECTED + - CLEAN + - PROCESSING + readOnly: true + '400': + description: invalid request + schema: + $ref: '#/definitions/InvalidRequestResponsePayload' + '403': + description: not authorized + '404': + description: not found + '500': + description: server error /application_parameters/{parameterName}: get: summary: Searches for an application parameter by name, returns nil if not found diff --git a/swagger/ghc.yaml b/swagger/ghc.yaml index 81bca009269..78ba66adc8d 100644 --- a/swagger/ghc.yaml +++ b/swagger/ghc.yaml @@ -4501,6 +4501,42 @@ paths: description: payload is too large '500': description: server error + /uploads/{uploadID}/status: + get: + summary: Returns status of an upload + description: Returns status of an upload based on antivirus run + operationId: getUploadStatus + produces: + - text/event-stream + tags: + - uploads + parameters: + - in: path + name: uploadID + type: string + format: uuid + required: true + description: UUID of the upload to return status of + responses: + '200': + description: the requested upload status + schema: + type: string + enum: + - INFECTED + - CLEAN + - PROCESSING + readOnly: true + '400': + description: invalid request + schema: + $ref: '#/definitions/InvalidRequestResponsePayload' + '403': + description: not authorized + '404': + description: not found + '500': + description: server error /application_parameters/{parameterName}: get: summary: Searches for an application parameter by name, returns nil if not found