Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: replace schema struct with interface to enable flexible implementation #5426

Merged
merged 3 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion warehouse/router/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ type UploadJob struct {
stagingFileRepo *repo.StagingFiles
loadFilesRepo *repo.LoadFiles
whManager manager.Manager
schemaHandle *schema.Schema
schemaHandle schema.Handler
conf *config.Config
logger logger.Logger
statsFactory stats.Stats
Expand Down
28 changes: 20 additions & 8 deletions warehouse/router/upload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ import (

backendconfig "github.com/rudderlabs/rudder-server/backend-config"
"github.com/rudderlabs/rudder-server/services/alerta"
"github.com/rudderlabs/rudder-server/warehouse/integrations/manager"
sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper"
"github.com/rudderlabs/rudder-server/warehouse/integrations/redshift"
"github.com/rudderlabs/rudder-server/warehouse/internal/model"
"github.com/rudderlabs/rudder-server/warehouse/schema"
warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils"
)

Expand Down Expand Up @@ -127,17 +127,31 @@ func TestColumnCountStat(t *testing.T) {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()
conf := config.New()
conf.Set(fmt.Sprintf("Warehouse.%s.columnCountLimit", strings.ToLower(warehouseutils.WHDestNameMap[tc.destinationType])), tc.columnCountLimit)

j := UploadJob{
conf: conf,
upload: model.Upload{
pool, err := dockertest.NewPool("")
require.NoError(t, err)

pgResource, err := postgres.Setup(pool, t)
require.NoError(t, err)

uploadJobFactory := &UploadJobFactory{
logger: logger.NOP,
statsFactory: statsStore,
conf: conf,
db: sqlmiddleware.New(pgResource.DB),
}
whManager, err := manager.New(warehouseutils.POSTGRES, conf, logger.NOP, statsStore)
require.NoError(t, err)
j := uploadJobFactory.NewUploadJob(context.Background(), &model.UploadJob{
Upload: model.Upload{
WorkspaceID: workspaceID,
DestinationID: destinationID,
SourceID: sourceID,
},
warehouse: model.Warehouse{
Warehouse: model.Warehouse{
Type: tc.destinationType,
Destination: backendconfig.DestinationT{
ID: destinationID,
Expand All @@ -148,9 +162,7 @@ func TestColumnCountStat(t *testing.T) {
Name: sourceName,
},
},
statsFactory: statsStore,
schemaHandle: &schema.Schema{}, // TODO use constructor
}
}, whManager)
j.schemaHandle.UpdateWarehouseTableSchema(tableName, model.TableSchema{
"test-column-1": "string",
"test-column-2": "string",
Expand Down
50 changes: 32 additions & 18 deletions warehouse/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,21 @@ type fetchSchemaRepo interface {
FetchSchema(ctx context.Context) (model.Schema, error)
}

type Schema struct {
type Handler interface {
SyncRemoteSchema(ctx context.Context, fetchSchemaRepo fetchSchemaRepo, uploadID int64) (bool, error)
IsWarehouseSchemaEmpty() bool
GetTableSchemaInWarehouse(tableName string) model.TableSchema
GetLocalSchema(ctx context.Context) (model.Schema, error)
UpdateLocalSchema(ctx context.Context, updatedSchema model.Schema) error
UpdateWarehouseTableSchema(tableName string, tableSchema model.TableSchema)
GetColumnsCountInWarehouseSchema(tableName string) int
RanjeetMishra marked this conversation as resolved.
Show resolved Hide resolved
ConsolidateStagingFilesUsingLocalSchema(ctx context.Context, stagingFiles []*model.StagingFile) (model.Schema, error)
UpdateLocalSchemaWithWarehouse(ctx context.Context) error
TableSchemaDiff(tableName string, tableSchema model.TableSchema) whutils.TableSchemaDiff
FetchSchemaFromWarehouse(ctx context.Context, repo fetchSchemaRepo) error
}

type schema struct {
warehouse model.Warehouse
schemaRepo schemaRepo
stagingFileRepo stagingFileRepo
Expand All @@ -69,8 +83,8 @@ func New(
conf *config.Config,
logger logger.Logger,
statsFactory stats.Stats,
) *Schema {
s := &Schema{
) Handler {
s := &schema{
warehouse: warehouse,
schemaRepo: repo.NewWHSchemas(db),
stagingFileRepo: repo.NewStagingFiles(db),
Expand All @@ -95,7 +109,7 @@ func New(
// 4. Enhances the consolidated schema with discards schema
// 5. Enhances the consolidated schema with ID resolution schema
// 6. Returns the consolidated schema
func (sh *Schema) ConsolidateStagingFilesUsingLocalSchema(ctx context.Context, stagingFiles []*model.StagingFile) (model.Schema, error) {
func (sh *schema) ConsolidateStagingFilesUsingLocalSchema(ctx context.Context, stagingFiles []*model.StagingFile) (model.Schema, error) {
consolidatedSchema := model.Schema{}
batches := lo.Chunk(stagingFiles, sh.stagingFilesSchemaPaginationSize)
for _, batch := range batches {
Expand Down Expand Up @@ -244,24 +258,24 @@ func enhanceSchemaWithIDResolution(consolidatedSchema model.Schema, isIDResoluti
return consolidatedSchema
}

func (sh *Schema) isIDResolutionEnabled() bool {
func (sh *schema) isIDResolutionEnabled() bool {
return sh.enableIDResolution && slices.Contains(whutils.IdentityEnabledWarehouses, sh.warehouse.Type)
}

func (sh *Schema) UpdateLocalSchemaWithWarehouse(ctx context.Context) error {
func (sh *schema) UpdateLocalSchemaWithWarehouse(ctx context.Context) error {
sh.schemaInWarehouseMu.RLock()
defer sh.schemaInWarehouseMu.RUnlock()
return sh.updateLocalSchema(ctx, sh.schemaInWarehouse)
}

func (sh *Schema) UpdateLocalSchema(ctx context.Context, updatedSchema model.Schema) error {
func (sh *schema) UpdateLocalSchema(ctx context.Context, updatedSchema model.Schema) error {
return sh.updateLocalSchema(ctx, updatedSchema)
}

// updateLocalSchema
// 1. Inserts the updated schema into the local schema table
// 2. Updates the local schema instance
func (sh *Schema) updateLocalSchema(ctx context.Context, updatedSchema model.Schema) error {
func (sh *schema) updateLocalSchema(ctx context.Context, updatedSchema model.Schema) error {
updatedSchemaInBytes, err := json.Marshal(updatedSchema)
if err != nil {
return fmt.Errorf("marshaling schema: %w", err)
Expand Down Expand Up @@ -292,7 +306,7 @@ func (sh *Schema) updateLocalSchema(ctx context.Context, updatedSchema model.Sch
// 3. Initialize local schema
// 4. Updates local schema with warehouse schema if it has changed
// 5. Returns true if schema has changed
func (sh *Schema) SyncRemoteSchema(ctx context.Context, fetchSchemaRepo fetchSchemaRepo, uploadID int64) (bool, error) {
func (sh *schema) SyncRemoteSchema(ctx context.Context, fetchSchemaRepo fetchSchemaRepo, uploadID int64) (bool, error) {
localSchema, err := sh.GetLocalSchema(ctx)
if err != nil {
return false, fmt.Errorf("fetching schema from local: %w", err)
Expand Down Expand Up @@ -321,7 +335,7 @@ func (sh *Schema) SyncRemoteSchema(ctx context.Context, fetchSchemaRepo fetchSch
}

// GetLocalSchema returns the local schema from wh_schemas table
func (sh *Schema) GetLocalSchema(ctx context.Context) (model.Schema, error) {
func (sh *schema) GetLocalSchema(ctx context.Context) (model.Schema, error) {
whSchema, err := sh.schemaRepo.GetForNamespace(
ctx,
sh.warehouse.Source.ID,
Expand All @@ -341,7 +355,7 @@ func (sh *Schema) GetLocalSchema(ctx context.Context) (model.Schema, error) {
// 1. Fetches schema from warehouse
// 2. Removes deprecated columns from schema
// 3. Updates local warehouse schema and unrecognized schema instance
func (sh *Schema) FetchSchemaFromWarehouse(ctx context.Context, repo fetchSchemaRepo) error {
func (sh *schema) FetchSchemaFromWarehouse(ctx context.Context, repo fetchSchemaRepo) error {
warehouseSchema, err := repo.FetchSchema(ctx)
if err != nil {
return fmt.Errorf("fetching schema: %w", err)
Expand All @@ -356,7 +370,7 @@ func (sh *Schema) FetchSchemaFromWarehouse(ctx context.Context, repo fetchSchema
}

// removeDeprecatedColumns skips deprecated columns from the schema map
func (sh *Schema) removeDeprecatedColumns(schema model.Schema) {
func (sh *schema) removeDeprecatedColumns(schema model.Schema) {
for tableName, columnMap := range schema {
for columnName := range columnMap {
if deprecatedColumnsRegex.MatchString(columnName) {
Expand All @@ -376,12 +390,12 @@ func (sh *Schema) removeDeprecatedColumns(schema model.Schema) {
}

// hasSchemaChanged compares the localSchema with the schemaInWarehouse
func (sh *Schema) hasSchemaChanged(localSchema model.Schema) bool {
func (sh *schema) hasSchemaChanged(localSchema model.Schema) bool {
return !reflect.DeepEqual(localSchema, sh.schemaInWarehouse)
}

// TableSchemaDiff returns the diff between the warehouse schema and the upload schema
func (sh *Schema) TableSchemaDiff(tableName string, tableSchema model.TableSchema) whutils.TableSchemaDiff {
func (sh *schema) TableSchemaDiff(tableName string, tableSchema model.TableSchema) whutils.TableSchemaDiff {
diff := whutils.TableSchemaDiff{
ColumnMap: make(model.TableSchema),
UpdatedSchema: make(model.TableSchema),
Expand Down Expand Up @@ -422,13 +436,13 @@ func (sh *Schema) TableSchemaDiff(tableName string, tableSchema model.TableSchem
return diff
}

func (sh *Schema) GetTableSchemaInWarehouse(tableName string) model.TableSchema {
func (sh *schema) GetTableSchemaInWarehouse(tableName string) model.TableSchema {
sh.schemaInWarehouseMu.RLock()
defer sh.schemaInWarehouseMu.RUnlock()
return sh.schemaInWarehouse[tableName]
}

func (sh *Schema) UpdateWarehouseTableSchema(tableName string, tableSchema model.TableSchema) {
func (sh *schema) UpdateWarehouseTableSchema(tableName string, tableSchema model.TableSchema) {
sh.schemaInWarehouseMu.Lock()
defer sh.schemaInWarehouseMu.Unlock()
if sh.schemaInWarehouse == nil {
Expand All @@ -437,13 +451,13 @@ func (sh *Schema) UpdateWarehouseTableSchema(tableName string, tableSchema model
sh.schemaInWarehouse[tableName] = tableSchema
}

func (sh *Schema) IsWarehouseSchemaEmpty() bool {
func (sh *schema) IsWarehouseSchemaEmpty() bool {
sh.schemaInWarehouseMu.RLock()
defer sh.schemaInWarehouseMu.RUnlock()
return len(sh.schemaInWarehouse) == 0
}

func (sh *Schema) GetColumnsCountInWarehouseSchema(tableName string) int {
func (sh *schema) GetColumnsCountInWarehouseSchema(tableName string) int {
sh.schemaInWarehouseMu.RLock()
defer sh.schemaInWarehouseMu.RUnlock()
return len(sh.schemaInWarehouse[tableName])
Expand Down
18 changes: 9 additions & 9 deletions warehouse/schema/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func TestSchema_UpdateLocalSchema(t *testing.T) {
statsStore, err := memstats.New()
require.NoError(t, err)

s := Schema{
s := schema{
warehouse: model.Warehouse{
WorkspaceID: workspaceID,
Source: backendconfig.SourceT{
Expand Down Expand Up @@ -352,7 +352,7 @@ func TestSchema_FetchSchemaFromWarehouse(t *testing.T) {
err: tc.mockErr,
}

s := &Schema{
s := &schema{
warehouse: model.Warehouse{
Source: backendconfig.SourceT{
ID: sourceID,
Expand Down Expand Up @@ -511,7 +511,7 @@ func TestSchema_TableSchemaDiff(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

s := Schema{
s := schema{
schemaInWarehouse: tc.currentSchema,
}
diff := s.TableSchemaDiff(tc.tableName, tc.uploadTableSchema)
Expand Down Expand Up @@ -592,7 +592,7 @@ func TestSchema_HasLocalSchemaChanged(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

s := &Schema{
s := &schema{
warehouse: model.Warehouse{
Type: warehouseutils.SNOWFLAKE,
},
Expand Down Expand Up @@ -1625,7 +1625,7 @@ func TestSchema_ConsolidateStagingFilesUsingLocalSchema(t *testing.T) {
err: tc.mockErr,
}

s := &Schema{
s := &schema{
warehouse: model.Warehouse{
Source: backendconfig.SourceT{
ID: sourceID,
Expand Down Expand Up @@ -1668,7 +1668,7 @@ func TestSchema_SyncRemoteSchema(t *testing.T) {
tableName := "test_table_name"

t.Run("should return error if unable to fetch local schema", func(t *testing.T) {
s := &Schema{
s := &schema{
warehouse: model.Warehouse{
Source: backendconfig.SourceT{
ID: sourceID,
Expand Down Expand Up @@ -1697,7 +1697,7 @@ func TestSchema_SyncRemoteSchema(t *testing.T) {
require.False(t, schemaChanged)
})
t.Run("should return error if unable to fetch remote schema", func(t *testing.T) {
s := &Schema{
s := &schema{
warehouse: model.Warehouse{
Source: backendconfig.SourceT{
ID: sourceID,
Expand Down Expand Up @@ -1766,7 +1766,7 @@ func TestSchema_SyncRemoteSchema(t *testing.T) {
},
}

s := &Schema{
s := &schema{
warehouse: model.Warehouse{
Source: backendconfig.SourceT{
ID: sourceID,
Expand Down Expand Up @@ -1835,7 +1835,7 @@ func TestSchema_SyncRemoteSchema(t *testing.T) {
},
}

s := &Schema{
s := &schema{
warehouse: model.Warehouse{
Source: backendconfig.SourceT{
ID: sourceID,
Expand Down
Loading