diff --git a/sdk/README.md b/sdk/README.md index 93619511e..97db06146 100644 --- a/sdk/README.md +++ b/sdk/README.md @@ -22,11 +22,12 @@ while taking care of the boring things for you: You can customise a component using the following environment variables: -| Environment Variable | Type | Required | Possible Values | -|----------------------------|--------|----------|--------------------------| -| SMITHY\_COMPONENT\_NAME | string | yes | - | -| SMITHY\_BACKEND\_STORE\_TYPE | string | yes | local, test, \*remote | -| SMITHY\_LOG\_LEVEL | string | false | info, debug, warn, error | +| Environment Variable | Type | Required | Default | Possible Values | +|-------------------------------------|--------|-----------|--------------------------|--------------------------| +| SMITHY\_COMPONENT\_NAME | string | yes | - | - | +| SMITHY\_BACKEND\_STORE\_TYPE | string | yes | - | local, test, \*remote | +| SSMITHY\_BACKEND\_STORE\_DSN | string | no | smithy.db | \* | +| SMITHY\_LOG\_LEVEL | string | false | info, debug, warn, error | For `local` development, an `SQLite` Backend Store Type will be used. diff --git a/sdk/component/component.go b/sdk/component/component.go index 5b8d61aac..398650b46 100644 --- a/sdk/component/component.go +++ b/sdk/component/component.go @@ -18,19 +18,19 @@ type ( // Reader allows reading vulnerability findings from a storage. Reader interface { // Read reads vulnerability findings from a storage. - Read(ctx context.Context, workflowID uuid.UUID) ([]*ocsf.VulnerabilityFinding, error) + Read(ctx context.Context, instanceID uuid.UUID) ([]*ocsf.VulnerabilityFinding, error) } // Updater allows updating vulnerability findings in an underlying storage. Updater interface { // Update updates existing vulnerability findings. - Update(ctx context.Context, workflowID uuid.UUID, findings []*ocsf.VulnerabilityFinding) error + Update(ctx context.Context, instanceID uuid.UUID, findings []*ocsf.VulnerabilityFinding) error } // Writer allows writing non-existent vulnerability findings in an underlying storage. Writer interface { // Write writes non-existing vulnerability findings. - Write(ctx context.Context, workflowID uuid.UUID, findings []*ocsf.VulnerabilityFinding) error + Write(ctx context.Context, instanceID uuid.UUID, findings []*ocsf.VulnerabilityFinding) error } // Closer allows to define behaviours to close component dependencies gracefully. diff --git a/sdk/component/component_test.go b/sdk/component/component_test.go index 7974bdb3a..0c251b8bf 100644 --- a/sdk/component/component_test.go +++ b/sdk/component/component_test.go @@ -25,7 +25,7 @@ type ( func (t testFilter) Read( ctx context.Context, - workflowID uuid.UUID, + instanceID uuid.UUID, ) ([]*ocsf.VulnerabilityFinding, error) { return nil, nil } @@ -43,7 +43,7 @@ func (t testFilter) Close(ctx context.Context) error { func (t testFilter) Update( ctx context.Context, - workflowID uuid.UUID, + instanceID uuid.UUID, findings []*ocsf.VulnerabilityFinding, ) error { return nil @@ -51,7 +51,7 @@ func (t testFilter) Update( func (t testReporter) Read( ctx context.Context, - workflowID uuid.UUID, + instanceID uuid.UUID, ) ([]*ocsf.VulnerabilityFinding, error) { return nil, nil } @@ -66,14 +66,14 @@ func (t testReporter) Close(ctx context.Context) error { func (t testEnricher) Read( ctx context.Context, - workflowID uuid.UUID, + instanceID uuid.UUID, ) ([]*ocsf.VulnerabilityFinding, error) { return nil, nil } func (t testEnricher) Update( ctx context.Context, - workflowID uuid.UUID, + instanceID uuid.UUID, findings []*ocsf.VulnerabilityFinding, ) error { return nil @@ -92,7 +92,7 @@ func (t testEnricher) Close(ctx context.Context) error { func (t testScanner) Write( ctx context.Context, - workflowID uuid.UUID, + instanceID uuid.UUID, findings []*ocsf.VulnerabilityFinding, ) error { return nil diff --git a/sdk/component/conf.go b/sdk/component/conf.go index 4f862474e..98799adbd 100644 --- a/sdk/component/conf.go +++ b/sdk/component/conf.go @@ -19,10 +19,14 @@ const ( errReasonCannotBeNil = "cannot be nil" // Env vars. - envVarKeyComponentName = "SMITHY_COMPONENT_NAME" - envVarKeyWorkflowID = "SMITHY_WORKFLOW_ID" + // -- BASE + envVarKeyComponentName = "SMITHY_COMPONENT_NAME" + envVarKeyInstanceID = "SMITHY_INSTANCE_ID" + // -- LOGGING envVarKeyLoggingLogLevel = "SMITHY_LOG_LEVEL" - envVarKeyBackedStoreType = "SMITHY_BACKEND_STORE_TYPE" + // -- STORE + envVarKeyBackendStoreType = "SMITHY_BACKEND_STORE_TYPE" + envVarKeyBackendStoreDSN = "SMITHY_BACKEND_STORE_DSN" ) type ( @@ -31,7 +35,7 @@ type ( RunnerConfig struct { SDKVersion string ComponentName string - WorkflowID uuid.UUID + InstanceID uuid.UUID Logging RunnerConfigLogging PanicHandler PanicHandler @@ -52,8 +56,8 @@ type ( } runnerConfigStorer struct { - enabled bool storeType storeType + dbDSN string store Storer } @@ -101,9 +105,9 @@ func (rc *RunnerConfig) isValid() error { FieldName: "component_name", Reason: errReasonCannotBeEmpty, } - case rc.WorkflowID.IsNil(): + case rc.InstanceID.IsNil(): return ErrInvalidRunnerConfig{ - FieldName: "workflow_id", + FieldName: "instance_id", Reason: errReasonCannotBeNil, } case rc.Logging.Logger == nil: @@ -116,7 +120,7 @@ func (rc *RunnerConfig) isValid() error { FieldName: "panic_handler", Reason: errReasonCannotBeNil, } - case rc.storerConfig.enabled && rc.storerConfig.store == nil: + case rc.storerConfig.store == nil: return ErrInvalidRunnerConfig{ FieldName: "store_type", Reason: errReasonCannotBeNil, @@ -154,16 +158,16 @@ func RunnerWithComponentName(name string) RunnerOption { } } -// RunnerWithWorkflowID allows customising the workflow id. -func RunnerWithWorkflowID(id uuid.UUID) RunnerOption { +// RunnerWithInstanceID allows customising the instance id. +func RunnerWithInstanceID(id uuid.UUID) RunnerOption { return func(r *runner) error { if id.IsNil() { return ErrRunnerOption{ - OptionName: "workflow id", + OptionName: "instance id", Reason: errReasonCannotBeEmpty, } } - r.config.WorkflowID = id + r.config.InstanceID = id return nil } } @@ -183,7 +187,6 @@ func RunnerWithStorer(stType string, store Storer) RunnerOption { Reason: errReasonCannotBeNil, } } - r.config.storerConfig.enabled = true r.config.storerConfig.store = store r.config.storerConfig.storeType = storeTypeLocal return nil @@ -204,14 +207,14 @@ func newRunnerConfig() (*RunnerConfig, error) { return nil, fmt.Errorf("could not lookup environment for '%s': %w", envVarKeyComponentName, err) } - workflowIDStr, err := fromEnvOrDefault(envVarKeyWorkflowID, "", withFallbackToDefaultOnError(true)) + instanceIDStr, err := fromEnvOrDefault(envVarKeyInstanceID, "", withFallbackToDefaultOnError(true)) if err != nil { - return nil, fmt.Errorf("could not lookup environment for '%s': %w", envVarKeyWorkflowID, err) + return nil, fmt.Errorf("could not lookup environment for '%s': %w", envVarKeyInstanceID, err) } - workflowID, err := uuid.Parse(workflowIDStr) + instanceID, err := uuid.Parse(instanceIDStr) if err != nil { - return nil, fmt.Errorf("could not parse workflow ID '%s': %w", workflowIDStr, err) + return nil, fmt.Errorf("could not parse instance ID '%s': %w", instanceIDStr, err) } // --- END - BASIC ENV - END --- @@ -228,42 +231,46 @@ func newRunnerConfig() (*RunnerConfig, error) { // --- END - LOGGING ENV - END --- // --- BEGIN - STORER ENV - BEGIN --- - st, err := fromEnvOrDefault(envVarKeyBackedStoreType, "", withFallbackToDefaultOnError(true)) + st, err := fromEnvOrDefault(envVarKeyBackendStoreType, "", withFallbackToDefaultOnError(true)) if err != nil { - return nil, fmt.Errorf("could not lookup environment for '%s': %w", envVarKeyBackedStoreType, err) + return nil, fmt.Errorf("could not lookup environment for '%s': %w", envVarKeyBackendStoreType, err) } - var ( - storageType = storeType(st) - store Storer = nil - storeEnabled = false - ) + conf := &RunnerConfig{ + ComponentName: componentName, + SDKVersion: sdk.Version, + InstanceID: instanceID, + Logging: RunnerConfigLogging{ + Level: RunnerConfigLoggingLevel(logLevel), + Logger: logger, + }, + PanicHandler: panicHandler, + } if st != "" { + var storageType = storeType(st) if !isAllowedStoreType(storageType) { - return nil, fmt.Errorf("invalid store type for '%s': %w", envVarKeyBackedStoreType, err) + return nil, fmt.Errorf("invalid store type for '%s': %w", envVarKeyBackendStoreType, err) } - store, err = newStorer(storageType) + + conf.storerConfig.storeType = storageType + + dbDSN, err := fromEnvOrDefault( + envVarKeyBackendStoreDSN, + "smithy.db", + withFallbackToDefaultOnError(true), + ) if err != nil { - return nil, fmt.Errorf("could not initialise store for '%s': %w", envVarKeyBackedStoreType, err) + return nil, fmt.Errorf("could not lookup environment for '%s': %w", envVarKeyBackendStoreDSN, err) + } + + conf.storerConfig.dbDSN = dbDSN + conf.storerConfig.store, err = newStorer(conf.storerConfig) + if err != nil { + return nil, fmt.Errorf("could not initialise store for '%s': %w", envVarKeyBackendStoreType, err) } - storeEnabled = true } // --- END - STORER ENV - END --- - return &RunnerConfig{ - ComponentName: componentName, - SDKVersion: sdk.Version, - WorkflowID: workflowID, - Logging: RunnerConfigLogging{ - Level: RunnerConfigLoggingLevel(logLevel), - Logger: logger, - }, - PanicHandler: panicHandler, - storerConfig: runnerConfigStorer{ - storeType: storageType, - store: store, - enabled: storeEnabled, - }, - }, nil + return conf, nil } diff --git a/sdk/component/enricher.go b/sdk/component/enricher.go index 23f7f6b25..ca1b6b6e2 100644 --- a/sdk/component/enricher.go +++ b/sdk/component/enricher.go @@ -11,7 +11,7 @@ func RunEnricher(ctx context.Context, enricher Enricher, opts ...RunnerOption) e ctx, func(ctx context.Context, cfg *RunnerConfig) error { var ( - workflowID = cfg.WorkflowID + instanceID = cfg.InstanceID logger = LoggerFromContext(ctx).With(logKeyComponentType, "enricher") store = cfg.storerConfig.store ) @@ -25,7 +25,7 @@ func RunEnricher(ctx context.Context, enricher Enricher, opts ...RunnerOption) e logger.Debug("preparing to execute enricher component...") logger.Debug("preparing to execute read step...") - findings, err := store.Read(ctx, workflowID) + findings, err := store.Read(ctx, instanceID) if err != nil { logger.With(logKeyError, err.Error()).Error("reading step failed") return fmt.Errorf("could not read: %w", err) @@ -45,7 +45,7 @@ func RunEnricher(ctx context.Context, enricher Enricher, opts ...RunnerOption) e logger.Debug("enricher step completed!") logger.Debug("preparing to execute update step...") - if err := store.Update(ctx, workflowID, enrichedFindings); err != nil { + if err := store.Update(ctx, instanceID, enrichedFindings); err != nil { logger.With(logKeyError, err.Error()).Error("updating step failed") return fmt.Errorf("could not update: %w", err) } diff --git a/sdk/component/enricher_test.go b/sdk/component/enricher_test.go index 2243cd54d..af9892478 100644 --- a/sdk/component/enricher_test.go +++ b/sdk/component/enricher_test.go @@ -19,7 +19,7 @@ import ( func runEnricherHelper( t *testing.T, ctx context.Context, - workflowID uuid.UUID, + instanceID uuid.UUID, enricher component.Enricher, store component.Storer, ) error { @@ -30,7 +30,7 @@ func runEnricherHelper( enricher, component.RunnerWithLogger(component.NewNoopLogger()), component.RunnerWithComponentName("sample-enricher"), - component.RunnerWithWorkflowID(workflowID), + component.RunnerWithInstanceID(instanceID), component.RunnerWithStorer("local", store), ) } @@ -38,7 +38,7 @@ func runEnricherHelper( func TestRunEnricher(t *testing.T) { var ( ctrl, ctx = gomock.WithContext(context.Background(), t) - workflowID = uuid.New() + instanceID = uuid.New() mockCtx = gomock.AssignableToTypeOf(ctx) mockStore = mocks.NewMockStorer(ctrl) mockEnricher = mocks.NewMockEnricher(ctrl) @@ -50,7 +50,7 @@ func TestRunEnricher(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(vulns, nil), mockEnricher. EXPECT(). @@ -58,7 +58,7 @@ func TestRunEnricher(t *testing.T) { Return(enrichedVulns, nil), mockStore. EXPECT(). - Update(mockCtx, workflowID, enrichedVulns). + Update(mockCtx, instanceID, enrichedVulns). Return(nil), mockStore. EXPECT(). @@ -66,7 +66,7 @@ func TestRunEnricher(t *testing.T) { Return(nil), ) - require.NoError(t, runEnricherHelper(t, ctx, workflowID, mockEnricher, mockStore)) + require.NoError(t, runEnricherHelper(t, ctx, instanceID, mockEnricher, mockStore)) }) t.Run("it should return early when the context is cancelled", func(t *testing.T) { @@ -75,7 +75,7 @@ func TestRunEnricher(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(vulns, nil), mockEnricher. EXPECT(). @@ -87,11 +87,11 @@ func TestRunEnricher(t *testing.T) { }), mockStore. EXPECT(). - Update(mockCtx, workflowID, enrichedVulns). + Update(mockCtx, instanceID, enrichedVulns). DoAndReturn( func( ctx context.Context, - workflowID uuid.UUID, + instanceID uuid.UUID, vulns []*ocsf.VulnerabilityFinding, ) error { <-ctx.Done() @@ -103,7 +103,7 @@ func TestRunEnricher(t *testing.T) { Return(nil), ) - require.NoError(t, runEnricherHelper(t, ctx, workflowID, mockEnricher, mockStore)) + require.NoError(t, runEnricherHelper(t, ctx, instanceID, mockEnricher, mockStore)) }) t.Run("it should return early when reading errors", func(t *testing.T) { @@ -112,7 +112,7 @@ func TestRunEnricher(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(nil, errRead), mockStore. EXPECT(). @@ -120,7 +120,7 @@ func TestRunEnricher(t *testing.T) { Return(nil), ) - err := runEnricherHelper(t, ctx, workflowID, mockEnricher, mockStore) + err := runEnricherHelper(t, ctx, instanceID, mockEnricher, mockStore) require.Error(t, err) assert.ErrorIs(t, err, errRead) }) @@ -131,7 +131,7 @@ func TestRunEnricher(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(vulns, nil), mockEnricher. EXPECT(). @@ -143,9 +143,7 @@ func TestRunEnricher(t *testing.T) { Return(nil), ) - err := runEnricherHelper(t, ctx, workflowID, mockEnricher, mockStore) - require.Error(t, err) - assert.ErrorIs(t, err, errAnnotation) + require.ErrorIs(t, runEnricherHelper(t, ctx, instanceID, mockEnricher, mockStore), errAnnotation) }) t.Run("it should return early when updating errors", func(t *testing.T) { @@ -154,7 +152,7 @@ func TestRunEnricher(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(vulns, nil), mockEnricher. EXPECT(). @@ -162,7 +160,7 @@ func TestRunEnricher(t *testing.T) { Return(enrichedVulns, nil), mockStore. EXPECT(). - Update(mockCtx, workflowID, enrichedVulns). + Update(mockCtx, instanceID, enrichedVulns). Return(errUpdate), mockStore. EXPECT(). @@ -170,9 +168,7 @@ func TestRunEnricher(t *testing.T) { Return(nil), ) - err := runEnricherHelper(t, ctx, workflowID, mockEnricher, mockStore) - require.Error(t, err) - assert.ErrorIs(t, err, errUpdate) + require.ErrorIs(t, runEnricherHelper(t, ctx, instanceID, mockEnricher, mockStore), errUpdate) }) t.Run("it should return early when a panic is detected on enriching", func(t *testing.T) { @@ -181,7 +177,7 @@ func TestRunEnricher(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(vulns, nil), mockEnricher. EXPECT(). @@ -197,8 +193,6 @@ func TestRunEnricher(t *testing.T) { Return(nil), ) - err := runEnricherHelper(t, ctx, workflowID, mockEnricher, mockStore) - require.Error(t, err) - assert.ErrorIs(t, err, errAnnotation) + require.ErrorIs(t, runEnricherHelper(t, ctx, instanceID, mockEnricher, mockStore), errAnnotation) }) } diff --git a/sdk/component/examples/enricher/.env b/sdk/component/examples/enricher/.env new file mode 100644 index 000000000..159cd0e76 --- /dev/null +++ b/sdk/component/examples/enricher/.env @@ -0,0 +1,4 @@ +SMITHY_COMPONENT_NAME=sample-enricher +SMITHY_INSTANCE_ID=8d719c1c-c569-4078-87b3-4951bd4012ee +SMITHY_LOG_LEVEL=debug +SMITHY_BACKEND_STORE_TYPE=local diff --git a/sdk/component/examples/enricher/main.go b/sdk/component/examples/enricher/main.go index c2fd5a424..4b1fad31a 100644 --- a/sdk/component/examples/enricher/main.go +++ b/sdk/component/examples/enricher/main.go @@ -6,8 +6,6 @@ import ( "time" "github.com/smithy-security/smithy/sdk/component" - "github.com/smithy-security/smithy/sdk/component/internal/storer/local" - "github.com/smithy-security/smithy/sdk/component/internal/uuid" ocsf "github.com/smithy-security/smithy/sdk/gen/com/github/ocsf/ocsf_schema/v1" ) @@ -22,18 +20,7 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - storageManager, err := local.NewStoreManager() - if err != nil { - log.Fatalf("failed to create storage manager: %v", err) - } - - if err := component.RunEnricher( - ctx, - sampleEnricher{}, - component.RunnerWithComponentName("sample-enricher"), - component.RunnerWithStorer("local", storageManager), - component.RunnerWithWorkflowID(uuid.New()), - ); err != nil { + if err := component.RunEnricher(ctx, sampleEnricher{}); err != nil { log.Fatalf("unexpected run error: %v", err) } } diff --git a/sdk/component/examples/filter/.env b/sdk/component/examples/filter/.env new file mode 100644 index 000000000..117569576 --- /dev/null +++ b/sdk/component/examples/filter/.env @@ -0,0 +1,4 @@ +SMITHY_COMPONENT_NAME=sample-filter +SMITHY_INSTANCE_ID=8d719c1c-c569-4078-87b3-4951bd4012ee +SMITHY_LOG_LEVEL=debug +SMITHY_BACKEND_STORE_TYPE=local diff --git a/sdk/component/examples/filter/main.go b/sdk/component/examples/filter/main.go index 7d8efb023..54d520d8f 100644 --- a/sdk/component/examples/filter/main.go +++ b/sdk/component/examples/filter/main.go @@ -6,8 +6,6 @@ import ( "time" "github.com/smithy-security/smithy/sdk/component" - "github.com/smithy-security/smithy/sdk/component/internal/storer/local" - "github.com/smithy-security/smithy/sdk/component/internal/uuid" ocsf "github.com/smithy-security/smithy/sdk/gen/com/github/ocsf/ocsf_schema/v1" ) @@ -22,18 +20,7 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - storageManager, err := local.NewStoreManager() - if err != nil { - log.Fatalf("failed to create storage manager: %v", err) - } - - if err := component.RunFilter( - ctx, - sampleFilter{}, - component.RunnerWithComponentName("sample-filter"), - component.RunnerWithStorer("local", storageManager), - component.RunnerWithWorkflowID(uuid.New()), - ); err != nil { + if err := component.RunFilter(ctx, sampleFilter{}); err != nil { log.Fatalf("unexpected run error: %v", err) } } diff --git a/sdk/component/examples/reporter/.env b/sdk/component/examples/reporter/.env new file mode 100644 index 000000000..4d2370a31 --- /dev/null +++ b/sdk/component/examples/reporter/.env @@ -0,0 +1,4 @@ +SMITHY_COMPONENT_NAME=sample-reporter +SMITHY_INSTANCE_ID=8d719c1c-c569-4078-87b3-4951bd4012ee +SMITHY_LOG_LEVEL=debug +SMITHY_BACKEND_STORE_TYPE=local diff --git a/sdk/component/examples/reporter/main.go b/sdk/component/examples/reporter/main.go index c9ce49c5c..cf8d271df 100644 --- a/sdk/component/examples/reporter/main.go +++ b/sdk/component/examples/reporter/main.go @@ -6,8 +6,6 @@ import ( "time" "github.com/smithy-security/smithy/sdk/component" - "github.com/smithy-security/smithy/sdk/component/internal/storer/local" - "github.com/smithy-security/smithy/sdk/component/internal/uuid" ocsf "github.com/smithy-security/smithy/sdk/gen/com/github/ocsf/ocsf_schema/v1" ) @@ -22,18 +20,7 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - storageManager, err := local.NewStoreManager() - if err != nil { - log.Fatalf("failed to create storage manager: %v", err) - } - - if err := component.RunReporter( - ctx, - sampleReporter{}, - component.RunnerWithComponentName("sample-reporter"), - component.RunnerWithStorer("local", storageManager), - component.RunnerWithWorkflowID(uuid.New()), - ); err != nil { + if err := component.RunReporter(ctx, sampleReporter{}); err != nil { log.Fatalf("unexpected run error: %v", err) } } diff --git a/sdk/component/examples/scanner/.env b/sdk/component/examples/scanner/.env new file mode 100644 index 000000000..6622f5936 --- /dev/null +++ b/sdk/component/examples/scanner/.env @@ -0,0 +1,4 @@ +SMITHY_COMPONENT_NAME=sample-scanner +SMITHY_INSTANCE_ID=8d719c1c-c569-4078-87b3-4951bd4012ee +SMITHY_LOG_LEVEL=debug +SMITHY_BACKEND_STORE_TYPE=local diff --git a/sdk/component/examples/scanner/main.go b/sdk/component/examples/scanner/main.go index 432e97055..408fc20f4 100644 --- a/sdk/component/examples/scanner/main.go +++ b/sdk/component/examples/scanner/main.go @@ -6,8 +6,6 @@ import ( "time" "github.com/smithy-security/smithy/sdk/component" - "github.com/smithy-security/smithy/sdk/component/internal/storer/local" - "github.com/smithy-security/smithy/sdk/component/internal/uuid" ocsf "github.com/smithy-security/smithy/sdk/gen/com/github/ocsf/ocsf_schema/v1" ) @@ -22,18 +20,7 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - storageManager, err := local.NewStoreManager() - if err != nil { - log.Fatalf("failed to create storage manager: %v", err) - } - - if err := component.RunScanner( - ctx, - sampleScanner{}, - component.RunnerWithComponentName("sample-scanner"), - component.RunnerWithStorer("local", storageManager), - component.RunnerWithWorkflowID(uuid.New()), - ); err != nil { + if err := component.RunScanner(ctx, sampleScanner{}); err != nil { log.Fatalf("unexpected run error: %v", err) } } diff --git a/sdk/component/examples/target/.env b/sdk/component/examples/target/.env new file mode 100644 index 000000000..f552ba404 --- /dev/null +++ b/sdk/component/examples/target/.env @@ -0,0 +1,4 @@ +SMITHY_COMPONENT_NAME=sample-target +SMITHY_INSTANCE_ID=8d719c1c-c569-4078-87b3-4951bd4012ee +SMITHY_LOG_LEVEL=debug +SMITHY_BACKEND_STORE_TYPE=local diff --git a/sdk/component/examples/target/main.go b/sdk/component/examples/target/main.go index b0fcdd414..2cfed4a27 100644 --- a/sdk/component/examples/target/main.go +++ b/sdk/component/examples/target/main.go @@ -6,7 +6,6 @@ import ( "time" "github.com/smithy-security/smithy/sdk/component" - "github.com/smithy-security/smithy/sdk/component/internal/uuid" ) type sampleTarget struct{} @@ -20,12 +19,7 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - if err := component.RunTarget( - ctx, - sampleTarget{}, - component.RunnerWithComponentName("sample-target"), - component.RunnerWithWorkflowID(uuid.New()), - ); err != nil { + if err := component.RunTarget(ctx, sampleTarget{}); err != nil { log.Fatalf("unexpected run error: %v", err) } } diff --git a/sdk/component/filter.go b/sdk/component/filter.go index 467468625..c99590960 100644 --- a/sdk/component/filter.go +++ b/sdk/component/filter.go @@ -11,7 +11,7 @@ func RunFilter(ctx context.Context, filter Filter, opts ...RunnerOption) error { ctx, func(ctx context.Context, cfg *RunnerConfig) error { var ( - workflowID = cfg.WorkflowID + instanceID = cfg.InstanceID logger = LoggerFromContext(ctx).With(logKeyComponentType, "filter") store = cfg.storerConfig.store ) @@ -25,7 +25,7 @@ func RunFilter(ctx context.Context, filter Filter, opts ...RunnerOption) error { logger.Debug("preparing to execute filter component...") logger.Debug("preparing to execute read step...") - findings, err := store.Read(ctx, workflowID) + findings, err := store.Read(ctx, instanceID) if err != nil { logger.With(logKeyError, err.Error()).Error("reading step failed") return fmt.Errorf("could not read: %w", err) @@ -49,7 +49,7 @@ func RunFilter(ctx context.Context, filter Filter, opts ...RunnerOption) error { logger.Debug("filter step completed!") logger.Debug("preparing to execute update step...") - if err := store.Update(ctx, workflowID, filteredFindings); err != nil { + if err := store.Update(ctx, instanceID, filteredFindings); err != nil { logger.With(logKeyError, err.Error()).Error("updating step failed") return fmt.Errorf("could not update: %w", err) } diff --git a/sdk/component/filter_test.go b/sdk/component/filter_test.go index 3d1bcfdde..2c452f0cd 100644 --- a/sdk/component/filter_test.go +++ b/sdk/component/filter_test.go @@ -5,7 +5,6 @@ import ( "errors" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -18,7 +17,7 @@ import ( func runFilterHelper( t *testing.T, ctx context.Context, - workflowID uuid.UUID, + instanceID uuid.UUID, filter component.Filter, store component.Storer, ) error { @@ -29,7 +28,7 @@ func runFilterHelper( filter, component.RunnerWithLogger(component.NewNoopLogger()), component.RunnerWithComponentName("sample-filter"), - component.RunnerWithWorkflowID(workflowID), + component.RunnerWithInstanceID(instanceID), component.RunnerWithStorer("local", store), ) } @@ -37,7 +36,7 @@ func runFilterHelper( func TestRunFilter(t *testing.T) { var ( ctrl, ctx = gomock.WithContext(context.Background(), t) - workflowID = uuid.New() + instanceID = uuid.New() mockCtx = gomock.AssignableToTypeOf(ctx) mockStore = mocks.NewMockStorer(ctrl) mockFilter = mocks.NewMockFilter(ctrl) @@ -49,7 +48,7 @@ func TestRunFilter(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(vulns, nil), mockFilter. EXPECT(). @@ -57,7 +56,7 @@ func TestRunFilter(t *testing.T) { Return(filteredVulns, true, nil), mockStore. EXPECT(). - Update(mockCtx, workflowID, filteredVulns). + Update(mockCtx, instanceID, filteredVulns). Return(nil), mockStore. EXPECT(). @@ -65,14 +64,14 @@ func TestRunFilter(t *testing.T) { Return(nil), ) - require.NoError(t, runFilterHelper(t, ctx, workflowID, mockFilter, mockStore)) + require.NoError(t, runFilterHelper(t, ctx, instanceID, mockFilter, mockStore)) }) t.Run("it should run a filter correctly and return early as no filtering was done", func(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(vulns, nil), mockFilter. EXPECT(). @@ -84,7 +83,7 @@ func TestRunFilter(t *testing.T) { Return(nil), ) - require.NoError(t, runFilterHelper(t, ctx, workflowID, mockFilter, mockStore)) + require.NoError(t, runFilterHelper(t, ctx, instanceID, mockFilter, mockStore)) }) t.Run("it should return early when the context is cancelled", func(t *testing.T) { @@ -93,7 +92,7 @@ func TestRunFilter(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(vulns, nil), mockFilter. EXPECT(). @@ -104,11 +103,11 @@ func TestRunFilter(t *testing.T) { }), mockStore. EXPECT(). - Update(mockCtx, workflowID, filteredVulns). + Update(mockCtx, instanceID, filteredVulns). DoAndReturn( func( ctx context.Context, - workflowID uuid.UUID, + instanceID uuid.UUID, vulns []*ocsf.VulnerabilityFinding, ) error { <-ctx.Done() @@ -120,7 +119,7 @@ func TestRunFilter(t *testing.T) { Return(nil), ) - require.NoError(t, runFilterHelper(t, ctx, workflowID, mockFilter, mockStore)) + require.NoError(t, runFilterHelper(t, ctx, instanceID, mockFilter, mockStore)) }) t.Run("it should return early when reading errors", func(t *testing.T) { @@ -129,7 +128,7 @@ func TestRunFilter(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(nil, errRead), mockStore. EXPECT(). @@ -137,9 +136,7 @@ func TestRunFilter(t *testing.T) { Return(nil), ) - err := runFilterHelper(t, ctx, workflowID, mockFilter, mockStore) - require.Error(t, err) - assert.ErrorIs(t, err, errRead) + require.ErrorIs(t, runFilterHelper(t, ctx, instanceID, mockFilter, mockStore), errRead) }) t.Run("it should return early when filtering errors", func(t *testing.T) { @@ -148,7 +145,7 @@ func TestRunFilter(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(vulns, nil), mockFilter. EXPECT(). @@ -160,9 +157,7 @@ func TestRunFilter(t *testing.T) { Return(nil), ) - err := runFilterHelper(t, ctx, workflowID, mockFilter, mockStore) - require.Error(t, err) - assert.ErrorIs(t, err, errFilter) + require.ErrorIs(t, runFilterHelper(t, ctx, instanceID, mockFilter, mockStore), errFilter) }) t.Run("it should return early when updating errors", func(t *testing.T) { @@ -171,7 +166,7 @@ func TestRunFilter(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(vulns, nil), mockFilter. EXPECT(). @@ -179,7 +174,7 @@ func TestRunFilter(t *testing.T) { Return(filteredVulns, true, nil), mockStore. EXPECT(). - Update(mockCtx, workflowID, filteredVulns). + Update(mockCtx, instanceID, filteredVulns). Return(errUpdate), mockStore. EXPECT(). @@ -187,9 +182,7 @@ func TestRunFilter(t *testing.T) { Return(nil), ) - err := runFilterHelper(t, ctx, workflowID, mockFilter, mockStore) - require.Error(t, err) - assert.ErrorIs(t, err, errUpdate) + require.ErrorIs(t, runFilterHelper(t, ctx, instanceID, mockFilter, mockStore), errUpdate) }) t.Run("it should return early when a panic is detected on filtering", func(t *testing.T) { @@ -198,7 +191,7 @@ func TestRunFilter(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(vulns, nil), mockFilter. EXPECT(). @@ -213,8 +206,6 @@ func TestRunFilter(t *testing.T) { Return(nil), ) - err := runFilterHelper(t, ctx, workflowID, mockFilter, mockStore) - require.Error(t, err) - assert.ErrorIs(t, err, errFilter) + require.ErrorIs(t, runFilterHelper(t, ctx, instanceID, mockFilter, mockStore), errFilter) }) } diff --git a/sdk/component/internal/mocks/component_mock.go b/sdk/component/internal/mocks/component_mock.go index ae4cb02ff..73b27c87d 100644 --- a/sdk/component/internal/mocks/component_mock.go +++ b/sdk/component/internal/mocks/component_mock.go @@ -82,18 +82,18 @@ func (m *MockReader) EXPECT() *MockReaderMockRecorder { } // Read mocks base method. -func (m *MockReader) Read(ctx context.Context, workflowID uuid.UUID) ([]*pb.VulnerabilityFinding, error) { +func (m *MockReader) Read(ctx context.Context, instanceID uuid.UUID) ([]*pb.VulnerabilityFinding, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Read", ctx, workflowID) + ret := m.ctrl.Call(m, "Read", ctx, instanceID) ret0, _ := ret[0].([]*pb.VulnerabilityFinding) ret1, _ := ret[1].(error) return ret0, ret1 } // Read indicates an expected call of Read. -func (mr *MockReaderMockRecorder) Read(ctx, workflowID any) *gomock.Call { +func (mr *MockReaderMockRecorder) Read(ctx, instanceID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReader)(nil).Read), ctx, workflowID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReader)(nil).Read), ctx, instanceID) } // MockUpdater is a mock of Updater interface. @@ -121,17 +121,17 @@ func (m *MockUpdater) EXPECT() *MockUpdaterMockRecorder { } // Update mocks base method. -func (m *MockUpdater) Update(ctx context.Context, workflowID uuid.UUID, findings []*pb.VulnerabilityFinding) error { +func (m *MockUpdater) Update(ctx context.Context, instanceID uuid.UUID, findings []*pb.VulnerabilityFinding) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Update", ctx, workflowID, findings) + ret := m.ctrl.Call(m, "Update", ctx, instanceID, findings) ret0, _ := ret[0].(error) return ret0 } // Update indicates an expected call of Update. -func (mr *MockUpdaterMockRecorder) Update(ctx, workflowID, findings any) *gomock.Call { +func (mr *MockUpdaterMockRecorder) Update(ctx, instanceID, findings any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockUpdater)(nil).Update), ctx, workflowID, findings) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockUpdater)(nil).Update), ctx, instanceID, findings) } // MockWriter is a mock of Writer interface. @@ -159,17 +159,17 @@ func (m *MockWriter) EXPECT() *MockWriterMockRecorder { } // Write mocks base method. -func (m *MockWriter) Write(ctx context.Context, workflowID uuid.UUID, findings []*pb.VulnerabilityFinding) error { +func (m *MockWriter) Write(ctx context.Context, instanceID uuid.UUID, findings []*pb.VulnerabilityFinding) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", ctx, workflowID, findings) + ret := m.ctrl.Call(m, "Write", ctx, instanceID, findings) ret0, _ := ret[0].(error) return ret0 } // Write indicates an expected call of Write. -func (mr *MockWriterMockRecorder) Write(ctx, workflowID, findings any) *gomock.Call { +func (mr *MockWriterMockRecorder) Write(ctx, instanceID, findings any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockWriter)(nil).Write), ctx, workflowID, findings) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockWriter)(nil).Write), ctx, instanceID, findings) } // MockCloser is a mock of Closer interface. @@ -249,32 +249,32 @@ func (mr *MockStorerMockRecorder) Close(arg0 any) *gomock.Call { } // Read mocks base method. -func (m *MockStorer) Read(ctx context.Context, workflowID uuid.UUID) ([]*pb.VulnerabilityFinding, error) { +func (m *MockStorer) Read(ctx context.Context, instanceID uuid.UUID) ([]*pb.VulnerabilityFinding, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Read", ctx, workflowID) + ret := m.ctrl.Call(m, "Read", ctx, instanceID) ret0, _ := ret[0].([]*pb.VulnerabilityFinding) ret1, _ := ret[1].(error) return ret0, ret1 } // Read indicates an expected call of Read. -func (mr *MockStorerMockRecorder) Read(ctx, workflowID any) *gomock.Call { +func (mr *MockStorerMockRecorder) Read(ctx, instanceID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStorer)(nil).Read), ctx, workflowID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStorer)(nil).Read), ctx, instanceID) } // Update mocks base method. -func (m *MockStorer) Update(ctx context.Context, workflowID uuid.UUID, findings []*pb.VulnerabilityFinding) error { +func (m *MockStorer) Update(ctx context.Context, instanceID uuid.UUID, findings []*pb.VulnerabilityFinding) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Update", ctx, workflowID, findings) + ret := m.ctrl.Call(m, "Update", ctx, instanceID, findings) ret0, _ := ret[0].(error) return ret0 } // Update indicates an expected call of Update. -func (mr *MockStorerMockRecorder) Update(ctx, workflowID, findings any) *gomock.Call { +func (mr *MockStorerMockRecorder) Update(ctx, instanceID, findings any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockStorer)(nil).Update), ctx, workflowID, findings) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockStorer)(nil).Update), ctx, instanceID, findings) } // Validate mocks base method. @@ -292,17 +292,17 @@ func (mr *MockStorerMockRecorder) Validate(finding any) *gomock.Call { } // Write mocks base method. -func (m *MockStorer) Write(ctx context.Context, workflowID uuid.UUID, findings []*pb.VulnerabilityFinding) error { +func (m *MockStorer) Write(ctx context.Context, instanceID uuid.UUID, findings []*pb.VulnerabilityFinding) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", ctx, workflowID, findings) + ret := m.ctrl.Call(m, "Write", ctx, instanceID, findings) ret0, _ := ret[0].(error) return ret0 } // Write indicates an expected call of Write. -func (mr *MockStorerMockRecorder) Write(ctx, workflowID, findings any) *gomock.Call { +func (mr *MockStorerMockRecorder) Write(ctx, instanceID, findings any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStorer)(nil).Write), ctx, workflowID, findings) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStorer)(nil).Write), ctx, instanceID, findings) } // MockTarget is a mock of Target interface. diff --git a/sdk/component/internal/storer/local/local.go b/sdk/component/internal/storer/local/local.go deleted file mode 100644 index dcf034271..000000000 --- a/sdk/component/internal/storer/local/local.go +++ /dev/null @@ -1,52 +0,0 @@ -package local - -import ( - "context" - - "github.com/smithy-security/smithy/sdk/component/internal/uuid" - ocsf "github.com/smithy-security/smithy/sdk/gen/com/github/ocsf/ocsf_schema/v1" -) - -// storeManager is going to be the local storage manager backed by SQLite. -type storeManager struct{} - -// NewStoreManager returns a new store manager. -func NewStoreManager() (*storeManager, error) { - return &storeManager{}, nil -} - -// TODO - implement me. -func (s *storeManager) Close(ctx context.Context) error { - return nil -} - -// TODO - implement me. -func (s *storeManager) Validate(finding *ocsf.VulnerabilityFinding) error { - return nil -} - -// TODO - implement me. -func (s *storeManager) Read( - ctx context.Context, - workflowID uuid.UUID, -) ([]*ocsf.VulnerabilityFinding, error) { - return nil, nil -} - -// TODO - implement me. -func (s *storeManager) Write( - ctx context.Context, - workflowID uuid.UUID, - findings []*ocsf.VulnerabilityFinding, -) error { - return nil -} - -// TODO - implement me. -func (s *storeManager) Update( - ctx context.Context, - workflowID uuid.UUID, - findings []*ocsf.VulnerabilityFinding, -) error { - return nil -} diff --git a/sdk/component/internal/storer/local/sqlite/export_test.go b/sdk/component/internal/storer/local/sqlite/export_test.go new file mode 100644 index 000000000..a83aff129 --- /dev/null +++ b/sdk/component/internal/storer/local/sqlite/export_test.go @@ -0,0 +1,27 @@ +package sqlite + +import ( + "fmt" +) + +// CreateTable is used to create a table in testing settings. +func (m *manager) CreateTable() error { + stmt, err := m.db.Prepare(` + CREATE TABLE finding ( + id INTEGER PRIMARY KEY AUTOINCREMENT UNIQUE, + instance_id UUID NOT NULL UNIQUE, + findings TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + `) + if err != nil { + return fmt.Errorf("could not prepare statement for creating table: %w", err) + } + + if _, err := stmt.Exec(); err != nil { + return fmt.Errorf("could not create table: %w", err) + } + + return stmt.Close() +} diff --git a/sdk/component/internal/storer/local/sqlite/sqlite.go b/sdk/component/internal/storer/local/sqlite/sqlite.go index 58d1ceb1d..dc742c720 100644 --- a/sdk/component/internal/storer/local/sqlite/sqlite.go +++ b/sdk/component/internal/storer/local/sqlite/sqlite.go @@ -3,22 +3,38 @@ package sqlite import ( "context" "database/sql" + "encoding/json" + "errors" "fmt" - "path" + "time" + "github.com/jonboulle/clockwork" _ "github.com/mattn/go-sqlite3" + "google.golang.org/protobuf/encoding/protojson" + "github.com/smithy-security/smithy/sdk/component/internal/storer" + "github.com/smithy-security/smithy/sdk/component/internal/uuid" ocsf "github.com/smithy-security/smithy/sdk/gen/com/github/ocsf/ocsf_schema/v1" ) -const errInvalidConstructorEmptyReason = "cannot be empty" +const ( + errInvalidConstructorEmptyReason = "cannot be empty" + + columnNameFindings columnName = "findings" + columnNameinstanceID columnName = "instance_id" + columnNameUpdatedAt columnName = "updated_at" +) type ( manager struct { - db *sql.DB - dbName string + clock clockwork.Clock + db *sql.DB } + managerOption func(*manager) error + + columnName string + // ErrInvalidConstructor should be used for invalid manager constructor errors. ErrInvalidConstructor struct { argName string @@ -26,55 +42,201 @@ type ( } ) +// ManagerWithClock allows customising manager's clock. +func ManagerWithClock(clock clockwork.Clock) managerOption { + return func(m *manager) error { + if clock == nil { + return errors.New("cannot set clock on nil clock") + } + m.clock = clock + return nil + } +} + +func (cn columnName) String() string { + return string(cn) +} + func (e ErrInvalidConstructor) Error() string { return fmt.Sprintf("invalid argument '%s': %s", e.argName, e.reason) } // NewManager returns a new SQLite database manager. -func NewManager(dbPath string, dbName string) (*manager, error) { - switch { - case dbPath == "": - return nil, ErrInvalidConstructor{ - argName: "db path", - reason: errInvalidConstructorEmptyReason, - } - case dbName == "": +func NewManager(dsn string, opts ...managerOption) (*manager, error) { + if dsn == "" { return nil, ErrInvalidConstructor{ - argName: "db name", + argName: "db dsn", reason: errInvalidConstructorEmptyReason, } } - db, err := sql.Open("sqlite3", path.Join(dbPath, dbName)) + db, err := sql.Open("sqlite3", dsn) if err != nil { return nil, fmt.Errorf("could not open sqlite db: %w", err) } - return &manager{ - db: db, - dbName: dbName, - }, nil + mgr := &manager{ + clock: clockwork.NewRealClock(), + db: db, + } + + for _, opt := range opts { + if err := opt(mgr); err != nil { + return nil, fmt.Errorf("could not apply option: %w", err) + } + } + + return mgr, nil } -// TODO - implement me. -func (m *manager) Read(ctx context.Context) ([]*ocsf.VulnerabilityFinding, error) { - return nil, nil +// Validate. TODO - implement. +func (m *manager) Validate(*ocsf.VulnerabilityFinding) error { + return nil } -// TODO - implement me. -func (m *manager) Write(ctx context.Context, findings []*ocsf.VulnerabilityFinding) error { +// Read finds Vulnerability Findings by instanceID. +// It returns storer.ErrNoFindingsFound is not vulnerabilities were found. +func (m *manager) Read(ctx context.Context, instanceID uuid.UUID) ([]*ocsf.VulnerabilityFinding, error) { + stmt, err := m.db.PrepareContext(ctx, ` + SELECT (findings) + FROM finding + WHERE instance_id = :instance_id + ; + `) + if err != nil { + return nil, fmt.Errorf("could not prepare select statement: %w", err) + } + + defer stmt.Close() + + var jsonFindingsStr string + err = stmt. + QueryRowContext( + ctx, + sql.Named(columnNameinstanceID.String(), instanceID.String()), + ). + Scan(&jsonFindingsStr) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("%s: %w", instanceID.String(), storer.ErrNoFindingsFound) + } + return nil, fmt.Errorf("could not select findings: %w", err) + } + + var jsonFindings []json.RawMessage + if err := json.Unmarshal([]byte(jsonFindingsStr), &jsonFindings); err != nil { + return nil, fmt.Errorf("could not unmarshal json findings to []json.RawMessage: %w", err) + } + + var findings []*ocsf.VulnerabilityFinding + for _, jsonFinding := range jsonFindings { + var finding ocsf.VulnerabilityFinding + if err := protojson.Unmarshal(jsonFinding, &finding); err != nil { + return nil, fmt.Errorf("failed to unmarshal JSON findings to *ocsf.VulnerabilityFinding: %w", err) + } + findings = append(findings, &finding) + } + + return findings, nil +} + +// Write writes new vulnerabilities in JSON format in the database. +func (m *manager) Write(ctx context.Context, instanceID uuid.UUID, findings []*ocsf.VulnerabilityFinding) error { + jsonFindings, err := m.marshalFindings(findings) + if err != nil { + return err + } + + stmt, err := m.db.PrepareContext(ctx, ` + INSERT INTO finding (instance_id, findings) + VALUES (:instance_id, :findings) + ; + `) + if err != nil { + return fmt.Errorf("could not prepare write statement: %w", err) + } + + defer stmt.Close() + + if _, err = stmt.Exec( + sql.Named(columnNameinstanceID.String(), instanceID.String()), + sql.Named(columnNameFindings.String(), jsonFindings), + ); err != nil { + return fmt.Errorf("could not insert findings: %w", err) + } + return nil } -// TODO - implement me. -func (m *manager) Update(ctx context.Context, findings []*ocsf.VulnerabilityFinding) error { +// Update updates existing vulnerabilities in the underlying database. +// It returns storer.ErrNoFindingsFound if the passed instanceID is not found. +func (m *manager) Update(ctx context.Context, instanceID uuid.UUID, findings []*ocsf.VulnerabilityFinding) error { + jsonFindings, err := m.marshalFindings(findings) + if err != nil { + return err + } + + stmt, err := m.db.PrepareContext(ctx, ` + UPDATE finding + SET + findings = :findings, + updated_at = :updated_at + WHERE + instance_id = :instance_id + ; + `) + if err != nil { + return fmt.Errorf("could not prepare update statement: %w", err) + } + + defer stmt.Close() + + res, err := stmt.Exec( + sql.Named(columnNameinstanceID.String(), instanceID.String()), + sql.Named(columnNameUpdatedAt.String(), m.clock.Now().UTC().Format(time.RFC3339)), + sql.Named(columnNameFindings.String(), jsonFindings), + ) + if err != nil { + return fmt.Errorf("could not update findings: %w", err) + } + + r, err := res.RowsAffected() + switch { + case err != nil: + return fmt.Errorf("could not get rows affected: %w", err) + case r <= 0: + return fmt.Errorf( + "could not update findings for instance '%s': %w", + instanceID.String(), + storer.ErrNoFindingsFound, + ) + } + return nil } -// TODO - implement me. +// Close closes the connection to the underlying database. func (m *manager) Close(ctx context.Context) error { if err := m.db.Close(); err != nil { return fmt.Errorf("could not close sqlite db: %w", err) } return nil } + +func (m *manager) marshalFindings(findings []*ocsf.VulnerabilityFinding) (string, error) { + var rawFindings []json.RawMessage + for _, finding := range findings { + b, err := protojson.Marshal(finding) + if err != nil { + return "", fmt.Errorf("could not json marshal finding: %w", err) + } + rawFindings = append(rawFindings, b) + } + + jsonFindings, err := json.Marshal(rawFindings) + if err != nil { + return "", fmt.Errorf("could not json marshal findings: %w", err) + } + + return string(jsonFindings), nil +} diff --git a/sdk/component/internal/storer/local/sqlite/sqlite_test.go b/sdk/component/internal/storer/local/sqlite/sqlite_test.go new file mode 100644 index 000000000..bd483276b --- /dev/null +++ b/sdk/component/internal/storer/local/sqlite/sqlite_test.go @@ -0,0 +1,221 @@ +package sqlite_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/smithy-security/smithy/sdk/component" + "github.com/smithy-security/smithy/sdk/component/internal/storer" + "github.com/smithy-security/smithy/sdk/component/internal/storer/local/sqlite" + "github.com/smithy-security/smithy/sdk/component/internal/uuid" + ocsf "github.com/smithy-security/smithy/sdk/gen/com/github/ocsf/ocsf_schema/v1" +) + +const ( + dbName = "smithy.db" +) + +type ( + localStorer interface { + component.Closer + component.Reader + component.Updater + component.Writer + + CreateTable() error + } + + ManagerTestSuite struct { + suite.Suite + + t *testing.T + manager localStorer + } +) + +func (mts *ManagerTestSuite) SetupTest() { + mts.t = mts.T() + var ( + err error + clock = clockwork.NewFakeClock() + ) + + f, err := os.Create(dbName) + require.NoError(mts.t, err) + require.NoError(mts.t, f.Close()) + + mts.manager, err = sqlite.NewManager("smithy.db", sqlite.ManagerWithClock(clock)) + require.NoError(mts.t, err) + require.NoError(mts.T(), mts.manager.CreateTable()) +} + +func (mts *ManagerTestSuite) TearDownTest() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + require.NoError(mts.t, mts.manager.Close(ctx)) + require.NoError(mts.t, os.Remove(dbName)) +} + +func (mts *ManagerTestSuite) TestManager() { + var ( + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + instanceID = uuid.New() + findings = []*ocsf.VulnerabilityFinding{ + { + ActivityId: ocsf.VulnerabilityFinding_ACTIVITY_ID_CREATE, + ActivityName: ptr("Activity 1"), + CategoryName: ptr("Category A"), + CategoryUid: ocsf.VulnerabilityFinding_CATEGORY_UID_FINDINGS, + ClassName: ptr("Class A"), + ClassUid: ocsf.VulnerabilityFinding_CLASS_UID_VULNERABILITY_FINDING, + Cloud: &ocsf.Cloud{Provider: "AWS", Region: ptr("us-west-2")}, + Comment: ptr("This is a comment for finding 1."), + Confidence: ptr("High"), + ConfidenceId: ptr(ocsf.VulnerabilityFinding_CONFIDENCE_ID_HIGH), + ConfidenceScore: ptr(int32(95)), + Count: ptr(int32(1)), + Duration: ptr(int32(3600)), + EndTime: ptr(time.Now().Unix()), + EndTimeDt: timestamppb.New(time.Now()), + Enrichments: []*ocsf.Enrichment{{Type: ptr("Type1"), Value: "Value1"}}, + Message: ptr("Vulnerability finding message 1"), + Metadata: &ocsf.Metadata{ + Version: "v1.0.1", + }, + RawData: ptr(`{"foo" : "bar"}`), + Severity: ptr("Critical"), + SeverityId: ocsf.VulnerabilityFinding_SEVERITY_ID_CRITICAL, + StartTime: ptr(time.Now().Add(-time.Hour).Unix()), + StartTimeDt: timestamppb.New(time.Now().Add(-time.Hour)), + Status: ptr("Open"), + StatusCode: ptr("200"), + Time: time.Now().Unix(), + TimeDt: timestamppb.New(time.Now()), + TimezoneOffset: ptr(int32(-7)), + TypeName: ptr("Type 1"), + TypeUid: 1, + Vulnerabilities: []*ocsf.Vulnerability{{Severity: ptr("Critical")}}, + }, + { + ActivityId: ocsf.VulnerabilityFinding_ACTIVITY_ID_CREATE, + ActivityName: ptr("Activity 2"), + CategoryName: ptr("Category B"), + CategoryUid: ocsf.VulnerabilityFinding_CATEGORY_UID_FINDINGS, + ClassName: ptr("Class B"), + ClassUid: ocsf.VulnerabilityFinding_CLASS_UID_VULNERABILITY_FINDING, + Cloud: &ocsf.Cloud{Provider: "AWS", Region: ptr("us-east-2")}, + Comment: ptr("This is a comment for finding 2."), + Confidence: ptr("High"), + ConfidenceId: ptr(ocsf.VulnerabilityFinding_CONFIDENCE_ID_HIGH), + ConfidenceScore: ptr(int32(100)), + Count: ptr(int32(5)), + Duration: ptr(int32(3600)), + EndTime: ptr(time.Now().Unix()), + EndTimeDt: timestamppb.New(time.Now()), + Enrichments: []*ocsf.Enrichment{{Type: ptr("Type2"), Value: "Value2"}}, + Message: ptr("Vulnerability finding message 2"), + Metadata: &ocsf.Metadata{ + Version: "v1.0.1", + }, + RawData: ptr(`{"bar" : "baz"}`), + Severity: ptr("Critical"), + SeverityId: ocsf.VulnerabilityFinding_SEVERITY_ID_CRITICAL, + StartTime: ptr(time.Now().Add(-time.Hour).Unix()), + StartTimeDt: timestamppb.New(time.Now().Add(-time.Hour)), + Status: ptr("Closed"), + StatusCode: ptr("200"), + Time: time.Now().Unix(), + TimeDt: timestamppb.New(time.Now()), + TimezoneOffset: ptr(int32(-7)), + TypeName: ptr("Type 2"), + TypeUid: 2, + Vulnerabilities: []*ocsf.Vulnerability{{Severity: ptr("Critical")}}, + }, + } + ) + + defer cancel() + + mts.t.Run("given an empty database, when I look for findings, I get none back", func(t *testing.T) { + resFindings, err := mts.manager.Read( + ctx, + instanceID, + ) + require.ErrorIs(t, err, storer.ErrNoFindingsFound) + require.Len(mts.t, resFindings, 0) + }) + + mts.t.Run("given an empty database, I should be able to create two findings", func(t *testing.T) { + require.NoError( + mts.t, + mts.manager.Write( + ctx, + instanceID, + findings, + ), + ) + }) + + mts.t.Run("given findings for an existing instance exist, a second write will fail", func(t *testing.T) { + require.Error( + mts.t, + mts.manager.Write( + ctx, + instanceID, + findings, + ), + ) + }) + + mts.t.Run("given two findings are present in the database, I should be able to retrieve them", func(t *testing.T) { + resFindings, err := mts.manager.Read(ctx, instanceID) + require.NoError(mts.t, err) + require.Len(mts.t, resFindings, 2) + assert.EqualValues(mts.t, findings, resFindings) + }) + + mts.t.Run("given a non existing instance id in the database, updating should fail", func(t *testing.T) { + require.ErrorIs( + mts.t, + mts.manager.Update(ctx, uuid.New(), findings), + storer.ErrNoFindingsFound, + ) + }) + + mts.t.Run( + "given the previous instance id, when I change metadata in the findings, I can update them correctly", + func(t *testing.T) { + const newVersion = "v1.1.0" + + copyFindings := append([]*ocsf.VulnerabilityFinding(nil), findings...) + require.Len(mts.t, copyFindings, 2) + copyFindings[0].Metadata.Version = newVersion + copyFindings[1].Metadata.Version = newVersion + + require.NoError( + mts.t, + mts.manager.Update(ctx, instanceID, copyFindings), + ) + + resFindings, err := mts.manager.Read(ctx, instanceID) + require.NoError(mts.t, err) + require.Len(mts.t, resFindings, 2) + assert.EqualValues(mts.t, copyFindings, resFindings) + }) +} + +func TestManagerTestSuite(t *testing.T) { + suite.Run(t, new(ManagerTestSuite)) +} + +func ptr[T any](v T) *T { + return &v +} diff --git a/sdk/component/internal/storer/storer.go b/sdk/component/internal/storer/storer.go new file mode 100644 index 000000000..82bb810a3 --- /dev/null +++ b/sdk/component/internal/storer/storer.go @@ -0,0 +1,5 @@ +package storer + +import "errors" + +var ErrNoFindingsFound = errors.New("no findings found") diff --git a/sdk/component/logger.go b/sdk/component/logger.go index 5c8804a42..008cf913b 100644 --- a/sdk/component/logger.go +++ b/sdk/component/logger.go @@ -12,7 +12,7 @@ const ( logKeyPanicStackTrace = "panic_stack_trace" logKeySDKVersion = "sdk_version" logKeyComponentName = "component_name" - logKeyWorkflowID = "workflow_id" + logKeyInstanceID = "instance_id" logKeyComponentType = "component_type" logKeyNumRawFindings = "num_raw_findings" logKeyRawFinding = "raw_finding" diff --git a/sdk/component/reporter.go b/sdk/component/reporter.go index 78da2f5c0..58f001eb5 100644 --- a/sdk/component/reporter.go +++ b/sdk/component/reporter.go @@ -11,7 +11,7 @@ func RunReporter(ctx context.Context, reporter Reporter, opts ...RunnerOption) e ctx, func(ctx context.Context, cfg *RunnerConfig) error { var ( - workflowID = cfg.WorkflowID + instanceID = cfg.InstanceID logger = LoggerFromContext(ctx).With(logKeyComponentType, "reporter") store = cfg.storerConfig.store ) @@ -25,7 +25,7 @@ func RunReporter(ctx context.Context, reporter Reporter, opts ...RunnerOption) e logger.Debug("preparing to execute component...") logger.Debug("preparing to execute read step...") - res, err := store.Read(ctx, workflowID) + res, err := store.Read(ctx, instanceID) if err != nil { logger. With(logKeyError, err.Error()). diff --git a/sdk/component/reporter_test.go b/sdk/component/reporter_test.go index b72c00221..8de3e7a9e 100644 --- a/sdk/component/reporter_test.go +++ b/sdk/component/reporter_test.go @@ -5,7 +5,6 @@ import ( "errors" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -18,7 +17,7 @@ import ( func runReporterHelper( t *testing.T, ctx context.Context, - workflowID uuid.UUID, + instanceID uuid.UUID, reporter component.Reporter, store component.Storer, ) error { @@ -29,7 +28,7 @@ func runReporterHelper( reporter, component.RunnerWithLogger(component.NewNoopLogger()), component.RunnerWithComponentName("sample-reporter"), - component.RunnerWithWorkflowID(workflowID), + component.RunnerWithInstanceID(instanceID), component.RunnerWithStorer("local", store), ) } @@ -37,7 +36,7 @@ func runReporterHelper( func TestRunReporter(t *testing.T) { var ( ctrl, ctx = gomock.WithContext(context.Background(), t) - workflowID = uuid.New() + instanceID = uuid.New() mockCtx = gomock.AssignableToTypeOf(ctx) mockStore = mocks.NewMockStorer(ctrl) mockReporter = mocks.NewMockReporter(ctrl) @@ -48,7 +47,7 @@ func TestRunReporter(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(vulns, nil), mockReporter. EXPECT(). @@ -60,7 +59,7 @@ func TestRunReporter(t *testing.T) { Return(nil), ) - require.NoError(t, runReporterHelper(t, ctx, workflowID, mockReporter, mockStore)) + require.NoError(t, runReporterHelper(t, ctx, instanceID, mockReporter, mockStore)) }) t.Run("it should return early when the context is cancelled", func(t *testing.T) { @@ -69,8 +68,8 @@ func TestRunReporter(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). - DoAndReturn(func(ctx context.Context, workflowID uuid.UUID) ([]*ocsf.VulnerabilityFinding, error) { + Read(mockCtx, instanceID). + DoAndReturn(func(ctx context.Context, instanceID uuid.UUID) ([]*ocsf.VulnerabilityFinding, error) { cancel() return vulns, nil }), @@ -87,7 +86,7 @@ func TestRunReporter(t *testing.T) { Return(nil), ) - require.NoError(t, runReporterHelper(t, ctx, workflowID, mockReporter, mockStore)) + require.NoError(t, runReporterHelper(t, ctx, instanceID, mockReporter, mockStore)) }) t.Run("it should return early when reading errors", func(t *testing.T) { @@ -96,7 +95,7 @@ func TestRunReporter(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(nil, errRead), mockStore. EXPECT(). @@ -104,9 +103,7 @@ func TestRunReporter(t *testing.T) { Return(nil), ) - err := runReporterHelper(t, ctx, workflowID, mockReporter, mockStore) - require.Error(t, err) - assert.ErrorIs(t, err, errRead) + require.ErrorIs(t, runReporterHelper(t, ctx, instanceID, mockReporter, mockStore), errRead) }) t.Run("it should return early when reporting errors", func(t *testing.T) { @@ -115,7 +112,7 @@ func TestRunReporter(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(vulns, nil), mockReporter. EXPECT(). @@ -127,9 +124,7 @@ func TestRunReporter(t *testing.T) { Return(nil), ) - err := runReporterHelper(t, ctx, workflowID, mockReporter, mockStore) - require.Error(t, err) - assert.ErrorIs(t, err, errReporting) + require.ErrorIs(t, runReporterHelper(t, ctx, instanceID, mockReporter, mockStore), errReporting) }) t.Run("it should return early when a panic is detected on reporting", func(t *testing.T) { @@ -138,7 +133,7 @@ func TestRunReporter(t *testing.T) { gomock.InOrder( mockStore. EXPECT(). - Read(mockCtx, workflowID). + Read(mockCtx, instanceID). Return(vulns, nil), mockReporter. EXPECT(). @@ -153,8 +148,6 @@ func TestRunReporter(t *testing.T) { Return(nil), ) - err := runReporterHelper(t, ctx, workflowID, mockReporter, mockStore) - require.Error(t, err) - assert.ErrorIs(t, err, errReporting) + require.ErrorIs(t, runReporterHelper(t, ctx, instanceID, mockReporter, mockStore), errReporting) }) } diff --git a/sdk/component/runner.go b/sdk/component/runner.go index 1c7029690..39d50d9b4 100644 --- a/sdk/component/runner.go +++ b/sdk/component/runner.go @@ -69,7 +69,7 @@ func run( logger = r. config.Logging.Logger. With(logKeySDKVersion, conf.SDKVersion). - With(logKeyWorkflowID, conf.WorkflowID.String()). + With(logKeyInstanceID, conf.InstanceID.String()). With(logKeyComponentName, conf.ComponentName) syncErrs = make(chan error, 1) ) diff --git a/sdk/component/scanner.go b/sdk/component/scanner.go index 6834676fa..f85d34998 100644 --- a/sdk/component/scanner.go +++ b/sdk/component/scanner.go @@ -11,7 +11,7 @@ func RunScanner(ctx context.Context, scanner Scanner, opts ...RunnerOption) erro ctx, func(ctx context.Context, cfg *RunnerConfig) error { var ( - workflowID = cfg.WorkflowID + instanceID = cfg.InstanceID logger = LoggerFromContext(ctx).With(logKeyComponentType, "scanner") store = cfg.storerConfig.store ) @@ -51,7 +51,7 @@ func RunScanner(ctx context.Context, scanner Scanner, opts ...RunnerOption) erro logger.Debug("validate step completed!") logger.Debug("preparing to execute store step...") - if err := store.Write(ctx, workflowID, rawFindings); err != nil { + if err := store.Write(ctx, instanceID, rawFindings); err != nil { logger. With(logKeyError, err.Error()). Debug("could not execute store step") diff --git a/sdk/component/scanner_test.go b/sdk/component/scanner_test.go index 491137217..a40f29fb8 100644 --- a/sdk/component/scanner_test.go +++ b/sdk/component/scanner_test.go @@ -5,7 +5,6 @@ import ( "errors" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -18,7 +17,7 @@ import ( func runScannerHelper( t *testing.T, ctx context.Context, - workflowID uuid.UUID, + instanceID uuid.UUID, reporter component.Scanner, storer component.Storer, ) error { @@ -29,7 +28,7 @@ func runScannerHelper( reporter, component.RunnerWithLogger(component.NewNoopLogger()), component.RunnerWithComponentName("sample-scanner"), - component.RunnerWithWorkflowID(workflowID), + component.RunnerWithInstanceID(instanceID), component.RunnerWithStorer("local", storer), ) } @@ -37,7 +36,7 @@ func runScannerHelper( func TestRunScanner(t *testing.T) { var ( ctrl, ctx = gomock.WithContext(context.Background(), t) - workflowID = uuid.New() + instanceID = uuid.New() mockCtx = gomock.AssignableToTypeOf(ctx) mockStore = mocks.NewMockStorer(ctrl) mockScanner = mocks.NewMockScanner(ctrl) @@ -60,7 +59,7 @@ func TestRunScanner(t *testing.T) { Return(nil), mockStore. EXPECT(). - Write(mockCtx, workflowID, vulns). + Write(mockCtx, instanceID, vulns). Return(nil), mockStore. EXPECT(). @@ -68,7 +67,7 @@ func TestRunScanner(t *testing.T) { Return(nil), ) - require.NoError(t, runScannerHelper(t, ctx, workflowID, mockScanner, mockStore)) + require.NoError(t, runScannerHelper(t, ctx, instanceID, mockScanner, mockStore)) }) t.Run("it should return early when the context is cancelled", func(t *testing.T) { @@ -92,11 +91,11 @@ func TestRunScanner(t *testing.T) { Return(nil), mockStore. EXPECT(). - Write(mockCtx, workflowID, vulns). + Write(mockCtx, instanceID, vulns). DoAndReturn( func( ctx context.Context, - workflowID uuid.UUID, + instanceID uuid.UUID, vulns []*ocsf.VulnerabilityFinding, ) error { <-ctx.Done() @@ -108,7 +107,7 @@ func TestRunScanner(t *testing.T) { Return(nil), ) - require.NoError(t, runScannerHelper(t, ctx, workflowID, mockScanner, mockStore)) + require.NoError(t, runScannerHelper(t, ctx, instanceID, mockScanner, mockStore)) }) t.Run("it should return early when transforming errors", func(t *testing.T) { @@ -125,9 +124,7 @@ func TestRunScanner(t *testing.T) { Return(nil), ) - err := runScannerHelper(t, ctx, workflowID, mockScanner, mockStore) - require.Error(t, err) - assert.ErrorIs(t, err, errTransform) + require.ErrorIs(t, runScannerHelper(t, ctx, instanceID, mockScanner, mockStore), errTransform) }) t.Run("it should return early when validation errors", func(t *testing.T) { @@ -148,9 +145,7 @@ func TestRunScanner(t *testing.T) { Return(nil), ) - err := runScannerHelper(t, ctx, workflowID, mockScanner, mockStore) - require.Error(t, err) - assert.ErrorIs(t, err, errValidate) + require.ErrorIs(t, runScannerHelper(t, ctx, instanceID, mockScanner, mockStore), errValidate) }) t.Run("it should return early when store errors", func(t *testing.T) { @@ -171,7 +166,7 @@ func TestRunScanner(t *testing.T) { Return(nil), mockStore. EXPECT(). - Write(mockCtx, workflowID, vulns). + Write(mockCtx, instanceID, vulns). Return(errStore), mockStore. EXPECT(). @@ -179,9 +174,7 @@ func TestRunScanner(t *testing.T) { Return(nil), ) - err := runScannerHelper(t, ctx, workflowID, mockScanner, mockStore) - require.Error(t, err) - assert.ErrorIs(t, err, errStore) + require.ErrorIs(t, runScannerHelper(t, ctx, instanceID, mockScanner, mockStore), errStore) }) t.Run("it should return early when a panic is detected on storing", func(t *testing.T) { @@ -202,11 +195,11 @@ func TestRunScanner(t *testing.T) { Return(nil), mockStore. EXPECT(). - Write(mockCtx, workflowID, vulns). + Write(mockCtx, instanceID, vulns). DoAndReturn( func( ctx context.Context, - workflowID uuid.UUID, + instanceID uuid.UUID, vulns []*ocsf.VulnerabilityFinding, ) error { panic(errStore) @@ -218,8 +211,6 @@ func TestRunScanner(t *testing.T) { Return(nil), ) - err := runScannerHelper(t, ctx, workflowID, mockScanner, mockStore) - require.Error(t, err) - assert.ErrorIs(t, err, errStore) + require.ErrorIs(t, runScannerHelper(t, ctx, instanceID, mockScanner, mockStore), errStore) }) } diff --git a/sdk/component/storer.go b/sdk/component/storer.go index ba2d821c4..8a4924c08 100644 --- a/sdk/component/storer.go +++ b/sdk/component/storer.go @@ -1,20 +1,26 @@ package component -import "github.com/smithy-security/smithy/sdk/component/internal/storer/local" +import ( + "fmt" + + "github.com/smithy-security/smithy/sdk/component/internal/storer/local/sqlite" +) type storeType string -const ( - storeTypeTest storeType = "test" - storeTypeLocal storeType = "local" -) +const storeTypeLocal storeType = "local" func isAllowedStoreType(st storeType) bool { return st == storeTypeLocal } -// newStore - TODO - implement in another PR. -func newStorer(storeType storeType) (Storer, error) { - localMgr, _ := local.NewStoreManager() - return localMgr, nil +func newStorer(conf runnerConfigStorer) (Storer, error) { + if conf.storeType == storeTypeLocal { + localMgr, err := sqlite.NewManager(conf.dbDSN) + if err != nil { + return nil, fmt.Errorf("unable to initialize local sqlite manager: %w", err) + } + return localMgr, nil + } + return nil, fmt.Errorf("curently unsupported store type: %s", conf.storeType) } diff --git a/sdk/component/target_test.go b/sdk/component/target_test.go index 160d2b965..b782a8a33 100644 --- a/sdk/component/target_test.go +++ b/sdk/component/target_test.go @@ -6,7 +6,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -16,7 +15,7 @@ import ( "github.com/smithy-security/smithy/sdk/component" ) -func runTargetHelper(t *testing.T, ctx context.Context, target component.Target) error { +func runTargetHelper(t *testing.T, ctx context.Context, target component.Target, store component.Storer) error { t.Helper() return component.RunTarget( @@ -24,7 +23,8 @@ func runTargetHelper(t *testing.T, ctx context.Context, target component.Target) target, component.RunnerWithLogger(component.NewNoopLogger()), component.RunnerWithComponentName("sample-target"), - component.RunnerWithWorkflowID(uuid.New()), + component.RunnerWithInstanceID(uuid.New()), + component.RunnerWithStorer("local", store), ) } @@ -33,6 +33,7 @@ func TestRunTarget(t *testing.T) { ctrl, ctx = gomock.WithContext(context.Background(), t) mockCtx = gomock.AssignableToTypeOf(ctx) mockTarget = mocks.NewMockTarget(ctrl) + mockStore = mocks.NewMockStorer(ctrl) ) t.Run("it should run a target correctly", func(t *testing.T) { @@ -40,7 +41,7 @@ func TestRunTarget(t *testing.T) { EXPECT(). Prepare(mockCtx). Return(nil) - require.NoError(t, runTargetHelper(t, ctx, mockTarget)) + require.NoError(t, runTargetHelper(t, ctx, mockTarget, mockStore)) }) t.Run("it should return early when the context is cancelled", func(t *testing.T) { @@ -55,7 +56,7 @@ func TestRunTarget(t *testing.T) { return nil }) - require.NoError(t, runTargetHelper(t, ctx, mockTarget)) + require.NoError(t, runTargetHelper(t, ctx, mockTarget, mockStore)) }) t.Run("it should return an error when prepare errors", func(t *testing.T) { @@ -66,9 +67,8 @@ func TestRunTarget(t *testing.T) { Prepare(mockCtx). Return(errPrepare) - err := runTargetHelper(t, ctx, mockTarget) - require.Error(t, err) - assert.ErrorIs(t, err, errPrepare) + err := runTargetHelper(t, ctx, mockTarget, mockStore) + require.ErrorIs(t, err, errPrepare) }) t.Run("it should return early an error when a panic is detected on prepare", func(t *testing.T) { @@ -82,8 +82,7 @@ func TestRunTarget(t *testing.T) { return nil }) - err := runTargetHelper(t, ctx, mockTarget) - require.Error(t, err) - assert.ErrorIs(t, err, errPrepare) + err := runTargetHelper(t, ctx, mockTarget, mockStore) + require.ErrorIs(t, err, errPrepare) }) } diff --git a/sdk/go.mod b/sdk/go.mod index 863ac098b..80e617399 100644 --- a/sdk/go.mod +++ b/sdk/go.mod @@ -4,6 +4,7 @@ go 1.23.0 require ( github.com/google/uuid v1.6.0 + github.com/jonboulle/clockwork v0.4.0 github.com/mattn/go-sqlite3 v1.14.24 github.com/stretchr/testify v1.9.0 go.uber.org/mock v0.5.0 diff --git a/sdk/go.sum b/sdk/go.sum index 39fa5d88f..39692e0c2 100644 --- a/sdk/go.sum +++ b/sdk/go.sum @@ -4,6 +4,8 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4= +github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc= github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/sdk/vendor/github.com/jonboulle/clockwork/.editorconfig b/sdk/vendor/github.com/jonboulle/clockwork/.editorconfig new file mode 100644 index 000000000..4492e9f9f --- /dev/null +++ b/sdk/vendor/github.com/jonboulle/clockwork/.editorconfig @@ -0,0 +1,12 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +indent_size = 4 +indent_style = space +insert_final_newline = true +trim_trailing_whitespace = true + +[*.go] +indent_style = tab diff --git a/sdk/vendor/github.com/jonboulle/clockwork/.gitignore b/sdk/vendor/github.com/jonboulle/clockwork/.gitignore new file mode 100644 index 000000000..00852bd94 --- /dev/null +++ b/sdk/vendor/github.com/jonboulle/clockwork/.gitignore @@ -0,0 +1,27 @@ +/.idea/ + +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test + +*.swp diff --git a/sdk/vendor/github.com/jonboulle/clockwork/LICENSE b/sdk/vendor/github.com/jonboulle/clockwork/LICENSE new file mode 100644 index 000000000..5c304d1a4 --- /dev/null +++ b/sdk/vendor/github.com/jonboulle/clockwork/LICENSE @@ -0,0 +1,201 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/sdk/vendor/github.com/jonboulle/clockwork/README.md b/sdk/vendor/github.com/jonboulle/clockwork/README.md new file mode 100644 index 000000000..42970da80 --- /dev/null +++ b/sdk/vendor/github.com/jonboulle/clockwork/README.md @@ -0,0 +1,80 @@ +# clockwork + +[![Mentioned in Awesome Go](https://awesome.re/mentioned-badge-flat.svg)](https://github.com/avelino/awesome-go#utilities) + +[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/jonboulle/clockwork/ci.yaml?style=flat-square)](https://github.com/jonboulle/clockwork/actions?query=workflow%3ACI) +[![Go Report Card](https://goreportcard.com/badge/github.com/jonboulle/clockwork?style=flat-square)](https://goreportcard.com/report/github.com/jonboulle/clockwork) +![Go Version](https://img.shields.io/badge/go%20version-%3E=1.15-61CFDD.svg?style=flat-square) +[![go.dev reference](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white&style=flat-square)](https://pkg.go.dev/mod/github.com/jonboulle/clockwork) + +**A simple fake clock for Go.** + + +## Usage + +Replace uses of the `time` package with the `clockwork.Clock` interface instead. + +For example, instead of using `time.Sleep` directly: + +```go +func myFunc() { + time.Sleep(3 * time.Second) + doSomething() +} +``` + +Inject a clock and use its `Sleep` method instead: + +```go +func myFunc(clock clockwork.Clock) { + clock.Sleep(3 * time.Second) + doSomething() +} +``` + +Now you can easily test `myFunc` with a `FakeClock`: + +```go +func TestMyFunc(t *testing.T) { + c := clockwork.NewFakeClock() + + // Start our sleepy function + var wg sync.WaitGroup + wg.Add(1) + go func() { + myFunc(c) + wg.Done() + }() + + // Ensure we wait until myFunc is sleeping + c.BlockUntil(1) + + assertState() + + // Advance the FakeClock forward in time + c.Advance(3 * time.Second) + + // Wait until the function completes + wg.Wait() + + assertState() +} +``` + +and in production builds, simply inject the real clock instead: + +```go +myFunc(clockwork.NewRealClock()) +``` + +See [example_test.go](example_test.go) for a full example. + + +# Credits + +clockwork is inspired by @wickman's [threaded fake clock](https://gist.github.com/wickman/3840816), and the [Golang playground](https://blog.golang.org/playground#TOC_3.1.) + + +## License + +Apache License, Version 2.0. Please see [License File](LICENSE) for more information. diff --git a/sdk/vendor/github.com/jonboulle/clockwork/clockwork.go b/sdk/vendor/github.com/jonboulle/clockwork/clockwork.go new file mode 100644 index 000000000..90dc9c82f --- /dev/null +++ b/sdk/vendor/github.com/jonboulle/clockwork/clockwork.go @@ -0,0 +1,349 @@ +package clockwork + +import ( + "context" + "sort" + "sync" + "time" +) + +// Clock provides an interface that packages can use instead of directly using +// the [time] module, so that chronology-related behavior can be tested. +type Clock interface { + After(d time.Duration) <-chan time.Time + Sleep(d time.Duration) + Now() time.Time + Since(t time.Time) time.Duration + NewTicker(d time.Duration) Ticker + NewTimer(d time.Duration) Timer + AfterFunc(d time.Duration, f func()) Timer +} + +// FakeClock provides an interface for a clock which can be manually advanced +// through time. +// +// FakeClock maintains a list of "waiters," which consists of all callers +// waiting on the underlying clock (i.e. Tickers and Timers including callers of +// Sleep or After). Users can call BlockUntil to block until the clock has an +// expected number of waiters. +type FakeClock interface { + Clock + // Advance advances the FakeClock to a new point in time, ensuring any existing + // waiters are notified appropriately before returning. + Advance(d time.Duration) + // BlockUntil blocks until the FakeClock has the given number of waiters. + BlockUntil(waiters int) +} + +// NewRealClock returns a Clock which simply delegates calls to the actual time +// package; it should be used by packages in production. +func NewRealClock() Clock { + return &realClock{} +} + +// NewFakeClock returns a FakeClock implementation which can be +// manually advanced through time for testing. The initial time of the +// FakeClock will be the current system time. +// +// Tests that require a deterministic time must use NewFakeClockAt. +func NewFakeClock() FakeClock { + return NewFakeClockAt(time.Now()) +} + +// NewFakeClockAt returns a FakeClock initialised at the given time.Time. +func NewFakeClockAt(t time.Time) FakeClock { + return &fakeClock{ + time: t, + } +} + +type realClock struct{} + +func (rc *realClock) After(d time.Duration) <-chan time.Time { + return time.After(d) +} + +func (rc *realClock) Sleep(d time.Duration) { + time.Sleep(d) +} + +func (rc *realClock) Now() time.Time { + return time.Now() +} + +func (rc *realClock) Since(t time.Time) time.Duration { + return rc.Now().Sub(t) +} + +func (rc *realClock) NewTicker(d time.Duration) Ticker { + return realTicker{time.NewTicker(d)} +} + +func (rc *realClock) NewTimer(d time.Duration) Timer { + return realTimer{time.NewTimer(d)} +} + +func (rc *realClock) AfterFunc(d time.Duration, f func()) Timer { + return realTimer{time.AfterFunc(d, f)} +} + +type fakeClock struct { + // l protects all attributes of the clock, including all attributes of all + // waiters and blockers. + l sync.RWMutex + waiters []expirer + blockers []*blocker + time time.Time +} + +// blocker is a caller of BlockUntil. +type blocker struct { + count int + + // ch is closed when the underlying clock has the specificed number of blockers. + ch chan struct{} +} + +// expirer is a timer or ticker that expires at some point in the future. +type expirer interface { + // expire the expirer at the given time, returning the desired duration until + // the next expiration, if any. + expire(now time.Time) (next *time.Duration) + + // Get and set the expiration time. + expiry() time.Time + setExpiry(time.Time) +} + +// After mimics [time.After]; it waits for the given duration to elapse on the +// fakeClock, then sends the current time on the returned channel. +func (fc *fakeClock) After(d time.Duration) <-chan time.Time { + return fc.NewTimer(d).Chan() +} + +// Sleep blocks until the given duration has passed on the fakeClock. +func (fc *fakeClock) Sleep(d time.Duration) { + <-fc.After(d) +} + +// Now returns the current time of the fakeClock +func (fc *fakeClock) Now() time.Time { + fc.l.RLock() + defer fc.l.RUnlock() + return fc.time +} + +// Since returns the duration that has passed since the given time on the +// fakeClock. +func (fc *fakeClock) Since(t time.Time) time.Duration { + return fc.Now().Sub(t) +} + +// NewTicker returns a Ticker that will expire only after calls to +// fakeClock.Advance() have moved the clock past the given duration. +func (fc *fakeClock) NewTicker(d time.Duration) Ticker { + var ft *fakeTicker + ft = &fakeTicker{ + firer: newFirer(), + d: d, + reset: func(d time.Duration) { fc.set(ft, d) }, + stop: func() { fc.stop(ft) }, + } + fc.set(ft, d) + return ft +} + +// NewTimer returns a Timer that will fire only after calls to +// fakeClock.Advance() have moved the clock past the given duration. +func (fc *fakeClock) NewTimer(d time.Duration) Timer { + return fc.newTimer(d, nil) +} + +// AfterFunc mimics [time.AfterFunc]; it returns a Timer that will invoke the +// given function only after calls to fakeClock.Advance() have moved the clock +// past the given duration. +func (fc *fakeClock) AfterFunc(d time.Duration, f func()) Timer { + return fc.newTimer(d, f) +} + +// newTimer returns a new timer, using an optional afterFunc. +func (fc *fakeClock) newTimer(d time.Duration, afterfunc func()) *fakeTimer { + var ft *fakeTimer + ft = &fakeTimer{ + firer: newFirer(), + reset: func(d time.Duration) bool { + fc.l.Lock() + defer fc.l.Unlock() + // fc.l must be held across the calls to stopExpirer & setExpirer. + stopped := fc.stopExpirer(ft) + fc.setExpirer(ft, d) + return stopped + }, + stop: func() bool { return fc.stop(ft) }, + + afterFunc: afterfunc, + } + fc.set(ft, d) + return ft +} + +// Advance advances fakeClock to a new point in time, ensuring waiters and +// blockers are notified appropriately before returning. +func (fc *fakeClock) Advance(d time.Duration) { + fc.l.Lock() + defer fc.l.Unlock() + end := fc.time.Add(d) + // Expire the earliest waiter until the earliest waiter's expiration is after + // end. + // + // We don't iterate because the callback of the waiter might register a new + // waiter, so the list of waiters might change as we execute this. + for len(fc.waiters) > 0 && !end.Before(fc.waiters[0].expiry()) { + w := fc.waiters[0] + fc.waiters = fc.waiters[1:] + + // Use the waiter's expriation as the current time for this expiration. + now := w.expiry() + fc.time = now + if d := w.expire(now); d != nil { + // Set the new exipration if needed. + fc.setExpirer(w, *d) + } + } + fc.time = end +} + +// BlockUntil blocks until the fakeClock has the given number of waiters. +// +// Prefer BlockUntilContext, which offers context cancellation to prevent +// deadlock. +// +// Deprecation warning: This function might be deprecated in later versions. +func (fc *fakeClock) BlockUntil(n int) { + b := fc.newBlocker(n) + if b == nil { + return + } + <-b.ch +} + +// BlockUntilContext blocks until the fakeClock has the given number of waiters +// or the context is cancelled. +func (fc *fakeClock) BlockUntilContext(ctx context.Context, n int) error { + b := fc.newBlocker(n) + if b == nil { + return nil + } + + select { + case <-b.ch: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (fc *fakeClock) newBlocker(n int) *blocker { + fc.l.Lock() + defer fc.l.Unlock() + // Fast path: we already have >= n waiters. + if len(fc.waiters) >= n { + return nil + } + // Set up a new blocker to wait for more waiters. + b := &blocker{ + count: n, + ch: make(chan struct{}), + } + fc.blockers = append(fc.blockers, b) + return b +} + +// stop stops an expirer, returning true if the expirer was stopped. +func (fc *fakeClock) stop(e expirer) bool { + fc.l.Lock() + defer fc.l.Unlock() + return fc.stopExpirer(e) +} + +// stopExpirer stops an expirer, returning true if the expirer was stopped. +// +// The caller must hold fc.l. +func (fc *fakeClock) stopExpirer(e expirer) bool { + for i, t := range fc.waiters { + if t == e { + // Remove element, maintaining order. + copy(fc.waiters[i:], fc.waiters[i+1:]) + fc.waiters[len(fc.waiters)-1] = nil + fc.waiters = fc.waiters[:len(fc.waiters)-1] + return true + } + } + return false +} + +// set sets an expirer to expire at a future point in time. +func (fc *fakeClock) set(e expirer, d time.Duration) { + fc.l.Lock() + defer fc.l.Unlock() + fc.setExpirer(e, d) +} + +// setExpirer sets an expirer to expire at a future point in time. +// +// The caller must hold fc.l. +func (fc *fakeClock) setExpirer(e expirer, d time.Duration) { + if d.Nanoseconds() <= 0 { + // special case - trigger immediately, never reset. + // + // TODO: Explain what cases this covers. + e.expire(fc.time) + return + } + // Add the expirer to the set of waiters and notify any blockers. + e.setExpiry(fc.time.Add(d)) + fc.waiters = append(fc.waiters, e) + sort.Slice(fc.waiters, func(i int, j int) bool { + return fc.waiters[i].expiry().Before(fc.waiters[j].expiry()) + }) + + // Notify blockers of our new waiter. + var blocked []*blocker + count := len(fc.waiters) + for _, b := range fc.blockers { + if b.count <= count { + close(b.ch) + continue + } + blocked = append(blocked, b) + } + fc.blockers = blocked +} + +// firer is used by fakeTimer and fakeTicker used to help implement expirer. +type firer struct { + // The channel associated with the firer, used to send expriation times. + c chan time.Time + + // The time when the firer expires. Only meaningful if the firer is currently + // one of a fakeClock's waiters. + exp time.Time +} + +func newFirer() firer { + return firer{c: make(chan time.Time, 1)} +} + +func (f *firer) Chan() <-chan time.Time { + return f.c +} + +// expiry implements expirer. +func (f *firer) expiry() time.Time { + return f.exp +} + +// setExpiry implements expirer. +func (f *firer) setExpiry(t time.Time) { + f.exp = t +} diff --git a/sdk/vendor/github.com/jonboulle/clockwork/context.go b/sdk/vendor/github.com/jonboulle/clockwork/context.go new file mode 100644 index 000000000..edbb368f0 --- /dev/null +++ b/sdk/vendor/github.com/jonboulle/clockwork/context.go @@ -0,0 +1,25 @@ +package clockwork + +import ( + "context" +) + +// contextKey is private to this package so we can ensure uniqueness here. This +// type identifies context values provided by this package. +type contextKey string + +// keyClock provides a clock for injecting during tests. If absent, a real clock should be used. +var keyClock = contextKey("clock") // clockwork.Clock + +// AddToContext creates a derived context that references the specified clock. +func AddToContext(ctx context.Context, clock Clock) context.Context { + return context.WithValue(ctx, keyClock, clock) +} + +// FromContext extracts a clock from the context. If not present, a real clock is returned. +func FromContext(ctx context.Context) Clock { + if clock, ok := ctx.Value(keyClock).(Clock); ok { + return clock + } + return NewRealClock() +} diff --git a/sdk/vendor/github.com/jonboulle/clockwork/ticker.go b/sdk/vendor/github.com/jonboulle/clockwork/ticker.go new file mode 100644 index 000000000..b68e4d777 --- /dev/null +++ b/sdk/vendor/github.com/jonboulle/clockwork/ticker.go @@ -0,0 +1,48 @@ +package clockwork + +import "time" + +// Ticker provides an interface which can be used instead of directly using +// [time.Ticker]. The real-time ticker t provides ticks through t.C which +// becomes t.Chan() to make this channel requirement definable in this +// interface. +type Ticker interface { + Chan() <-chan time.Time + Reset(d time.Duration) + Stop() +} + +type realTicker struct{ *time.Ticker } + +func (r realTicker) Chan() <-chan time.Time { + return r.C +} + +type fakeTicker struct { + firer + + // reset and stop provide the implementation of the respective exported + // functions. + reset func(d time.Duration) + stop func() + + // The duration of the ticker. + d time.Duration +} + +func (f *fakeTicker) Reset(d time.Duration) { + f.reset(d) +} + +func (f *fakeTicker) Stop() { + f.stop() +} + +func (f *fakeTicker) expire(now time.Time) *time.Duration { + // Never block on expiration. + select { + case f.c <- now: + default: + } + return &f.d +} diff --git a/sdk/vendor/github.com/jonboulle/clockwork/timer.go b/sdk/vendor/github.com/jonboulle/clockwork/timer.go new file mode 100644 index 000000000..6f928b3dd --- /dev/null +++ b/sdk/vendor/github.com/jonboulle/clockwork/timer.go @@ -0,0 +1,53 @@ +package clockwork + +import "time" + +// Timer provides an interface which can be used instead of directly using +// [time.Timer]. The real-time timer t provides events through t.C which becomes +// t.Chan() to make this channel requirement definable in this interface. +type Timer interface { + Chan() <-chan time.Time + Reset(d time.Duration) bool + Stop() bool +} + +type realTimer struct{ *time.Timer } + +func (r realTimer) Chan() <-chan time.Time { + return r.C +} + +type fakeTimer struct { + firer + + // reset and stop provide the implmenetation of the respective exported + // functions. + reset func(d time.Duration) bool + stop func() bool + + // If present when the timer fires, the timer calls afterFunc in its own + // goroutine rather than sending the time on Chan(). + afterFunc func() +} + +func (f *fakeTimer) Reset(d time.Duration) bool { + return f.reset(d) +} + +func (f *fakeTimer) Stop() bool { + return f.stop() +} + +func (f *fakeTimer) expire(now time.Time) *time.Duration { + if f.afterFunc != nil { + go f.afterFunc() + return nil + } + + // Never block on expiration. + select { + case f.c <- now: + default: + } + return nil +} diff --git a/sdk/vendor/github.com/stretchr/testify/suite/doc.go b/sdk/vendor/github.com/stretchr/testify/suite/doc.go new file mode 100644 index 000000000..8d55a3aa8 --- /dev/null +++ b/sdk/vendor/github.com/stretchr/testify/suite/doc.go @@ -0,0 +1,66 @@ +// Package suite contains logic for creating testing suite structs +// and running the methods on those structs as tests. The most useful +// piece of this package is that you can create setup/teardown methods +// on your testing suites, which will run before/after the whole suite +// or individual tests (depending on which interface(s) you +// implement). +// +// A testing suite is usually built by first extending the built-in +// suite functionality from suite.Suite in testify. Alternatively, +// you could reproduce that logic on your own if you wanted (you +// just need to implement the TestingSuite interface from +// suite/interfaces.go). +// +// After that, you can implement any of the interfaces in +// suite/interfaces.go to add setup/teardown functionality to your +// suite, and add any methods that start with "Test" to add tests. +// Methods that do not match any suite interfaces and do not begin +// with "Test" will not be run by testify, and can safely be used as +// helper methods. +// +// Once you've built your testing suite, you need to run the suite +// (using suite.Run from testify) inside any function that matches the +// identity that "go test" is already looking for (i.e. +// func(*testing.T)). +// +// Regular expression to select test suites specified command-line +// argument "-run". Regular expression to select the methods +// of test suites specified command-line argument "-m". +// Suite object has assertion methods. +// +// A crude example: +// +// // Basic imports +// import ( +// "testing" +// "github.com/stretchr/testify/assert" +// "github.com/stretchr/testify/suite" +// ) +// +// // Define the suite, and absorb the built-in basic suite +// // functionality from testify - including a T() method which +// // returns the current testing context +// type ExampleTestSuite struct { +// suite.Suite +// VariableThatShouldStartAtFive int +// } +// +// // Make sure that VariableThatShouldStartAtFive is set to five +// // before each test +// func (suite *ExampleTestSuite) SetupTest() { +// suite.VariableThatShouldStartAtFive = 5 +// } +// +// // All methods that begin with "Test" are run as tests within a +// // suite. +// func (suite *ExampleTestSuite) TestExample() { +// assert.Equal(suite.T(), 5, suite.VariableThatShouldStartAtFive) +// suite.Equal(5, suite.VariableThatShouldStartAtFive) +// } +// +// // In order for 'go test' to run this suite, we need to create +// // a normal test function and pass our suite to suite.Run +// func TestExampleTestSuite(t *testing.T) { +// suite.Run(t, new(ExampleTestSuite)) +// } +package suite diff --git a/sdk/vendor/github.com/stretchr/testify/suite/interfaces.go b/sdk/vendor/github.com/stretchr/testify/suite/interfaces.go new file mode 100644 index 000000000..fed037d7f --- /dev/null +++ b/sdk/vendor/github.com/stretchr/testify/suite/interfaces.go @@ -0,0 +1,66 @@ +package suite + +import "testing" + +// TestingSuite can store and return the current *testing.T context +// generated by 'go test'. +type TestingSuite interface { + T() *testing.T + SetT(*testing.T) + SetS(suite TestingSuite) +} + +// SetupAllSuite has a SetupSuite method, which will run before the +// tests in the suite are run. +type SetupAllSuite interface { + SetupSuite() +} + +// SetupTestSuite has a SetupTest method, which will run before each +// test in the suite. +type SetupTestSuite interface { + SetupTest() +} + +// TearDownAllSuite has a TearDownSuite method, which will run after +// all the tests in the suite have been run. +type TearDownAllSuite interface { + TearDownSuite() +} + +// TearDownTestSuite has a TearDownTest method, which will run after +// each test in the suite. +type TearDownTestSuite interface { + TearDownTest() +} + +// BeforeTest has a function to be executed right before the test +// starts and receives the suite and test names as input +type BeforeTest interface { + BeforeTest(suiteName, testName string) +} + +// AfterTest has a function to be executed right after the test +// finishes and receives the suite and test names as input +type AfterTest interface { + AfterTest(suiteName, testName string) +} + +// WithStats implements HandleStats, a function that will be executed +// when a test suite is finished. The stats contain information about +// the execution of that suite and its tests. +type WithStats interface { + HandleStats(suiteName string, stats *SuiteInformation) +} + +// SetupSubTest has a SetupSubTest method, which will run before each +// subtest in the suite. +type SetupSubTest interface { + SetupSubTest() +} + +// TearDownSubTest has a TearDownSubTest method, which will run after +// each subtest in the suite have been run. +type TearDownSubTest interface { + TearDownSubTest() +} diff --git a/sdk/vendor/github.com/stretchr/testify/suite/stats.go b/sdk/vendor/github.com/stretchr/testify/suite/stats.go new file mode 100644 index 000000000..261da37f7 --- /dev/null +++ b/sdk/vendor/github.com/stretchr/testify/suite/stats.go @@ -0,0 +1,46 @@ +package suite + +import "time" + +// SuiteInformation stats stores stats for the whole suite execution. +type SuiteInformation struct { + Start, End time.Time + TestStats map[string]*TestInformation +} + +// TestInformation stores information about the execution of each test. +type TestInformation struct { + TestName string + Start, End time.Time + Passed bool +} + +func newSuiteInformation() *SuiteInformation { + testStats := make(map[string]*TestInformation) + + return &SuiteInformation{ + TestStats: testStats, + } +} + +func (s SuiteInformation) start(testName string) { + s.TestStats[testName] = &TestInformation{ + TestName: testName, + Start: time.Now(), + } +} + +func (s SuiteInformation) end(testName string, passed bool) { + s.TestStats[testName].End = time.Now() + s.TestStats[testName].Passed = passed +} + +func (s SuiteInformation) Passed() bool { + for _, stats := range s.TestStats { + if !stats.Passed { + return false + } + } + + return true +} diff --git a/sdk/vendor/github.com/stretchr/testify/suite/suite.go b/sdk/vendor/github.com/stretchr/testify/suite/suite.go new file mode 100644 index 000000000..18443a91c --- /dev/null +++ b/sdk/vendor/github.com/stretchr/testify/suite/suite.go @@ -0,0 +1,253 @@ +package suite + +import ( + "flag" + "fmt" + "os" + "reflect" + "regexp" + "runtime/debug" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var allTestsFilter = func(_, _ string) (bool, error) { return true, nil } +var matchMethod = flag.String("testify.m", "", "regular expression to select tests of the testify suite to run") + +// Suite is a basic testing suite with methods for storing and +// retrieving the current *testing.T context. +type Suite struct { + *assert.Assertions + + mu sync.RWMutex + require *require.Assertions + t *testing.T + + // Parent suite to have access to the implemented methods of parent struct + s TestingSuite +} + +// T retrieves the current *testing.T context. +func (suite *Suite) T() *testing.T { + suite.mu.RLock() + defer suite.mu.RUnlock() + return suite.t +} + +// SetT sets the current *testing.T context. +func (suite *Suite) SetT(t *testing.T) { + suite.mu.Lock() + defer suite.mu.Unlock() + suite.t = t + suite.Assertions = assert.New(t) + suite.require = require.New(t) +} + +// SetS needs to set the current test suite as parent +// to get access to the parent methods +func (suite *Suite) SetS(s TestingSuite) { + suite.s = s +} + +// Require returns a require context for suite. +func (suite *Suite) Require() *require.Assertions { + suite.mu.Lock() + defer suite.mu.Unlock() + if suite.require == nil { + panic("'Require' must not be called before 'Run' or 'SetT'") + } + return suite.require +} + +// Assert returns an assert context for suite. Normally, you can call +// `suite.NoError(expected, actual)`, but for situations where the embedded +// methods are overridden (for example, you might want to override +// assert.Assertions with require.Assertions), this method is provided so you +// can call `suite.Assert().NoError()`. +func (suite *Suite) Assert() *assert.Assertions { + suite.mu.Lock() + defer suite.mu.Unlock() + if suite.Assertions == nil { + panic("'Assert' must not be called before 'Run' or 'SetT'") + } + return suite.Assertions +} + +func recoverAndFailOnPanic(t *testing.T) { + t.Helper() + r := recover() + failOnPanic(t, r) +} + +func failOnPanic(t *testing.T, r interface{}) { + t.Helper() + if r != nil { + t.Errorf("test panicked: %v\n%s", r, debug.Stack()) + t.FailNow() + } +} + +// Run provides suite functionality around golang subtests. It should be +// called in place of t.Run(name, func(t *testing.T)) in test suite code. +// The passed-in func will be executed as a subtest with a fresh instance of t. +// Provides compatibility with go test pkg -run TestSuite/TestName/SubTestName. +func (suite *Suite) Run(name string, subtest func()) bool { + oldT := suite.T() + + return oldT.Run(name, func(t *testing.T) { + suite.SetT(t) + defer suite.SetT(oldT) + + defer recoverAndFailOnPanic(t) + + if setupSubTest, ok := suite.s.(SetupSubTest); ok { + setupSubTest.SetupSubTest() + } + + if tearDownSubTest, ok := suite.s.(TearDownSubTest); ok { + defer tearDownSubTest.TearDownSubTest() + } + + subtest() + }) +} + +// Run takes a testing suite and runs all of the tests attached +// to it. +func Run(t *testing.T, suite TestingSuite) { + defer recoverAndFailOnPanic(t) + + suite.SetT(t) + suite.SetS(suite) + + var suiteSetupDone bool + + var stats *SuiteInformation + if _, ok := suite.(WithStats); ok { + stats = newSuiteInformation() + } + + tests := []testing.InternalTest{} + methodFinder := reflect.TypeOf(suite) + suiteName := methodFinder.Elem().Name() + + for i := 0; i < methodFinder.NumMethod(); i++ { + method := methodFinder.Method(i) + + ok, err := methodFilter(method.Name) + if err != nil { + fmt.Fprintf(os.Stderr, "testify: invalid regexp for -m: %s\n", err) + os.Exit(1) + } + + if !ok { + continue + } + + if !suiteSetupDone { + if stats != nil { + stats.Start = time.Now() + } + + if setupAllSuite, ok := suite.(SetupAllSuite); ok { + setupAllSuite.SetupSuite() + } + + suiteSetupDone = true + } + + test := testing.InternalTest{ + Name: method.Name, + F: func(t *testing.T) { + parentT := suite.T() + suite.SetT(t) + defer recoverAndFailOnPanic(t) + defer func() { + t.Helper() + + r := recover() + + if stats != nil { + passed := !t.Failed() && r == nil + stats.end(method.Name, passed) + } + + if afterTestSuite, ok := suite.(AfterTest); ok { + afterTestSuite.AfterTest(suiteName, method.Name) + } + + if tearDownTestSuite, ok := suite.(TearDownTestSuite); ok { + tearDownTestSuite.TearDownTest() + } + + suite.SetT(parentT) + failOnPanic(t, r) + }() + + if setupTestSuite, ok := suite.(SetupTestSuite); ok { + setupTestSuite.SetupTest() + } + if beforeTestSuite, ok := suite.(BeforeTest); ok { + beforeTestSuite.BeforeTest(methodFinder.Elem().Name(), method.Name) + } + + if stats != nil { + stats.start(method.Name) + } + + method.Func.Call([]reflect.Value{reflect.ValueOf(suite)}) + }, + } + tests = append(tests, test) + } + if suiteSetupDone { + defer func() { + if tearDownAllSuite, ok := suite.(TearDownAllSuite); ok { + tearDownAllSuite.TearDownSuite() + } + + if suiteWithStats, measureStats := suite.(WithStats); measureStats { + stats.End = time.Now() + suiteWithStats.HandleStats(suiteName, stats) + } + }() + } + + runTests(t, tests) +} + +// Filtering method according to set regular expression +// specified command-line argument -m +func methodFilter(name string) (bool, error) { + if ok, _ := regexp.MatchString("^Test", name); !ok { + return false, nil + } + return regexp.MatchString(*matchMethod, name) +} + +func runTests(t testing.TB, tests []testing.InternalTest) { + if len(tests) == 0 { + t.Log("warning: no tests to run") + return + } + + r, ok := t.(runner) + if !ok { // backwards compatibility with Go 1.6 and below + if !testing.RunTests(allTestsFilter, tests) { + t.Fail() + } + return + } + + for _, test := range tests { + r.Run(test.Name, test.F) + } +} + +type runner interface { + Run(name string, f func(t *testing.T)) bool +} diff --git a/sdk/vendor/modules.txt b/sdk/vendor/modules.txt index d24a562a4..7a724a430 100644 --- a/sdk/vendor/modules.txt +++ b/sdk/vendor/modules.txt @@ -4,6 +4,9 @@ github.com/davecgh/go-spew/spew # github.com/google/uuid v1.6.0 ## explicit github.com/google/uuid +# github.com/jonboulle/clockwork v0.4.0 +## explicit; go 1.15 +github.com/jonboulle/clockwork # github.com/mattn/go-sqlite3 v1.14.24 ## explicit; go 1.19 github.com/mattn/go-sqlite3 @@ -14,6 +17,7 @@ github.com/pmezard/go-difflib/difflib ## explicit; go 1.17 github.com/stretchr/testify/assert github.com/stretchr/testify/require +github.com/stretchr/testify/suite # go.uber.org/mock v0.5.0 ## explicit; go 1.22 go.uber.org/mock/gomock