Skip to content

Commit

Permalink
make cleanup works
Browse files Browse the repository at this point in the history
- move cleanup to a seperate function
- move getDoc to a seperate function
- makes cleanup works properly
  • Loading branch information
matthieudolci committed Dec 5, 2019
1 parent a3d747f commit 079b706
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 50 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
- Creates new rds instance(s) with the snapshots
- Runs a set of queries on the database to validate the content of the backup

## TODO

- Handle things gracefully when there is more than 5 snapshots to copy
- Handle different retentions between automatic and manual backups. (tag automatic snapshot with something like "CopiedBy" "rdscheck" and skip if set)

## check: state machine diagram

Expand All @@ -18,7 +22,7 @@
- database: `the name of the databse that we copied and restored we use this field to initiate the db connection`
- type: `the rds instance type we want to use to restore the snapshot`
- password: `the password that we will use to connect to the database. It doesn't need to be the original one. We will use this one to reset the original password`
- retention: `how many days we want to keep the copied snapshot around`
- retention: `how many days we want to keep the copied snapshot around. Right now it should be equal to the number of days the automatic backups are kept`
- destination: `the aws region where we will copy/restore the snapshot`
- kmsid: `the id (ARN) of the kms key that you want to use on the destination region. This is needed if your original snapshot is encrypted`
- queries: `all the sql queries we want to run on the restored snapshot to validate it and the expected results as regex`
Expand Down Expand Up @@ -60,7 +64,7 @@ By doing so we can then download the command zip file for a release and use it w
```hcl
module "rdscheck-copy" {
source = "github.com/techdroplabs/rdscheck//terraform?ref=v0.0.8"
source = "github.com/techdroplabs/rdscheck//terraform?ref=v0.0.9"
release_version = "v0.0.8"
command = "copy"
Expand All @@ -80,7 +84,7 @@ module "rdscheck-copy" {
```hcl
module "rdscheck-check" {
source = "github.com/techdroplabs/rdscheck//terraform?ref=v0.0.8"
source = "github.com/techdroplabs/rdscheck//terraform?ref=v0.0.9"
lambda_rate = "rate(30 minutes)"
release_version = "v0.0.8"
Expand Down
34 changes: 13 additions & 21 deletions checks/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,29 +134,21 @@ func (c *Client) GetOldSnapshots(snapshots []*rds.DBSnapshot, retentionDays int)
return oldSnapshots, nil
}

// DeleteOldSnapshots deletes snapshots returned by GetOldSnapshots
func (c *Client) DeleteOldSnapshots(snapshots []*rds.DBSnapshot) error {
for _, s := range snapshots {
if s.DBSnapshotIdentifier == nil {
log.Info("No old snapshots to delete")
break
}

input := &rds.DeleteDBSnapshotInput{
DBSnapshotIdentifier: aws.String(*s.DBSnapshotIdentifier),
}
// DeleteOldSnapshot deletes snapshots returned by GetOldSnapshots
func (c *Client) DeleteOldSnapshot(snapshot *rds.DBSnapshot) error {
input := &rds.DeleteDBSnapshotInput{
DBSnapshotIdentifier: aws.String(*snapshot.DBSnapshotIdentifier),
}

_, err := c.RDS.DeleteDBSnapshot(input)
if err != nil {
return err
} else {
log.WithFields(log.Fields{
"Snapshot": *s.DBSnapshotIdentifier,
}).Info("Snapshot deleted")
return nil
}
_, err := c.RDS.DeleteDBSnapshot(input)
if err != nil {
return err
} else {
log.WithFields(log.Fields{
"Snapshot": *snapshot.DBSnapshotIdentifier,
}).Info("Snapshot deleted")
return nil
}
return nil
}

// CheckIfDatabaseSubnetGroupExist return true if the Subnet Group already exist
Expand Down
13 changes: 4 additions & 9 deletions checks/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,27 +193,22 @@ func TestGetOldSnapshots(t *testing.T) {
rdsc.AssertExpectations(t)
}

func TestDeleteOldSnapshots(t *testing.T) {
func TestDeleteOldSnapshot(t *testing.T) {
rdsc := &mockRDS{}

c := &Client{
RDS: rdsc,
}

input := []*rds.DBSnapshot{
&rds.DBSnapshot{
DBSnapshotIdentifier: aws.String("old-test-1"),
},
&rds.DBSnapshot{
DBSnapshotIdentifier: aws.String("old-test-2"),
},
input := &rds.DBSnapshot{
DBSnapshotIdentifier: aws.String("old-test-1"),
}

rdsc.On("DeleteDBSnapshot", mock.Anything).Return(&rds.DeleteDBSnapshotOutput{
DBSnapshot: &rds.DBSnapshot{},
}, nil)

err := c.DeleteOldSnapshots(input)
err := c.DeleteOldSnapshot(input)
assert.Nil(t, err)
rdsc.AssertExpectations(t)
}
Expand Down
2 changes: 1 addition & 1 deletion checks/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type DefaultChecks interface {
GetSnapshots(DBInstanceIdentifier string) ([]*rds.DBSnapshot, error)
CopySnapshots(snapshot *rds.DBSnapshot, destination, kmsid, preSignedUrl, cleanArn string) error
GetOldSnapshots(snapshots []*rds.DBSnapshot, retention int) ([]*rds.DBSnapshot, error)
DeleteOldSnapshots(snapshots []*rds.DBSnapshot) error
DeleteOldSnapshot(snapshot *rds.DBSnapshot) error
CheckIfDatabaseSubnetGroupExist(snapshot *rds.DBSnapshot) bool
CreateDatabaseSubnetGroup(snapshot *rds.DBSnapshot, subnetids []string) error
CreateDBFromSnapshot(snapshot *rds.DBSnapshot, instancetype string, vpcsecuritygroupids []string) error
Expand Down
70 changes: 59 additions & 11 deletions cmd/copy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,51 @@ func run() {
source := checks.New()
destination := checks.New()

err := copy(source, destination)
doc, err := getDoc(source)
if err != nil {
log.WithError(err).Error("Run returned:")
log.WithError(err).Error("getDoc returned:")
os.Exit(1)
}

err = copy(source, destination, doc)
if err != nil {
log.WithError(err).Error("copy returned:")
os.Exit(1)
}

err = clean(destination, doc)
if err != nil {
log.WithError(err).Error("clean returned:")
os.Exit(1)
}
}

func copy(source checks.DefaultChecks, destination checks.DefaultChecks) error {
func getDoc(source checks.DefaultChecks) (checks.Doc, error) {
source.SetSessions(config.AWSRegionSource)

doc := checks.Doc{}

yaml, err := source.GetYamlFileFromS3(config.S3Bucket, config.S3Key)
if err != nil {
log.WithError(err).Error("Could not get the yaml file from s3")
return err
return doc, err
}

doc, err := source.UnmarshalYamlFile(yaml)
doc, err = source.UnmarshalYamlFile(yaml)
if err != nil {
log.WithError(err).Error("Could not unmarshal yaml file")
return err
return doc, err
}

return doc, nil
}

func copy(source checks.DefaultChecks, destination checks.DefaultChecks, doc checks.Doc) error {
source.SetSessions(config.AWSRegionSource)

for _, instance := range doc.Instances {
destination.SetSessions(instance.Destination)

snapshots, err := source.GetSnapshots(instance.Name)
if err != nil {
log.WithFields(log.Fields{
Expand All @@ -49,21 +70,25 @@ func copy(source checks.DefaultChecks, destination checks.DefaultChecks) error {
}).WithError(err).Error("Could not get snapshots")
return err
}

for _, snapshot := range snapshots {
if *snapshot.SnapshotType == "automated" {
err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "ok", "copy")
if err != nil {
log.WithError(err).Error("Could not update datadog status")
return err
}

var preSignedUrl string
cleanArn := destination.CleanArn(snapshot)

if *snapshot.Encrypted {
preSignedUrl, err = source.PreSignUrl(instance.Destination, *snapshot.DBSnapshotArn, instance.KmsID, cleanArn)
if err != nil {
log.WithFields(log.Fields{
"snapshot": *snapshot.DBSnapshotIdentifier,
}).WithError(err).Error("Could not presigned the url")

err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "critical", "copy")
if err != nil {
log.WithError(err).Error("Could not update datadog status")
Expand All @@ -72,11 +97,13 @@ func copy(source checks.DefaultChecks, destination checks.DefaultChecks) error {
return err
}
}

err = destination.CopySnapshots(snapshot, instance.Destination, instance.KmsID, preSignedUrl, cleanArn)
if err != nil {
log.WithFields(log.Fields{
"Snapshot": *snapshot.DBSnapshotIdentifier,
}).WithError(err).Error("Could not copy snapshot")

err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "critical", "copy")
if err != nil {
log.WithError(err).Error("Could not update datadog status")
Expand All @@ -86,26 +113,47 @@ func copy(source checks.DefaultChecks, destination checks.DefaultChecks) error {
}
}
}
}
return nil
}

func clean(destination checks.DefaultChecks, doc checks.Doc) error {
for _, instance := range doc.Instances {
destination.SetSessions(instance.Destination)

snapshots, err = destination.GetSnapshots(instance.Name)
snapshots, err := destination.GetSnapshots(instance.Name)
if err != nil {
log.WithFields(log.Fields{
"RDS Instance": instance.Name,
"AWS Region": instance.Destination,
}).WithError(err).Error("Could not get snapshots")
return err
}
for _, snapshot := range snapshots {

oldSnapshots, err := destination.GetOldSnapshots(snapshots, instance.Retention)
if err != nil {
log.WithError(err).Error("Could not get old snapshots")
return err
}

for _, snapshot := range oldSnapshots {
if destination.CheckTag(*snapshot.DBSnapshotArn, "CreatedBy", "rdscheck") {
oldSnapshots, err := destination.GetOldSnapshots(snapshots, instance.Retention)

err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "ok", "copy")
if err != nil {
log.WithError(err).Error("Could not get old snapshots")
log.WithError(err).Error("Could not update datadog status")
return err
}

err = destination.DeleteOldSnapshots(oldSnapshots)
err = destination.DeleteOldSnapshot(snapshot)
if err != nil {
log.WithError(err).Error("Could not delete old snapshots")

err := destination.PostDatadogChecks(snapshot, "rdscheck.status", "critical", "copy")
if err != nil {
log.WithError(err).Error("Could not update datadog status")
return err
}
return err
}
}
Expand Down
29 changes: 24 additions & 5 deletions cmd/copy/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ func (m *mockDefaultChecks) GetOldSnapshots(snapshots []*rds.DBSnapshot, retenti
return args.Get(0).([]*rds.DBSnapshot), args.Error(1)
}

func (m *mockDefaultChecks) DeleteOldSnapshots(snapshots []*rds.DBSnapshot) error {
args := m.Called(snapshots)
func (m *mockDefaultChecks) DeleteOldSnapshot(snapshot *rds.DBSnapshot) error {
args := m.Called(snapshot)
return args.Error(0)
}

Expand Down Expand Up @@ -108,7 +108,7 @@ func (m *mockDefaultChecks) PostDatadogChecks(snapshot *rds.DBSnapshot, metricNa
return args.Error(0)
}

func TestCopy(t *testing.T) {
func TestGetDoc(t *testing.T) {
c := &mockDefaultChecks{}

yaml, _ := ioutil.ReadFile("../../example/checks.yml")
Expand All @@ -117,16 +117,35 @@ func TestCopy(t *testing.T) {
c.On("SetSessions", mock.Anything).Return()
c.On("GetYamlFileFromS3", mock.Anything, mock.Anything).Return(input, nil)
c.On("UnmarshalYamlFile", mock.Anything).Return(doc, nil)
}

func TestCopy(t *testing.T) {
c := &mockDefaultChecks{}

c.On("SetSessions", mock.Anything).Return()
c.On("GetSnapshots", mock.Anything).Return(snapshots, nil)
c.On("PostDatadogChecks", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
c.On("CleanArn", mock.Anything).Return("test")
c.On("PreSignUrl", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return("https://url.local", nil)
c.On("CopySnapshots", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)

err := copy(c, c, doc)

assert.Nil(t, err)
c.AssertExpectations(t)
}

func TestClean(t *testing.T) {
c := &mockDefaultChecks{}

c.On("SetSessions", mock.Anything).Return()
c.On("GetSnapshots", mock.Anything).Return(snapshots, nil)
c.On("PostDatadogChecks", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
c.On("CheckTag", mock.Anything, mock.Anything, mock.Anything).Return(true)
c.On("GetOldSnapshots", mock.Anything, mock.Anything).Return(snapshots, nil)
c.On("DeleteOldSnapshots", mock.Anything).Return(nil)
c.On("DeleteOldSnapshot", mock.Anything).Return(nil)

err := copy(c, c)
err := clean(c, doc)

assert.Nil(t, err)
c.AssertExpectations(t)
Expand Down

0 comments on commit 079b706

Please sign in to comment.