diff --git a/README.md b/README.md index 4bc9e8d..568d7f9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # rdscheck + Copy command will: - - Copy snapshot(s) to a different AWS region in the same account + - Copy snapshot(s) to a different AWS region in the same account. This will copy only automated snapshots. - Cleanup old snapshots based on retention setup in the yaml config file + Check command will: - Creates new rds instance(s) with the snapshots @@ -20,6 +20,7 @@ - password: `the password that we will use to connect to the database. It doesn't need to be the original one. We will use this one to reset the original password` - retention: `how many days we want to keep the copied snapshot around` - destination: `the aws region where we will copy/restore the snapshot` + - kmsid: `the id (ARN) of the kms key that you want to use on the destination region. This is needed if your original snapshot is encrypted` - queries: `all the sql queries we want to run on the restored snapshot to validate it and the expected results as regex` - query: `the sql query to run` - regex: `the regex of the expected result` @@ -33,6 +34,7 @@ instances: password: thisisatest retention: 1 destination: us-east-1 + kmsid: "arn:aws:kms:us-east-1:1234567890:key/123456-7890-123456" queries: - query: "SELECT tablename FROM pg_catalog.pg_tables;" regex: "^pg_statistic$" diff --git a/checks/commands.go b/checks/commands.go index d332632..4746ab7 100644 --- a/checks/commands.go +++ b/checks/commands.go @@ -3,7 +3,6 @@ package checks import ( "errors" "sort" - "strings" "time" "github.com/aws/aws-sdk-go/aws" @@ -42,10 +41,7 @@ func (c *Client) GetSnapshots(DBInstanceIdentifier string) ([]*rds.DBSnapshot, e // CopySnapshots copies the snapshots either to the same region as the original // or to a new region -func (c *Client) CopySnapshots(snapshot *rds.DBSnapshot, destination string) error { - - arn := strings.SplitN(*snapshot.DBSnapshotArn, ":", 8) - cleanArn := arn[len(arn)-1] +func (c *Client) CopySnapshots(snapshot *rds.DBSnapshot, destination, kmsid, preSignedUrl, cleanArn string) error { input := &rds.CopyDBSnapshotInput{ SourceRegion: aws.String(config.AWSRegionSource), @@ -71,6 +67,12 @@ func (c *Client) CopySnapshots(snapshot *rds.DBSnapshot, destination string) err }, }, } + + if *snapshot.Encrypted { + input.PreSignedUrl = aws.String(preSignedUrl) + input.KmsKeyId = aws.String(kmsid) + } + _, err := c.RDS.CopyDBSnapshot(input) if err != nil { if aerr, ok := err.(awserr.Error); ok { @@ -86,10 +88,33 @@ func (c *Client) CopySnapshots(snapshot *rds.DBSnapshot, destination string) err } else { return err } + } else { + log.WithFields(log.Fields{ + "Snapshot": *snapshot.DBSnapshotIdentifier, + "From": config.AWSRegionSource, + "Destination": destination, + }).Info("Snapshot copied") } return nil } +// PreSignUrl presigned an aws url so that we can copy an encrypted snapshot from a region to another +func (c *Client) PreSignUrl(destinationRegion, snapshotArn, kmsid, cleanArn string) (string, error) { + input := &rds.CopyDBSnapshotInput{ + SourceRegion: aws.String(config.AWSRegionSource), + DestinationRegion: aws.String(destinationRegion), + SourceDBSnapshotIdentifier: aws.String(snapshotArn), + KmsKeyId: aws.String(kmsid), + TargetDBSnapshotIdentifier: aws.String(cleanArn), + } + req, _ := c.RDS.CopyDBSnapshotRequest(input) + url, err := req.Presign(time.Duration(5) * time.Minute) + if err != nil { + return "", err + } + return url, nil +} + // GetOldSnapshots gets old snapshots based on the retention policy // retentionDays is a integer of the number of days we want to keep the snapshots. func (c *Client) GetOldSnapshots(snapshots []*rds.DBSnapshot, retentionDays int) ([]*rds.DBSnapshot, error) { diff --git a/checks/commands_test.go b/checks/commands_test.go index 27e528d..8395db1 100644 --- a/checks/commands_test.go +++ b/checks/commands_test.go @@ -1,10 +1,13 @@ package checks import ( + "net/http" + "net/url" "testing" "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/rds/rdsiface" "github.com/stretchr/testify/assert" @@ -81,6 +84,11 @@ func (m *mockRDS) ModifyDBInstance(input *rds.ModifyDBInstanceInput) (*rds.Modif return args.Get(0).(*rds.ModifyDBInstanceOutput), args.Error(1) } +func (m *mockRDS) CopyDBSnapshotRequest(input *rds.CopyDBSnapshotInput) (*request.Request, *rds.CopyDBSnapshotOutput) { + args := m.Called(input) + return args.Get(0).(*request.Request), args.Get(1).(*rds.CopyDBSnapshotOutput) +} + func TestGetSnapshots(t *testing.T) { rdsc := &mockRDS{} @@ -109,7 +117,30 @@ func TestGetSnapshots(t *testing.T) { rdsc.AssertExpectations(t) } -func TestCopySnapshots(t *testing.T) { +func TestCopySnapshotsNoKms(t *testing.T) { + rdsc := &mockRDS{} + + c := &Client{ + RDS: rdsc, + } + + input := &rds.DBSnapshot{ + DBSnapshotIdentifier: aws.String("test"), + DBSnapshotArn: aws.String("arn:aws:rds:us-west-2:123456789012:snapshot:test"), + Encrypted: aws.Bool(false), + } + + rdsc.On("CopyDBSnapshot", mock.Anything).Return(&rds.CopyDBSnapshotOutput{ + DBSnapshot: &rds.DBSnapshot{}, + }, nil) + + err := c.CopySnapshots(input, "us-west-2", "", "", "test") + assert.Nil(t, err) + rdsc.AssertExpectations(t) + +} + +func TestCopySnapshotsWithKms(t *testing.T) { rdsc := &mockRDS{} c := &Client{ @@ -119,13 +150,15 @@ func TestCopySnapshots(t *testing.T) { input := &rds.DBSnapshot{ DBSnapshotIdentifier: aws.String("test"), DBSnapshotArn: aws.String("arn:aws:rds:us-west-2:123456789012:snapshot:test"), + KmsKeyId: aws.String("arn:aws:kms:us-east-1:1234567890:key/123456-7890-123456"), + Encrypted: aws.Bool(true), } rdsc.On("CopyDBSnapshot", mock.Anything).Return(&rds.CopyDBSnapshotOutput{ DBSnapshot: &rds.DBSnapshot{}, }, nil) - err := c.CopySnapshots(input, "us-west-2") + err := c.CopySnapshots(input, "us-west-2", "arn:aws:kms:us-east-1:1234567890:key/123456-7890-123456", "https://url.local", "test") assert.Nil(t, err) rdsc.AssertExpectations(t) @@ -432,3 +465,32 @@ func TestGetTagValue(t *testing.T) { assert.Equal(t, value, "restore") rdsc.AssertExpectations(t) } + +func TestPreSignUrl(t *testing.T) { + rdsc := &mockRDS{} + + c := &Client{ + RDS: rdsc, + } + + u := &url.URL{ + Scheme: "http", + Host: "fakeurl.aws.com", + } + + req := &request.Request{ + HTTPRequest: &http.Request{ + URL: u, + }, + Operation: &request.Operation{}, + } + + output := &rds.CopyDBSnapshotOutput{} + + rdsc.On("CopyDBSnapshotRequest", mock.Anything).Return(req, output) + + value, err := c.PreSignUrl("us-east-2", "arn:aws:rds:us-west-2:123456789012:snapshot:test", "arn:aws:kms:us-east-1:1234567890:key/123456-7890-123456", "test") + assert.Nil(t, err) + assert.Equal(t, value, "http://fakeurl.aws.com") + rdsc.AssertExpectations(t) +} diff --git a/checks/common.go b/checks/common.go index 75d1cdb..d4eb90c 100644 --- a/checks/common.go +++ b/checks/common.go @@ -4,6 +4,7 @@ import ( "database/sql" "io" "io/ioutil" + "strings" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" @@ -23,9 +24,9 @@ type DefaultChecks interface { GetYamlFileFromS3(bucket, key string) (io.Reader, error) UnmarshalYamlFile(body io.Reader) (Doc, error) DataDogSession(apiKey, applicationKey string) *datadog.Client - PostDatadogChecks(snapshot *rds.DBSnapshot, metricName, status string) error + PostDatadogChecks(snapshot *rds.DBSnapshot, metricName, status, cmdName string) error GetSnapshots(DBInstanceIdentifier string) ([]*rds.DBSnapshot, error) - CopySnapshots(snapshot *rds.DBSnapshot, destination string) error + CopySnapshots(snapshot *rds.DBSnapshot, destination, kmsid, preSignedUrl, cleanArn string) error GetOldSnapshots(snapshots []*rds.DBSnapshot, retention int) ([]*rds.DBSnapshot, error) DeleteOldSnapshots(snapshots []*rds.DBSnapshot) error CheckIfDatabaseSubnetGroupExist(snapshot *rds.DBSnapshot) bool @@ -41,6 +42,8 @@ type DefaultChecks interface { GetTagValue(arn, key string) string InitDb(db *rds.DBInstance, password, dbname string) error CheckRegexAgainstRow(query, regex string) bool + PreSignUrl(destinationRegion, snapshotArn, kmsid, cleanArn string) (string, error) + CleanArn(snapshot *rds.DBSnapshot) string } type Client struct { @@ -62,6 +65,7 @@ type Instances struct { Password string Retention int Destination string + KmsID string Queries []Queries } @@ -132,11 +136,12 @@ func (c *Client) UnmarshalYamlFile(body io.Reader) (Doc, error) { } // PostDatadogChecks posts to datadog the status of a check -func (c *Client) PostDatadogChecks(snapshot *rds.DBSnapshot, metricName, status string) error { +func (c *Client) PostDatadogChecks(snapshot *rds.DBSnapshot, metricName, status, cmdName string) error { tags := []string{ "database:" + *snapshot.DBInstanceIdentifier, "snapshot:" + *snapshot.DBSnapshotIdentifier, + "command" + cmdName, } timeNow := utils.GetUnixTimeAsString() @@ -157,3 +162,9 @@ func (c *Client) PostDatadogChecks(snapshot *rds.DBSnapshot, metricName, status } return nil } + +func (c *Client) CleanArn(snapshot *rds.DBSnapshot) string { + arn := strings.SplitN(*snapshot.DBSnapshotArn, ":", 8) + cleanArn := arn[len(arn)-1] + return cleanArn +} diff --git a/checks/common_test.go b/checks/common_test.go index d647dc5..3e2c783 100644 --- a/checks/common_test.go +++ b/checks/common_test.go @@ -65,6 +65,7 @@ func TestUnmarshalYamlFile(t *testing.T) { Password: "thisisatest", Retention: 1, Destination: "us-east-1", + KmsID: "arn:aws:kms:us-east-1:1234567890:key/123456-7890-123456", Queries: []Queries{ Queries{ Query: "SELECT tablename FROM pg_catalog.pg_tables;", @@ -116,6 +117,17 @@ func TestPostDatadogChecks(t *testing.T) { DBSnapshotIdentifier: aws.String("test"), } - err := c.PostDatadogChecks(input, "rdscheck.status", "ok") + err := c.PostDatadogChecks(input, "rdscheck.status", "ok", "check") assert.Nil(t, err) } + +func TestCleanArn(t *testing.T) { + c := &Client{} + + input := &rds.DBSnapshot{ + DBSnapshotArn: aws.String("arn:aws:rds:us-west-2:123456789012:snapshot:test"), + } + + value := c.CleanArn(input) + assert.Equal(t, value, "test") +} diff --git a/cmd/check/main.go b/cmd/check/main.go index 37c9f89..f0a9f9c 100644 --- a/cmd/check/main.go +++ b/cmd/check/main.go @@ -108,7 +108,7 @@ func process(destination checks.DefaultChecks, snapshot *rds.DBSnapshot, instanc } func caseReady(destination checks.DefaultChecks, snapshot *rds.DBSnapshot) error { - err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "ok") + err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "ok", "check") if err != nil { log.WithError(err).Error("Could not update datadog status") return err @@ -215,33 +215,31 @@ func caseVerify(destination checks.DefaultChecks, snapshot *rds.DBSnapshot, inst return err } - if instance.Name == *dbInfo.DBName { - for _, query := range instance.Queries { - if destination.CheckRegexAgainstRow(query.Query, query.Regex) { - err := destination.UpdateTag(snapshot, "Status", "clean") - if err != nil { - return err - } - } else { - log.WithFields(log.Fields{ - "RDS Instance": string(*snapshot.DBInstanceIdentifier + "-" + *snapshot.DBSnapshotIdentifier), - "DB Name": *dbInfo.DBName, - "Query": query.Query, - "Regex": query.Regex, - }).Errorf("Query matched failed: %s", err) - errors := destination.UpdateTag(snapshot, "Status", "alarm") - if errors != nil { - return err - } + for _, query := range instance.Queries { + if destination.CheckRegexAgainstRow(query.Query, query.Regex) { + err := destination.UpdateTag(snapshot, "Status", "clean") + if err != nil { return err } + } else { + log.WithFields(log.Fields{ + "RDS Instance": string(*snapshot.DBInstanceIdentifier + "-" + *snapshot.DBSnapshotIdentifier), + "DB Name": *dbInfo.DBName, + "Query": query.Query, + "Regex": query.Regex, + }).Errorf("Query matched failed: %s", err) + errors := destination.UpdateTag(snapshot, "Status", "alarm") + if errors != nil { + return err + } + return err } } return nil } func caseAlarm(destination checks.DefaultChecks, snapshot *rds.DBSnapshot) error { - err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "critical") + err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "critical", "check") if err != nil { log.WithError(err).Error("Could not update datadog status") return err diff --git a/cmd/check/main_test.go b/cmd/check/main_test.go index 7317e8a..6f5c873 100644 --- a/cmd/check/main_test.go +++ b/cmd/check/main_test.go @@ -63,8 +63,8 @@ func (m *mockDefaultChecks) GetTagValue(arn, key string) string { return args.Get(0).(string) } -func (m *mockDefaultChecks) PostDatadogChecks(snapshot *rds.DBSnapshot, metricName, status string) error { - args := m.Called(snapshot, metricName, status) +func (m *mockDefaultChecks) PostDatadogChecks(snapshot *rds.DBSnapshot, metricName, status, cmdName string) error { + args := m.Called(snapshot, metricName, status, cmdName) return args.Error(0) } @@ -169,7 +169,7 @@ func TestGetDoc(t *testing.T) { func TestCaseReady(t *testing.T) { c := &mockDefaultChecks{} - c.On("PostDatadogChecks", mock.Anything, mock.Anything, mock.Anything).Return(nil) + c.On("PostDatadogChecks", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) c.On("CreateDatabaseSubnetGroup", mock.Anything, mock.Anything).Return(nil) c.On("UpdateTag", mock.Anything, mock.Anything, mock.Anything).Return(nil) @@ -223,7 +223,7 @@ func TestCaseVerify(t *testing.T) { func TestCaseAlarm(t *testing.T) { c := &mockDefaultChecks{} - c.On("PostDatadogChecks", mock.Anything, mock.Anything, mock.Anything).Return(nil) + c.On("PostDatadogChecks", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) c.On("UpdateTag", mock.Anything, mock.Anything, mock.Anything).Return(nil) err := caseAlarm(c, singleSnapshot) diff --git a/cmd/copy/main.go b/cmd/copy/main.go index baf627d..5c103ec 100644 --- a/cmd/copy/main.go +++ b/cmd/copy/main.go @@ -50,12 +50,40 @@ func copy(source checks.DefaultChecks, destination checks.DefaultChecks) error { return err } for _, snapshot := range snapshots { - err := destination.CopySnapshots(snapshot, instance.Destination) - if err != nil { - log.WithFields(log.Fields{ - "Snapshot": *snapshot.DBSnapshotIdentifier, - }).Errorf("Could not copy snapshot: %s", err) - return err + if *snapshot.SnapshotType == "automated" { + err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "ok", "copy") + if err != nil { + log.WithError(err).Error("Could not update datadog status") + return err + } + var preSignedUrl string + cleanArn := destination.CleanArn(snapshot) + if *snapshot.Encrypted { + preSignedUrl, err = source.PreSignUrl(instance.Destination, *snapshot.DBSnapshotArn, instance.KmsID, cleanArn) + if err != nil { + log.WithFields(log.Fields{ + "snapshot": *snapshot.DBSnapshotIdentifier, + }).WithError(err).Error("Could not presigned the url") + err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "critical", "copy") + if err != nil { + log.WithError(err).Error("Could not update datadog status") + return err + } + return err + } + } + err = destination.CopySnapshots(snapshot, instance.Destination, instance.KmsID, preSignedUrl, cleanArn) + if err != nil { + log.WithFields(log.Fields{ + "Snapshot": *snapshot.DBSnapshotIdentifier, + }).WithError(err).Error("Could not copy snapshot") + err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "critical", "copy") + if err != nil { + log.WithError(err).Error("Could not update datadog status") + return err + } + return err + } } } @@ -67,17 +95,20 @@ func copy(source checks.DefaultChecks, destination checks.DefaultChecks) error { }).Errorf("Could not get snapshots: %s", err) return err } + for _, snapshot := range snapshots { + if destination.CheckTag(*snapshot.DBSnapshotArn, "CreatedBy", "rdscheck") { + oldSnapshots, err := destination.GetOldSnapshots(snapshots, instance.Retention) + if err != nil { + log.WithError(err).Error("Could not get old snapshots") + return err + } - oldSnapshots, err := destination.GetOldSnapshots(snapshots, instance.Retention) - if err != nil { - log.WithError(err).Error("Could not get old snapshots") - return err - } - - err = destination.DeleteOldSnapshots(oldSnapshots) - if err != nil { - log.WithError(err).Error("Could not delete old snapshots") - return err + err = destination.DeleteOldSnapshots(oldSnapshots) + if err != nil { + log.WithError(err).Error("Could not delete old snapshots") + return err + } + } } } return nil diff --git a/cmd/copy/main_test.go b/cmd/copy/main_test.go index 4f6ee73..6bff934 100644 --- a/cmd/copy/main_test.go +++ b/cmd/copy/main_test.go @@ -41,17 +41,21 @@ var snapshots = []*rds.DBSnapshot{ DBSnapshotIdentifier: aws.String("test"), SnapshotCreateTime: aws.Time(time.Now().AddDate(0, 0, -10)), DBSnapshotArn: aws.String("arn:aws:rds:us-west-2:123456789012:snapshot:test"), + SnapshotType: aws.String("automated"), + Encrypted: aws.Bool(true), }, &rds.DBSnapshot{ Status: aws.String("available"), DBSnapshotIdentifier: aws.String("test-2"), SnapshotCreateTime: aws.Time(time.Now()), DBSnapshotArn: aws.String("arn:aws:rds:us-west-2:123456789012:snapshot:test-2"), + SnapshotType: aws.String("automated"), + Encrypted: aws.Bool(true), }, } -func (m *mockDefaultChecks) CopySnapshots(snapshot *rds.DBSnapshot, destination string) error { - args := m.Called(snapshot, destination) +func (m *mockDefaultChecks) CopySnapshots(snapshot *rds.DBSnapshot, destination, kmsid, preSignedUrl, cleanArn string) error { + args := m.Called(snapshot, destination, kmsid) return args.Error(0) } @@ -84,6 +88,26 @@ func (m *mockDefaultChecks) UnmarshalYamlFile(body io.Reader) (checks.Doc, error return args.Get(0).(checks.Doc), args.Error(1) } +func (m *mockDefaultChecks) PreSignUrl(destinationRegion, snapshotArn, kmsid, cleanArn string) (string, error) { + args := m.Called(destinationRegion, snapshotArn, kmsid, cleanArn) + return args.Get(0).(string), args.Error(1) +} + +func (m *mockDefaultChecks) CheckTag(arn string, key string, value string) bool { + args := m.Called(arn, key, value) + return args.Bool(0) +} + +func (m *mockDefaultChecks) CleanArn(snapshot *rds.DBSnapshot) string { + args := m.Called(snapshot) + return args.Get(0).(string) +} + +func (m *mockDefaultChecks) PostDatadogChecks(snapshot *rds.DBSnapshot, metricName, status, cmdName string) error { + args := m.Called(snapshot, metricName, status, cmdName) + return args.Error(0) +} + func TestCopy(t *testing.T) { c := &mockDefaultChecks{} @@ -94,7 +118,11 @@ func TestCopy(t *testing.T) { c.On("GetYamlFileFromS3", mock.Anything, mock.Anything).Return(input, nil) c.On("UnmarshalYamlFile", mock.Anything).Return(doc, nil) c.On("GetSnapshots", mock.Anything).Return(snapshots, nil) - c.On("CopySnapshots", mock.Anything, mock.Anything).Return(nil) + c.On("PostDatadogChecks", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + c.On("CleanArn", mock.Anything).Return("test") + c.On("PreSignUrl", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return("https://url.local", nil) + c.On("CopySnapshots", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + c.On("CheckTag", mock.Anything, mock.Anything, mock.Anything).Return(true) c.On("GetOldSnapshots", mock.Anything, mock.Anything).Return(snapshots, nil) c.On("DeleteOldSnapshots", mock.Anything).Return(nil) diff --git a/example/checks.yml b/example/checks.yml index 26b7258..8a7480f 100644 --- a/example/checks.yml +++ b/example/checks.yml @@ -5,6 +5,7 @@ instances: password: thisisatest retention: 1 destination: us-east-1 + kmsid: "arn:aws:kms:us-east-1:1234567890:key/123456-7890-123456" queries: - query: "SELECT tablename FROM pg_catalog.pg_tables;" regex: "^pg_statistic$" diff --git a/terraform/terraform.tf b/terraform/terraform.tf index d57a448..33703c2 100644 --- a/terraform/terraform.tf +++ b/terraform/terraform.tf @@ -30,11 +30,10 @@ resource "null_resource" "get_release" { } data "archive_file" "lambda_code" { - type = "zip" + type = "zip" source_file = "${path.module}/lambda-files/main" output_path = "${path.module}/lambda-files/main.zip" - - depends_on = ["null_resource.get_release"] + depends_on = ["null_resource.get_release"] } resource "aws_lambda_function" "rdscheck_lambda_copy" {