Skip to content

Commit

Permalink
Implementing psql store.
Browse files Browse the repository at this point in the history
  • Loading branch information
andream16 committed Dec 15, 2024
1 parent 42b01f9 commit 5c1784b
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 36 deletions.
93 changes: 67 additions & 26 deletions sdk/component/store/remote/postgresql/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"log"

"github.com/go-errors/errors"
"github.com/jackc/pgx/v5"
Expand All @@ -16,6 +17,7 @@ import (
"github.com/smithy-security/smithy/sdk/component/store/remote/postgresql/sqlc"
_ "github.com/smithy-security/smithy/sdk/component/store/remote/postgresql/sqlc/migrations"
"github.com/smithy-security/smithy/sdk/component/uuid"
vf "github.com/smithy-security/smithy/sdk/component/vulnerability-finding"
ocsf "github.com/smithy-security/smithy/sdk/gen/ocsf_schema/v1"
)

Expand Down Expand Up @@ -99,20 +101,19 @@ func (m *manager) Validate(finding *ocsf.VulnerabilityFinding) error {

// Read finds Vulnerability Findings by instanceID.
// It returns ErrNoFindingsFound is not vulnerabilities were found.
func (m *manager) Read(ctx context.Context, instanceID uuid.UUID) ([]*ocsf.VulnerabilityFinding, error) {
rawFindings, err := m.queries.FindingsByID(ctx, m.newPgUUID(instanceID))
func (m *manager) Read(ctx context.Context, instanceID uuid.UUID) ([]*vf.VulnerabilityFinding, error) {
rows, err := m.queries.FindingsByID(ctx, m.newPgUUID(instanceID))
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, store.ErrNoFindingsFound
}
return nil, errors.Errorf("failed to read findings: %w", err)
}

var findings = make([]*ocsf.VulnerabilityFinding, 0, len(rawFindings))

for _, rawFinding := range rawFindings {
var findings = make([]*vf.VulnerabilityFinding, 0, len(rows))
for _, row := range rows {
var jsonFinding json.RawMessage
if err := json.Unmarshal(rawFinding.Findings, &jsonFinding); err != nil {
if err := json.Unmarshal(row.Details, &jsonFinding); err != nil {
return nil, errors.Errorf("could not unmarshal json findings to json.RawMessage: %w", err)
}

Expand All @@ -121,46 +122,86 @@ func (m *manager) Read(ctx context.Context, instanceID uuid.UUID) ([]*ocsf.Vulne
return nil, errors.Errorf("failed to unmarshal JSON findings to *ocsf.VulnerabilityFinding: %w", err)
}

findings = append(findings, &finding)
findings = append(findings, &vf.VulnerabilityFinding{
ID: uint64(row.ID),
Finding: &finding,
})
}

if len(findings) == 0 {
return nil, store.ErrNoFindingsFound
}

return findings, nil
}

// Write writes new vulnerabilities in JSON format in the database.
func (m *manager) Write(ctx context.Context, instanceID uuid.UUID, findings []*ocsf.VulnerabilityFinding) error {
jsonFindings, err := store.JSONMarshalFindings(findings)
if err != nil {
return err
var createFindingsReq = sqlc.CreateFindingsParams{
DetailsArray: make([][]byte, 0, len(findings)),
}

if _, err := m.queries.CreateFinding(ctx, sqlc.CreateFindingParams{
InstanceID: m.newPgUUID(instanceID),
Findings: []byte(jsonFindings),
}); err != nil {
for _, finding := range findings {
jsonFinding, err := protojson.Marshal(finding)
if err != nil {
return errors.Errorf("could not json marshal finding: %w", err)
}
createFindingsReq.InstanceIDArray = append(createFindingsReq.InstanceIDArray, m.newPgUUID(instanceID))
createFindingsReq.DetailsArray = append(createFindingsReq.DetailsArray, jsonFinding)
}

if err := m.queries.CreateFindings(ctx, createFindingsReq); err != nil {
return errors.Errorf("failed to write findings: %w", err)
}

return nil
}

// Update updates existing vulnerabilities in the underlying database.
// It returns ErrNoFindingsFound if the passed instanceID is not found.
func (m *manager) Update(ctx context.Context, instanceID uuid.UUID, findings []*ocsf.VulnerabilityFinding) error {
jsonFindings, err := store.JSONMarshalFindings(findings)
func (m *manager) Update(ctx context.Context, instanceID uuid.UUID, findings []*vf.VulnerabilityFinding) error {
tx, err := m.conn.Begin(ctx)
if err != nil {
return err
return errors.Errorf("failed to begin update transaction: %w", err)
}

if _, err := m.queries.UpdateFinding(ctx, sqlc.UpdateFindingParams{
Findings: []byte(jsonFindings),
UpdatedAt: pgtype.Timestamp{Time: m.clock.Now()},
InstanceID: m.newPgUUID(instanceID),
}); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return store.ErrNoFindingsFound
rollback := func(tx pgx.Tx, err error) error {
if txErr := tx.Rollback(ctx); txErr != nil {
return errors.Errorf("failed to rollback transaction for error %w: %w", err, txErr)
}
return errors.Errorf("failed to update findings: %w", err)
return errors.Errorf("rolledback transaction for error: %w", err)
}

defer func() {
if err := tx.Rollback(ctx); err != nil && !errors.Is(err, pgx.ErrTxClosed) {
// TODO: replace with logger.
log.Printf("failed to rollback update transaction: %s", err.Error())
}
}()

for _, finding := range findings {
jsonFinding, err := protojson.Marshal(finding.Finding)
if err != nil {
return rollback(tx, errors.Errorf("could not json marshal finding: %w", err))
}

if err := m.queries.WithTx(tx).UpdateFinding(ctx, sqlc.UpdateFindingParams{
InstanceID: m.newPgUUID(instanceID),
ID: int32(finding.ID),
Details: jsonFinding,
}); err != nil {
return rollback(
tx,
errors.Errorf(
"could not update finding %d: %w",
finding.ID,
err,
),
)
}
}

if err := tx.Commit(ctx); err != nil {
return rollback(tx, errors.Errorf("failed to commit update transaction: %w", err))
}

return nil
Expand Down
30 changes: 20 additions & 10 deletions sdk/component/store/remote/postgresql/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/smithy-security/smithy/sdk/component/store"
"github.com/smithy-security/smithy/sdk/component/store/remote/postgresql"
"github.com/smithy-security/smithy/sdk/component/uuid"
vf "github.com/smithy-security/smithy/sdk/component/vulnerability-finding"
ocsf "github.com/smithy-security/smithy/sdk/gen/ocsf_schema/v1"
)

Expand Down Expand Up @@ -265,15 +266,11 @@ func (suite *ManagerTestSuite) TestManager() {
resFindings, err := suite.manager.Read(ctx, instanceID)
require.NoError(t, err)
require.Len(t, resFindings, 2)
assert.EqualValues(t, findings, resFindings)
})

suite.T().Run("given a non existing instance id in the database, updating should fail", func(t *testing.T) {
require.ErrorIs(
t,
suite.manager.Update(ctx, uuid.New(), findings),
store.ErrNoFindingsFound,
)
assert.Equal(t, uint64(1), resFindings[0].ID)
assert.Equal(t, findings[0], resFindings[0].Finding)
assert.Equal(t, uint64(2), resFindings[1].ID)
assert.Equal(t, findings[1], resFindings[1].Finding)
})

suite.T().Run(
Expand All @@ -288,13 +285,26 @@ func (suite *ManagerTestSuite) TestManager() {

require.NoError(
t,
suite.manager.Update(ctx, instanceID, copyFindings),
suite.manager.Update(ctx, instanceID, []*vf.VulnerabilityFinding{
{
ID: 1,
Finding: copyFindings[0],
},
{
ID: 2,
Finding: copyFindings[1],
},
}),
)

resFindings, err := suite.manager.Read(ctx, instanceID)
require.NoError(t, err)
require.Len(t, resFindings, 2)
assert.EqualValues(t, copyFindings, resFindings)

assert.Equal(t, uint64(1), resFindings[0].ID)
assert.Equal(t, copyFindings[0], resFindings[0].Finding)
assert.Equal(t, uint64(2), resFindings[1].ID)
assert.Equal(t, copyFindings[1], resFindings[1].Finding)
})
}

Expand Down

0 comments on commit 5c1784b

Please sign in to comment.