diff --git a/sdk/component/store/local/enum.go b/sdk/component/store/local/enum.go index 49375d6d1..f2858df14 100644 --- a/sdk/component/store/local/enum.go +++ b/sdk/component/store/local/enum.go @@ -1,6 +1,6 @@ package localstore type ( - // ENUM(findings, instance_id, updated_at) + // ENUM(id, details, instance_id, updated_at) localStoreColumnName string ) diff --git a/sdk/component/store/local/enum_enum.go b/sdk/component/store/local/enum_enum.go index 1c79e08cf..a01cb1f8b 100644 --- a/sdk/component/store/local/enum_enum.go +++ b/sdk/component/store/local/enum_enum.go @@ -12,8 +12,10 @@ import ( ) const ( - // LocalStoreColumnNameFindings is a localStoreColumnName of type findings. - LocalStoreColumnNameFindings localStoreColumnName = "findings" + // LocalStoreColumnNameId is a localStoreColumnName of type id. + LocalStoreColumnNameId localStoreColumnName = "id" + // LocalStoreColumnNameDetails is a localStoreColumnName of type details. + LocalStoreColumnNameDetails localStoreColumnName = "details" // LocalStoreColumnNameInstanceId is a localStoreColumnName of type instance_id. LocalStoreColumnNameInstanceId localStoreColumnName = "instance_id" // LocalStoreColumnNameUpdatedAt is a localStoreColumnName of type updated_at. @@ -35,7 +37,8 @@ func (x localStoreColumnName) IsValid() bool { } var _localStoreColumnNameValue = map[string]localStoreColumnName{ - "findings": LocalStoreColumnNameFindings, + "id": LocalStoreColumnNameId, + "details": LocalStoreColumnNameDetails, "instance_id": LocalStoreColumnNameInstanceId, "updated_at": LocalStoreColumnNameUpdatedAt, } diff --git a/sdk/component/store/local/store.go b/sdk/component/store/local/store.go index fe98043a2..b579b1798 100644 --- a/sdk/component/store/local/store.go +++ b/sdk/component/store/local/store.go @@ -3,9 +3,10 @@ package localstore import ( "context" "database/sql" - "encoding/json" "fmt" + "log" "os" + "strings" "time" "github.com/go-errors/errors" @@ -16,6 +17,7 @@ import ( "github.com/smithy-security/smithy/sdk/component/internal/utils" "github.com/smithy-security/smithy/sdk/component/store" "github.com/smithy-security/smithy/sdk/component/uuid" + vf "github.com/smithy-security/smithy/sdk/component/vulnerability-finding" ocsf "github.com/smithy-security/smithy/sdk/gen/ocsf_schema/v1" ) @@ -97,9 +99,9 @@ func (m *manager) Validate(*ocsf.VulnerabilityFinding) error { // Read finds Vulnerability Findings by instanceID. // It returns ErrNoFindingsFound is not vulnerabilities were found. -func (m *manager) Read(ctx context.Context, instanceID uuid.UUID) ([]*ocsf.VulnerabilityFinding, error) { +func (m *manager) Read(ctx context.Context, instanceID uuid.UUID) ([]*vf.VulnerabilityFinding, error) { stmt, err := m.db.PrepareContext(ctx, ` - SELECT (findings) + SELECT id, details FROM finding WHERE instance_id = :instance_id ; @@ -110,13 +112,11 @@ func (m *manager) Read(ctx context.Context, instanceID uuid.UUID) ([]*ocsf.Vulne defer stmt.Close() - var jsonFindingsStr string - err = stmt. - QueryRowContext( + rows, err := stmt. + QueryContext( ctx, sql.Named(LocalStoreColumnNameInstanceId.String(), instanceID.String()), - ). - Scan(&jsonFindingsStr) + ) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, errors.Errorf("%s: %w", instanceID.String(), store.ErrNoFindingsFound) @@ -124,18 +124,38 @@ func (m *manager) Read(ctx context.Context, instanceID uuid.UUID) ([]*ocsf.Vulne return nil, errors.Errorf("could not select findings: %w", err) } - var jsonFindings []json.RawMessage - if err := json.Unmarshal([]byte(jsonFindingsStr), &jsonFindings); err != nil { - return nil, errors.Errorf("could not unmarshal json findings to []json.RawMessage: %w", err) - } + var findings []*vf.VulnerabilityFinding + for rows.Next() { + var ( + findingID uint64 + jsonFindingDetails string + ) + + if err := rows.Scan(&findingID, &jsonFindingDetails); err != nil { + return nil, errors.Errorf("could not scan row: %w", err) + } - var findings []*ocsf.VulnerabilityFinding - for _, jsonFinding := range jsonFindings { var finding ocsf.VulnerabilityFinding - if err := protojson.Unmarshal(jsonFinding, &finding); err != nil { - return nil, errors.Errorf("failed to unmarshal JSON findings to *ocsf.VulnerabilityFinding: %w", err) + if err := protojson.Unmarshal([]byte(jsonFindingDetails), &finding); err != nil { + return nil, errors.Errorf( + "failed to unmarshal JSON findings to *ocsf.VulnerabilityFinding for id %d: %w", + findingID, + err, + ) } - findings = append(findings, &finding) + + findings = append(findings, &vf.VulnerabilityFinding{ + ID: findingID, + Finding: &finding, + }) + } + + if err := rows.Err(); err != nil { + return nil, errors.Errorf("could not scan rows: %w", err) + } + + if len(findings) == 0 { + return nil, store.ErrNoFindingsFound } return findings, nil @@ -143,26 +163,34 @@ func (m *manager) Read(ctx context.Context, instanceID uuid.UUID) ([]*ocsf.Vulne // Write writes new vulnerabilities in JSON format in the database. func (m *manager) Write(ctx context.Context, instanceID uuid.UUID, findings []*ocsf.VulnerabilityFinding) error { - jsonFindings, err := store.JSONMarshalFindings(findings) - if err != nil { - return err + var ( + placeHolders strings.Builder + values []any + ) + + for i, finding := range findings { + placeHolders.WriteString("(?, ?)") + if i != len(findings)-1 { + placeHolders.WriteString(",") + } + + b, err := protojson.Marshal(finding) + if err != nil { + return errors.Errorf("could not marshal JSON finding: %w", err) + } + + values = append(values, instanceID.String(), string(b)) } - stmt, err := m.db.PrepareContext(ctx, ` - INSERT INTO finding (instance_id, findings) - VALUES (:instance_id, :findings) - ; - `) + query := fmt.Sprintf("INSERT INTO finding (instance_id, details) VALUES %s;", placeHolders.String()) + stmt, err := m.db.PrepareContext(ctx, query) if err != nil { return errors.Errorf("could not prepare write statement: %w", err) } defer stmt.Close() - if _, err = stmt.Exec( - sql.Named(LocalStoreColumnNameInstanceId.String(), instanceID.String()), - sql.Named(LocalStoreColumnNameFindings.String(), jsonFindings), - ); err != nil { + if _, err = stmt.Exec(values...); err != nil { return errors.Errorf("could not insert findings: %w", err) } @@ -171,46 +199,90 @@ func (m *manager) Write(ctx context.Context, instanceID uuid.UUID, findings []*o // Update updates existing vulnerabilities in the underlying database. // It returns ErrNoFindingsFound if the passed instanceID is not found. -func (m *manager) Update(ctx context.Context, instanceID uuid.UUID, findings []*ocsf.VulnerabilityFinding) error { - jsonFindings, err := store.JSONMarshalFindings(findings) +func (m *manager) Update(ctx context.Context, instanceID uuid.UUID, findings []*vf.VulnerabilityFinding) error { + tx, err := m.db.BeginTx(ctx, nil) if err != nil { - return err + return errors.Errorf("could not start update transaction: %w", err) } - stmt, err := m.db.PrepareContext(ctx, ` + rollback := func(tx *sql.Tx, err error) error { + if txErr := tx.Rollback(); txErr != nil { + return errors.Errorf("could not rollback transaction for error %w: %w", err, txErr) + } + return errors.Errorf("unexpected update error, rolled back: %w", err) + } + + defer func() { + if err := tx.Rollback(); err != nil { + // TODO: replace with logger. + log.Printf("failed to rollback update transaction: %s", err) + } + }() + + for _, finding := range findings { + b, err := protojson.Marshal(finding.Finding) + if err != nil { + return rollback( + tx, + errors.Errorf("could not marshal JSON finding: %w", err), + ) + } + + stmt, err := tx.PrepareContext(ctx, ` UPDATE finding SET - findings = :findings, + details = :details, updated_at = :updated_at - WHERE + WHERE + id = :id AND instance_id = :instance_id ; - `) - if err != nil { - return errors.Errorf("could not prepare update statement: %w", err) - } + `) + if err != nil { + return rollback( + tx, + errors.Errorf("could not prepare update statement: %w", err), + ) + } - defer stmt.Close() + defer stmt.Close() - res, err := stmt.Exec( - sql.Named(LocalStoreColumnNameInstanceId.String(), instanceID.String()), - sql.Named(LocalStoreColumnNameUpdatedAt.String(), m.clock.Now().UTC().Format(time.RFC3339)), - sql.Named(LocalStoreColumnNameFindings.String(), jsonFindings), - ) - if err != nil { - return errors.Errorf("could not update findings: %w", err) - } - - r, err := res.RowsAffected() - switch { - case err != nil: - return errors.Errorf("could not get rows affected: %w", err) - case r <= 0: - return errors.Errorf( - "could not update findings for instance '%s': %w", - instanceID.String(), - store.ErrNoFindingsFound, + res, err := stmt.Exec( + sql.Named(LocalStoreColumnNameInstanceId.String(), instanceID.String()), + sql.Named(LocalStoreColumnNameUpdatedAt.String(), m.clock.Now().UTC().Format(time.RFC3339)), + sql.Named(LocalStoreColumnNameDetails.String(), string(b)), + sql.Named(LocalStoreColumnNameId.String(), finding.ID), ) + if err != nil { + return rollback( + tx, + errors.Errorf("could not update findings: %w", err), + ) + } + + r, err := res.RowsAffected() + switch { + case err != nil: + return rollback( + tx, + errors.Errorf( + "could not get rows affected for finding with id %d: %w", finding.ID, err), + ) + case r <= 0: + return rollback( + tx, + errors.Errorf( + "could not update findings for instance '%s' with id %d: %w", + instanceID.String(), + finding.ID, + store.ErrNoFindingsFound, + ), + ) + } + } + + if err := tx.Commit(); err != nil { + return rollback(tx, errors.Errorf("could not commit update transaction: %w", err)) } return nil @@ -233,8 +305,8 @@ func (m *manager) migrate() error { stmt, err := m.db.Prepare(` CREATE TABLE IF NOT EXISTS finding ( id INTEGER PRIMARY KEY AUTOINCREMENT UNIQUE, - instance_id UUID NOT NULL UNIQUE, - findings TEXT NOT NULL, + instance_id UUID NOT NULL, + details TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); diff --git a/sdk/component/store/local/storer_test.go b/sdk/component/store/local/storer_test.go index 6fbc75c1b..aaf15657f 100644 --- a/sdk/component/store/local/storer_test.go +++ b/sdk/component/store/local/storer_test.go @@ -15,6 +15,7 @@ import ( "github.com/smithy-security/smithy/sdk/component/store" localstore "github.com/smithy-security/smithy/sdk/component/store/local" "github.com/smithy-security/smithy/sdk/component/uuid" + vf "github.com/smithy-security/smithy/sdk/component/vulnerability-finding" ocsf "github.com/smithy-security/smithy/sdk/gen/ocsf_schema/v1" ) @@ -144,28 +145,26 @@ func (mts *ManagerTestSuite) TestManager() { ) }) - mts.t.Run("given findings for an existing instance exist, a second write will fail", func(t *testing.T) { - require.Error( - mts.t, - mts.manager.Write( - ctx, - instanceID, - findings, - ), - ) - }) - mts.t.Run("given two findings are present in the database, I should be able to retrieve them", func(t *testing.T) { resFindings, err := mts.manager.Read(ctx, instanceID) require.NoError(mts.t, err) require.Len(mts.t, resFindings, 2) - assert.EqualValues(mts.t, findings, resFindings) + + assert.Equal(t, uint64(1), resFindings[0].ID) + assert.Equal(mts.t, findings[0], resFindings[0].Finding) + assert.Equal(t, uint64(2), resFindings[1].ID) + assert.Equal(mts.t, findings[1], resFindings[1].Finding) }) mts.t.Run("given a non existing instance id in the database, updating should fail", func(t *testing.T) { require.ErrorIs( mts.t, - mts.manager.Update(ctx, uuid.New(), findings), + mts.manager.Update(ctx, uuid.New(), []*vf.VulnerabilityFinding{ + { + ID: 1, + Finding: findings[0], + }, + }), store.ErrNoFindingsFound, ) }) @@ -182,13 +181,26 @@ func (mts *ManagerTestSuite) TestManager() { require.NoError( mts.t, - mts.manager.Update(ctx, instanceID, copyFindings), + mts.manager.Update(ctx, instanceID, []*vf.VulnerabilityFinding{ + { + ID: 1, + Finding: copyFindings[0], + }, + { + ID: 2, + Finding: copyFindings[1], + }, + }), ) resFindings, err := mts.manager.Read(ctx, instanceID) require.NoError(mts.t, err) require.Len(mts.t, resFindings, 2) - assert.EqualValues(mts.t, copyFindings, resFindings) + + assert.Equal(t, uint64(1), resFindings[0].ID) + assert.Equal(mts.t, copyFindings[0], resFindings[0].Finding) + assert.Equal(t, uint64(2), resFindings[1].ID) + assert.Equal(mts.t, copyFindings[1], resFindings[1].Finding) }) }