Skip to content

Commit

Permalink
Allow the copy of encrypted snapshots (#4)
Browse files Browse the repository at this point in the history
* Allow the copy of encrypted snapshots

- We can now copy encrypted snapshots from a region to another by providing a kms key ARN.
- If the snapshot is encrypted we generate a Pre-signed url and add it to the CopySnapshots function
  • Loading branch information
matthieudolci authored Nov 28, 2019
1 parent 2912f7f commit 509b8dd
Show file tree
Hide file tree
Showing 11 changed files with 227 additions and 58 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# rdscheck
+ Copy command will:
- Copy snapshot(s) to a different AWS region in the same account
- Copy snapshot(s) to a different AWS region in the same account. This will copy only automated snapshots.
- Cleanup old snapshots based on retention setup in the yaml config file
+ Check command will:
- Creates new rds instance(s) with the snapshots
Expand All @@ -20,6 +20,7 @@
- password: `the password that we will use to connect to the database. It doesn't need to be the original one. We will use this one to reset the original password`
- retention: `how many days we want to keep the copied snapshot around`
- destination: `the aws region where we will copy/restore the snapshot`
- kmsid: `the id (ARN) of the kms key that you want to use on the destination region. This is needed if your original snapshot is encrypted`
- queries: `all the sql queries we want to run on the restored snapshot to validate it and the expected results as regex`
- query: `the sql query to run`
- regex: `the regex of the expected result`
Expand All @@ -33,6 +34,7 @@ instances:
password: thisisatest
retention: 1
destination: us-east-1
kmsid: "arn:aws:kms:us-east-1:1234567890:key/123456-7890-123456"
queries:
- query: "SELECT tablename FROM pg_catalog.pg_tables;"
regex: "^pg_statistic$"
Expand Down
35 changes: 30 additions & 5 deletions checks/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package checks
import (
"errors"
"sort"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -42,10 +41,7 @@ func (c *Client) GetSnapshots(DBInstanceIdentifier string) ([]*rds.DBSnapshot, e

// CopySnapshots copies the snapshots either to the same region as the original
// or to a new region
func (c *Client) CopySnapshots(snapshot *rds.DBSnapshot, destination string) error {

arn := strings.SplitN(*snapshot.DBSnapshotArn, ":", 8)
cleanArn := arn[len(arn)-1]
func (c *Client) CopySnapshots(snapshot *rds.DBSnapshot, destination, kmsid, preSignedUrl, cleanArn string) error {

input := &rds.CopyDBSnapshotInput{
SourceRegion: aws.String(config.AWSRegionSource),
Expand All @@ -71,6 +67,12 @@ func (c *Client) CopySnapshots(snapshot *rds.DBSnapshot, destination string) err
},
},
}

if *snapshot.Encrypted {
input.PreSignedUrl = aws.String(preSignedUrl)
input.KmsKeyId = aws.String(kmsid)
}

_, err := c.RDS.CopyDBSnapshot(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
Expand All @@ -86,10 +88,33 @@ func (c *Client) CopySnapshots(snapshot *rds.DBSnapshot, destination string) err
} else {
return err
}
} else {
log.WithFields(log.Fields{
"Snapshot": *snapshot.DBSnapshotIdentifier,
"From": config.AWSRegionSource,
"Destination": destination,
}).Info("Snapshot copied")
}
return nil
}

// PreSignUrl presigned an aws url so that we can copy an encrypted snapshot from a region to another
func (c *Client) PreSignUrl(destinationRegion, snapshotArn, kmsid, cleanArn string) (string, error) {
input := &rds.CopyDBSnapshotInput{
SourceRegion: aws.String(config.AWSRegionSource),
DestinationRegion: aws.String(destinationRegion),
SourceDBSnapshotIdentifier: aws.String(snapshotArn),
KmsKeyId: aws.String(kmsid),
TargetDBSnapshotIdentifier: aws.String(cleanArn),
}
req, _ := c.RDS.CopyDBSnapshotRequest(input)
url, err := req.Presign(time.Duration(5) * time.Minute)
if err != nil {
return "", err
}
return url, nil
}

// GetOldSnapshots gets old snapshots based on the retention policy
// retentionDays is a integer of the number of days we want to keep the snapshots.
func (c *Client) GetOldSnapshots(snapshots []*rds.DBSnapshot, retentionDays int) ([]*rds.DBSnapshot, error) {
Expand Down
66 changes: 64 additions & 2 deletions checks/commands_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package checks

import (
"net/http"
"net/url"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/aws/aws-sdk-go/service/rds/rdsiface"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -81,6 +84,11 @@ func (m *mockRDS) ModifyDBInstance(input *rds.ModifyDBInstanceInput) (*rds.Modif
return args.Get(0).(*rds.ModifyDBInstanceOutput), args.Error(1)
}

func (m *mockRDS) CopyDBSnapshotRequest(input *rds.CopyDBSnapshotInput) (*request.Request, *rds.CopyDBSnapshotOutput) {
args := m.Called(input)
return args.Get(0).(*request.Request), args.Get(1).(*rds.CopyDBSnapshotOutput)
}

func TestGetSnapshots(t *testing.T) {
rdsc := &mockRDS{}

Expand Down Expand Up @@ -109,7 +117,30 @@ func TestGetSnapshots(t *testing.T) {
rdsc.AssertExpectations(t)
}

func TestCopySnapshots(t *testing.T) {
func TestCopySnapshotsNoKms(t *testing.T) {
rdsc := &mockRDS{}

c := &Client{
RDS: rdsc,
}

input := &rds.DBSnapshot{
DBSnapshotIdentifier: aws.String("test"),
DBSnapshotArn: aws.String("arn:aws:rds:us-west-2:123456789012:snapshot:test"),
Encrypted: aws.Bool(false),
}

rdsc.On("CopyDBSnapshot", mock.Anything).Return(&rds.CopyDBSnapshotOutput{
DBSnapshot: &rds.DBSnapshot{},
}, nil)

err := c.CopySnapshots(input, "us-west-2", "", "", "test")
assert.Nil(t, err)
rdsc.AssertExpectations(t)

}

func TestCopySnapshotsWithKms(t *testing.T) {
rdsc := &mockRDS{}

c := &Client{
Expand All @@ -119,13 +150,15 @@ func TestCopySnapshots(t *testing.T) {
input := &rds.DBSnapshot{
DBSnapshotIdentifier: aws.String("test"),
DBSnapshotArn: aws.String("arn:aws:rds:us-west-2:123456789012:snapshot:test"),
KmsKeyId: aws.String("arn:aws:kms:us-east-1:1234567890:key/123456-7890-123456"),
Encrypted: aws.Bool(true),
}

rdsc.On("CopyDBSnapshot", mock.Anything).Return(&rds.CopyDBSnapshotOutput{
DBSnapshot: &rds.DBSnapshot{},
}, nil)

err := c.CopySnapshots(input, "us-west-2")
err := c.CopySnapshots(input, "us-west-2", "arn:aws:kms:us-east-1:1234567890:key/123456-7890-123456", "https://url.local", "test")
assert.Nil(t, err)
rdsc.AssertExpectations(t)

Expand Down Expand Up @@ -432,3 +465,32 @@ func TestGetTagValue(t *testing.T) {
assert.Equal(t, value, "restore")
rdsc.AssertExpectations(t)
}

func TestPreSignUrl(t *testing.T) {
rdsc := &mockRDS{}

c := &Client{
RDS: rdsc,
}

u := &url.URL{
Scheme: "http",
Host: "fakeurl.aws.com",
}

req := &request.Request{
HTTPRequest: &http.Request{
URL: u,
},
Operation: &request.Operation{},
}

output := &rds.CopyDBSnapshotOutput{}

rdsc.On("CopyDBSnapshotRequest", mock.Anything).Return(req, output)

value, err := c.PreSignUrl("us-east-2", "arn:aws:rds:us-west-2:123456789012:snapshot:test", "arn:aws:kms:us-east-1:1234567890:key/123456-7890-123456", "test")
assert.Nil(t, err)
assert.Equal(t, value, "http://fakeurl.aws.com")
rdsc.AssertExpectations(t)
}
17 changes: 14 additions & 3 deletions checks/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"io"
"io/ioutil"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
Expand All @@ -23,9 +24,9 @@ type DefaultChecks interface {
GetYamlFileFromS3(bucket, key string) (io.Reader, error)
UnmarshalYamlFile(body io.Reader) (Doc, error)
DataDogSession(apiKey, applicationKey string) *datadog.Client
PostDatadogChecks(snapshot *rds.DBSnapshot, metricName, status string) error
PostDatadogChecks(snapshot *rds.DBSnapshot, metricName, status, cmdName string) error
GetSnapshots(DBInstanceIdentifier string) ([]*rds.DBSnapshot, error)
CopySnapshots(snapshot *rds.DBSnapshot, destination string) error
CopySnapshots(snapshot *rds.DBSnapshot, destination, kmsid, preSignedUrl, cleanArn string) error
GetOldSnapshots(snapshots []*rds.DBSnapshot, retention int) ([]*rds.DBSnapshot, error)
DeleteOldSnapshots(snapshots []*rds.DBSnapshot) error
CheckIfDatabaseSubnetGroupExist(snapshot *rds.DBSnapshot) bool
Expand All @@ -41,6 +42,8 @@ type DefaultChecks interface {
GetTagValue(arn, key string) string
InitDb(db *rds.DBInstance, password, dbname string) error
CheckRegexAgainstRow(query, regex string) bool
PreSignUrl(destinationRegion, snapshotArn, kmsid, cleanArn string) (string, error)
CleanArn(snapshot *rds.DBSnapshot) string
}

type Client struct {
Expand All @@ -62,6 +65,7 @@ type Instances struct {
Password string
Retention int
Destination string
KmsID string
Queries []Queries
}

Expand Down Expand Up @@ -132,11 +136,12 @@ func (c *Client) UnmarshalYamlFile(body io.Reader) (Doc, error) {
}

// PostDatadogChecks posts to datadog the status of a check
func (c *Client) PostDatadogChecks(snapshot *rds.DBSnapshot, metricName, status string) error {
func (c *Client) PostDatadogChecks(snapshot *rds.DBSnapshot, metricName, status, cmdName string) error {

tags := []string{
"database:" + *snapshot.DBInstanceIdentifier,
"snapshot:" + *snapshot.DBSnapshotIdentifier,
"command" + cmdName,
}

timeNow := utils.GetUnixTimeAsString()
Expand All @@ -157,3 +162,9 @@ func (c *Client) PostDatadogChecks(snapshot *rds.DBSnapshot, metricName, status
}
return nil
}

func (c *Client) CleanArn(snapshot *rds.DBSnapshot) string {
arn := strings.SplitN(*snapshot.DBSnapshotArn, ":", 8)
cleanArn := arn[len(arn)-1]
return cleanArn
}
14 changes: 13 additions & 1 deletion checks/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func TestUnmarshalYamlFile(t *testing.T) {
Password: "thisisatest",
Retention: 1,
Destination: "us-east-1",
KmsID: "arn:aws:kms:us-east-1:1234567890:key/123456-7890-123456",
Queries: []Queries{
Queries{
Query: "SELECT tablename FROM pg_catalog.pg_tables;",
Expand Down Expand Up @@ -116,6 +117,17 @@ func TestPostDatadogChecks(t *testing.T) {
DBSnapshotIdentifier: aws.String("test"),
}

err := c.PostDatadogChecks(input, "rdscheck.status", "ok")
err := c.PostDatadogChecks(input, "rdscheck.status", "ok", "check")
assert.Nil(t, err)
}

func TestCleanArn(t *testing.T) {
c := &Client{}

input := &rds.DBSnapshot{
DBSnapshotArn: aws.String("arn:aws:rds:us-west-2:123456789012:snapshot:test"),
}

value := c.CleanArn(input)
assert.Equal(t, value, "test")
}
38 changes: 18 additions & 20 deletions cmd/check/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func process(destination checks.DefaultChecks, snapshot *rds.DBSnapshot, instanc
}

func caseReady(destination checks.DefaultChecks, snapshot *rds.DBSnapshot) error {
err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "ok")
err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "ok", "check")
if err != nil {
log.WithError(err).Error("Could not update datadog status")
return err
Expand Down Expand Up @@ -215,33 +215,31 @@ func caseVerify(destination checks.DefaultChecks, snapshot *rds.DBSnapshot, inst
return err
}

if instance.Name == *dbInfo.DBName {
for _, query := range instance.Queries {
if destination.CheckRegexAgainstRow(query.Query, query.Regex) {
err := destination.UpdateTag(snapshot, "Status", "clean")
if err != nil {
return err
}
} else {
log.WithFields(log.Fields{
"RDS Instance": string(*snapshot.DBInstanceIdentifier + "-" + *snapshot.DBSnapshotIdentifier),
"DB Name": *dbInfo.DBName,
"Query": query.Query,
"Regex": query.Regex,
}).Errorf("Query matched failed: %s", err)
errors := destination.UpdateTag(snapshot, "Status", "alarm")
if errors != nil {
return err
}
for _, query := range instance.Queries {
if destination.CheckRegexAgainstRow(query.Query, query.Regex) {
err := destination.UpdateTag(snapshot, "Status", "clean")
if err != nil {
return err
}
} else {
log.WithFields(log.Fields{
"RDS Instance": string(*snapshot.DBInstanceIdentifier + "-" + *snapshot.DBSnapshotIdentifier),
"DB Name": *dbInfo.DBName,
"Query": query.Query,
"Regex": query.Regex,
}).Errorf("Query matched failed: %s", err)
errors := destination.UpdateTag(snapshot, "Status", "alarm")
if errors != nil {
return err
}
return err
}
}
return nil
}

func caseAlarm(destination checks.DefaultChecks, snapshot *rds.DBSnapshot) error {
err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "critical")
err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "critical", "check")
if err != nil {
log.WithError(err).Error("Could not update datadog status")
return err
Expand Down
8 changes: 4 additions & 4 deletions cmd/check/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ func (m *mockDefaultChecks) GetTagValue(arn, key string) string {
return args.Get(0).(string)
}

func (m *mockDefaultChecks) PostDatadogChecks(snapshot *rds.DBSnapshot, metricName, status string) error {
args := m.Called(snapshot, metricName, status)
func (m *mockDefaultChecks) PostDatadogChecks(snapshot *rds.DBSnapshot, metricName, status, cmdName string) error {
args := m.Called(snapshot, metricName, status, cmdName)
return args.Error(0)
}

Expand Down Expand Up @@ -169,7 +169,7 @@ func TestGetDoc(t *testing.T) {
func TestCaseReady(t *testing.T) {
c := &mockDefaultChecks{}

c.On("PostDatadogChecks", mock.Anything, mock.Anything, mock.Anything).Return(nil)
c.On("PostDatadogChecks", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
c.On("CreateDatabaseSubnetGroup", mock.Anything, mock.Anything).Return(nil)
c.On("UpdateTag", mock.Anything, mock.Anything, mock.Anything).Return(nil)

Expand Down Expand Up @@ -223,7 +223,7 @@ func TestCaseVerify(t *testing.T) {
func TestCaseAlarm(t *testing.T) {
c := &mockDefaultChecks{}

c.On("PostDatadogChecks", mock.Anything, mock.Anything, mock.Anything).Return(nil)
c.On("PostDatadogChecks", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
c.On("UpdateTag", mock.Anything, mock.Anything, mock.Anything).Return(nil)

err := caseAlarm(c, singleSnapshot)
Expand Down
Loading

0 comments on commit 509b8dd

Please sign in to comment.