Skip to content

Commit

Permalink
Local storage changes for updates to schema.
Browse files Browse the repository at this point in the history
# Conflicts:
#	sdk/component/store/local/store.go
#	sdk/component/store/local/storer_test.go
  • Loading branch information
andream16 committed Dec 15, 2024
1 parent 8497c61 commit fc988ae
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 98 deletions.
2 changes: 1 addition & 1 deletion sdk/component/store/local/enum.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package localstore

type (
// ENUM(findings, instance_id, updated_at)
// ENUM(id, details, instance_id, updated_at)
localStoreColumnName string
)
9 changes: 6 additions & 3 deletions sdk/component/store/local/enum_enum.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

210 changes: 132 additions & 78 deletions sdk/component/store/local/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package localstore
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"log"
"os"
"strings"
"time"

"github.com/go-errors/errors"
Expand All @@ -16,7 +17,8 @@ import (
"github.com/smithy-security/smithy/sdk/component/internal/utils"
"github.com/smithy-security/smithy/sdk/component/store"
"github.com/smithy-security/smithy/sdk/component/uuid"
ocsf "github.com/smithy-security/smithy/sdk/gen/com/github/ocsf/ocsf_schema/v1"
vf "github.com/smithy-security/smithy/sdk/component/vulnerability-finding"
ocsf "github.com/smithy-security/smithy/sdk/gen/ocsf_schema/v1"
)

type (
Expand Down Expand Up @@ -97,9 +99,9 @@ func (m *manager) Validate(*ocsf.VulnerabilityFinding) error {

// Read finds Vulnerability Findings by instanceID.
// It returns ErrNoFindingsFound is not vulnerabilities were found.
func (m *manager) Read(ctx context.Context, instanceID uuid.UUID) ([]*ocsf.VulnerabilityFinding, error) {
func (m *manager) Read(ctx context.Context, instanceID uuid.UUID) ([]*vf.VulnerabilityFinding, error) {
stmt, err := m.db.PrepareContext(ctx, `
SELECT (findings)
SELECT id, details
FROM finding
WHERE instance_id = :instance_id
;
Expand All @@ -110,59 +112,85 @@ func (m *manager) Read(ctx context.Context, instanceID uuid.UUID) ([]*ocsf.Vulne

defer stmt.Close()

var jsonFindingsStr string
err = stmt.
QueryRowContext(
rows, err := stmt.
QueryContext(
ctx,
sql.Named(LocalStoreColumnNameInstanceId.String(), instanceID.String()),
).
Scan(&jsonFindingsStr)
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, errors.Errorf("%s: %w", instanceID.String(), store.ErrNoFindingsFound)
}
return nil, errors.Errorf("could not select findings: %w", err)
}

var jsonFindings []json.RawMessage
if err := json.Unmarshal([]byte(jsonFindingsStr), &jsonFindings); err != nil {
return nil, errors.Errorf("could not unmarshal json findings to []json.RawMessage: %w", err)
}
var findings []*vf.VulnerabilityFinding
for rows.Next() {
var (
findingID uint64
jsonFindingDetails string
)

if err := rows.Scan(&findingID, &jsonFindingDetails); err != nil {
return nil, errors.Errorf("could not scan row: %w", err)
}

var findings []*ocsf.VulnerabilityFinding
for _, jsonFinding := range jsonFindings {
var finding ocsf.VulnerabilityFinding
if err := protojson.Unmarshal(jsonFinding, &finding); err != nil {
return nil, errors.Errorf("failed to unmarshal JSON findings to *ocsf.VulnerabilityFinding: %w", err)
if err := protojson.Unmarshal([]byte(jsonFindingDetails), &finding); err != nil {
return nil, errors.Errorf(
"failed to unmarshal JSON findings to *ocsf.VulnerabilityFinding for id %d: %w",
findingID,
err,
)
}
findings = append(findings, &finding)

findings = append(findings, &vf.VulnerabilityFinding{
ID: findingID,
Finding: &finding,
})
}

if err := rows.Err(); err != nil {
return nil, errors.Errorf("could not scan rows: %w", err)
}

if len(findings) == 0 {
return nil, store.ErrNoFindingsFound
}

return findings, nil
}

// Write writes new vulnerabilities in JSON format in the database.
func (m *manager) Write(ctx context.Context, instanceID uuid.UUID, findings []*ocsf.VulnerabilityFinding) error {
jsonFindings, err := m.marshalFindings(findings)
if err != nil {
return err
var (
placeHolders strings.Builder
values []any
)

for i, finding := range findings {
placeHolders.WriteString("(?, ?)")
if i != len(findings)-1 {
placeHolders.WriteString(",")
}

b, err := protojson.Marshal(finding)
if err != nil {
return errors.Errorf("could not marshal JSON finding: %w", err)
}

values = append(values, instanceID.String(), string(b))
}

stmt, err := m.db.PrepareContext(ctx, `
INSERT INTO finding (instance_id, findings)
VALUES (:instance_id, :findings)
;
`)
query := fmt.Sprintf("INSERT INTO finding (instance_id, details) VALUES %s;", placeHolders.String())
stmt, err := m.db.PrepareContext(ctx, query)
if err != nil {
return errors.Errorf("could not prepare write statement: %w", err)
}

defer stmt.Close()

if _, err = stmt.Exec(
sql.Named(LocalStoreColumnNameInstanceId.String(), instanceID.String()),
sql.Named(LocalStoreColumnNameFindings.String(), jsonFindings),
); err != nil {
if _, err = stmt.Exec(values...); err != nil {
return errors.Errorf("could not insert findings: %w", err)
}

Expand All @@ -171,46 +199,90 @@ func (m *manager) Write(ctx context.Context, instanceID uuid.UUID, findings []*o

// Update updates existing vulnerabilities in the underlying database.
// It returns ErrNoFindingsFound if the passed instanceID is not found.
func (m *manager) Update(ctx context.Context, instanceID uuid.UUID, findings []*ocsf.VulnerabilityFinding) error {
jsonFindings, err := m.marshalFindings(findings)
func (m *manager) Update(ctx context.Context, instanceID uuid.UUID, findings []*vf.VulnerabilityFinding) error {
tx, err := m.db.BeginTx(ctx, nil)
if err != nil {
return err
return errors.Errorf("could not start update transaction: %w", err)
}

stmt, err := m.db.PrepareContext(ctx, `
rollback := func(tx *sql.Tx, err error) error {
if txErr := tx.Rollback(); txErr != nil {
return errors.Errorf("could not rollback transaction for error %w: %w", err, txErr)
}
return errors.Errorf("unexpected update error, rolled back: %w", err)
}

defer func() {
if err := tx.Rollback(); err != nil {
// TODO: replace with logger.
log.Printf("failed to rollback update transaction: %s", err)
}
}()

for _, finding := range findings {
b, err := protojson.Marshal(finding.Finding)
if err != nil {
return rollback(
tx,
errors.Errorf("could not marshal JSON finding: %w", err),
)
}

stmt, err := tx.PrepareContext(ctx, `
UPDATE finding
SET
findings = :findings,
details = :details,
updated_at = :updated_at
WHERE
WHERE
id = :id AND
instance_id = :instance_id
;
`)
if err != nil {
return errors.Errorf("could not prepare update statement: %w", err)
}
`)
if err != nil {
return rollback(
tx,
errors.Errorf("could not prepare update statement: %w", err),
)
}

defer stmt.Close()
defer stmt.Close()

res, err := stmt.Exec(
sql.Named(LocalStoreColumnNameInstanceId.String(), instanceID.String()),
sql.Named(LocalStoreColumnNameUpdatedAt.String(), m.clock.Now().UTC().Format(time.RFC3339)),
sql.Named(LocalStoreColumnNameFindings.String(), jsonFindings),
)
if err != nil {
return errors.Errorf("could not update findings: %w", err)
res, err := stmt.Exec(
sql.Named(LocalStoreColumnNameInstanceId.String(), instanceID.String()),
sql.Named(LocalStoreColumnNameUpdatedAt.String(), m.clock.Now().UTC().Format(time.RFC3339)),
sql.Named(LocalStoreColumnNameDetails.String(), string(b)),
sql.Named(LocalStoreColumnNameId.String(), finding.ID),
)
if err != nil {
return rollback(
tx,
errors.Errorf("could not update findings: %w", err),
)
}

r, err := res.RowsAffected()
switch {
case err != nil:
return rollback(
tx,
errors.Errorf(
"could not get rows affected for finding with id %d: %w", finding.ID, err),
)
case r <= 0:
return rollback(
tx,
errors.Errorf(
"could not update findings for instance '%s' with id %d: %w",
instanceID.String(),
finding.ID,
store.ErrNoFindingsFound,
),
)
}
}

r, err := res.RowsAffected()
switch {
case err != nil:
return errors.Errorf("could not get rows affected: %w", err)
case r <= 0:
return errors.Errorf(
"could not update findings for instance '%s': %w",
instanceID.String(),
store.ErrNoFindingsFound,
)
if err := tx.Commit(); err != nil {
return rollback(tx, errors.Errorf("could not commit update transaction: %w", err))
}

return nil
Expand All @@ -233,8 +305,8 @@ func (m *manager) migrate() error {
stmt, err := m.db.Prepare(`
CREATE TABLE IF NOT EXISTS finding (
id INTEGER PRIMARY KEY AUTOINCREMENT UNIQUE,
instance_id UUID NOT NULL UNIQUE,
findings TEXT NOT NULL,
instance_id UUID NOT NULL,
details TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
Expand All @@ -249,21 +321,3 @@ func (m *manager) migrate() error {

return stmt.Close()
}

func (m *manager) marshalFindings(findings []*ocsf.VulnerabilityFinding) (string, error) {
var rawFindings []json.RawMessage
for _, finding := range findings {
b, err := protojson.Marshal(finding)
if err != nil {
return "", errors.Errorf("could not json marshal finding: %w", err)
}
rawFindings = append(rawFindings, b)
}

jsonFindings, err := json.Marshal(rawFindings)
if err != nil {
return "", errors.Errorf("could not json marshal findings: %w", err)
}

return string(jsonFindings), nil
}
Loading

0 comments on commit fc988ae

Please sign in to comment.