Skip to content

Commit

Permalink
feat: warehouse transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
achettyiitr committed Jan 15, 2025
1 parent 5fd9fba commit ca8ec5a
Show file tree
Hide file tree
Showing 45 changed files with 12,959 additions and 11 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ require (
github.com/databricks/databricks-sql-go v1.6.1
github.com/denisenkom/go-mssqldb v0.12.3
github.com/dgraph-io/badger/v4 v4.5.0
github.com/dlclark/regexp2 v1.11.4
github.com/docker/docker v27.5.0+incompatible
github.com/go-chi/chi/v5 v5.2.0
github.com/go-redis/redis v6.15.9+incompatible
Expand Down Expand Up @@ -192,7 +193,6 @@ require (
github.com/danieljoos/wincred v1.2.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.4 // indirect
github.com/dnephin/pflag v1.0.7 // indirect
github.com/docker/cli v27.2.1+incompatible // indirect
github.com/docker/cli-docs-tool v0.8.0 // indirect
Expand Down
80 changes: 76 additions & 4 deletions processor/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@ import (
"encoding/json"
"errors"
"fmt"
"reflect"
"runtime/trace"
"slices"
"strconv"
"strings"
"sync"
"time"

obskit "github.com/rudderlabs/rudder-observability-kit/go/labels"

"github.com/google/uuid"

"github.com/rudderlabs/rudder-server/enterprise/trackedusers"
warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils"

"golang.org/x/sync/errgroup"

Expand Down Expand Up @@ -57,6 +61,7 @@ import (
. "github.com/rudderlabs/rudder-server/utils/tx" //nolint:staticcheck
"github.com/rudderlabs/rudder-server/utils/types"
"github.com/rudderlabs/rudder-server/utils/workerpool"
wtrans "github.com/rudderlabs/rudder-server/warehouse/transformer"
)

const (
Expand Down Expand Up @@ -86,10 +91,12 @@ type trackedUsersReporter interface {

// Handle is a handle to the processor module
type Handle struct {
conf *config.Config
tracer stats.Tracer
backendConfig backendconfig.BackendConfig
transformer transformer.Transformer
conf *config.Config
tracer stats.Tracer
backendConfig backendconfig.BackendConfig
transformer transformer.Transformer
warehouseTransformer transformer.DestinationTransformer
warehouseDebugLogger *wtrans.DebugLogger

gatewayDB jobsdb.JobsDB
routerDB jobsdb.JobsDB
Expand Down Expand Up @@ -159,6 +166,7 @@ type Handle struct {
eventAuditEnabled map[string]bool
credentialsMap map[string][]transformer.Credential
nonEventStreamSources map[string]bool
enableWarehouseTransformations config.ValueLoader[bool]
}

drainConfig struct {
Expand Down Expand Up @@ -618,6 +626,9 @@ func (proc *Handle) Setup(
"partition": partition,
})
}
proc.warehouseTransformer = wtrans.New(proc.conf, proc.logger, proc.statsFactory)
proc.warehouseDebugLogger = wtrans.NewDebugLogger(proc.conf, proc.logger)

if proc.config.enableDedup {
var err error
proc.dedup, err = dedup.New(proc.conf, proc.statsFactory)
Expand Down Expand Up @@ -819,6 +830,7 @@ func (proc *Handle) loadReloadableConfig(defaultPayloadLimit int64, defaultMaxEv
proc.config.archivalEnabled = config.GetReloadableBoolVar(true, "archival.Enabled")
// Capture event name as a tag in event level stats
proc.config.captureEventNameStats = config.GetReloadableBoolVar(false, "Processor.Stats.captureEventName")
proc.config.enableWarehouseTransformations = config.GetReloadableBoolVar(false, "Processor.enableWarehouseTransformations")
}

type connection struct {
Expand Down Expand Up @@ -3215,6 +3227,7 @@ func (proc *Handle) transformSrcDest(
proc.logger.Debug("Dest Transform input size", len(eventsToTransform))
s := time.Now()
response = proc.transformer.Transform(ctx, eventsToTransform, proc.config.transformBatchSize.Load())
proc.handleResponseForWarehouseTransformation(ctx, eventsToTransform, response, commonMetaData, eventsByMessageID)

destTransformationStat := proc.newDestinationTransformationStat(sourceID, workspaceID, transformAt, destination)
destTransformationStat.transformTime.Since(s)
Expand Down Expand Up @@ -3373,6 +3386,65 @@ func (proc *Handle) transformSrcDest(
}
}

func (proc *Handle) handleResponseForWarehouseTransformation(
ctx context.Context,
eventsToTransform []transformer.TransformerEvent,
pResponse transformer.Response,
commonMetaData *transformer.Metadata,
eventsByMessageID map[string]types.SingularEventWithReceivedAt,
) {
if _, ok := warehouseutils.WarehouseDestinationMap[commonMetaData.DestinationType]; !ok {
return
}
if len(eventsToTransform) == 0 || !proc.config.enableWarehouseTransformations.Load() {
return
}
defer proc.statsFactory.NewStat("proc_warehouse_transformations_time", stats.TimerType).RecordDuration()()

wResponse := proc.warehouseTransformer.Transform(ctx, eventsToTransform, proc.config.transformBatchSize.Load())
differingEvents := proc.responsesDiffer(eventsToTransform, pResponse, wResponse, eventsByMessageID)
if err := proc.warehouseDebugLogger.LogEvents(differingEvents, commonMetaData); err != nil {
proc.logger.Warnn("Failed to log events for warehouse transformation debugging", obskit.Error(err))
}
}

func (proc *Handle) responsesDiffer(
eventsToTransform []transformer.TransformerEvent,
pResponse, wResponse transformer.Response,
eventsByMessageID map[string]types.SingularEventWithReceivedAt,
) []types.SingularEventT {
// If the event counts differ, return all events in the transformation
if len(pResponse.Events) != len(wResponse.Events) || len(pResponse.FailedEvents) != len(wResponse.FailedEvents) {
events := lo.Map(eventsToTransform, func(e transformer.TransformerEvent, _ int) types.SingularEventT {
return eventsByMessageID[e.Metadata.MessageID].SingularEvent
})
proc.statsFactory.NewStat("proc_warehouse_transformations_mismatches", stats.CountType).Count(len(events))
return events
}

var (
differedSampleEvents []types.SingularEventT
differedEventsCount int
collectedSampleEvent bool
)

for i := range pResponse.Events {
if !reflect.DeepEqual(pResponse.Events[i], wResponse.Events[i]) {
differedEventsCount++
if !collectedSampleEvent {
// Collect the mismatched messages and break (sample only)
differedSampleEvents = append(differedSampleEvents, lo.Map(pResponse.Events[i].Metadata.GetMessagesIDs(), func(msgID string, _ int) types.SingularEventT {
return eventsByMessageID[msgID].SingularEvent
})...)
collectedSampleEvent = true
}
}
}
proc.statsFactory.NewStat("proc_warehouse_transformations_mismatches", stats.CountType).Count(differedEventsCount)

return differedSampleEvents
}

func (proc *Handle) saveDroppedJobs(ctx context.Context, droppedJobs []*jobsdb.JobT, tx *Tx) error {
if len(droppedJobs) > 0 {
for i := range droppedJobs { // each dropped job should have a unique jobID in the scope of the batch
Expand Down
20 changes: 16 additions & 4 deletions processor/transformer/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,25 @@ func WithClient(client HTTPDoer) Opt {
}
}

// Transformer provides methods to transform events
type Transformer interface {
Transform(ctx context.Context, clientEvents []TransformerEvent, batchSize int) Response
type UserTransformer interface {
UserTransform(ctx context.Context, clientEvents []TransformerEvent, batchSize int) Response
}

type DestinationTransformer interface {
Transform(ctx context.Context, clientEvents []TransformerEvent, batchSize int) Response
}

type TrackingPlanValidator interface {
Validate(ctx context.Context, clientEvents []TransformerEvent, batchSize int) Response
}

// Transformer provides methods to transform events
type Transformer interface {
UserTransformer
DestinationTransformer
TrackingPlanValidator
}

type HTTPDoer interface {
Do(req *http.Request) (*http.Response, error)
}
Expand Down Expand Up @@ -568,7 +580,7 @@ func (trans *handle) destTransformURL(destType string) string {
destinationEndPoint := fmt.Sprintf("%s/v0/destinations/%s", trans.config.destTransformationURL, strings.ToLower(destType))

if _, ok := warehouseutils.WarehouseDestinationMap[destType]; ok {
whSchemaVersionQueryParam := fmt.Sprintf("whSchemaVersion=%s&whIDResolve=%v", trans.conf.GetString("Warehouse.schemaVersion", "v1"), warehouseutils.IDResolutionEnabled())
whSchemaVersionQueryParam := fmt.Sprintf("whIDResolve=%t", trans.conf.GetBool("Warehouse.enableIDResolution", false))
switch destType {
case warehouseutils.RS:
return destinationEndPoint + "?" + whSchemaVersionQueryParam
Expand Down
3 changes: 2 additions & 1 deletion warehouse/internal/model/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ const (
JSONDataType SchemaType = "json"
TextDataType SchemaType = "text"
DateTimeDataType SchemaType = "datetime"
ArrayOfBooleanDatatype SchemaType = "array(boolean)"
ArrayDataType SchemaType = "array"
ArrayOfBooleanDataType SchemaType = "array(boolean)"
)

type WHSchema struct {
Expand Down
2 changes: 1 addition & 1 deletion warehouse/slave/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func (w *worker) processStagingFile(ctx context.Context, job payload) ([]uploadR
}

columnVal = newColumnVal
case model.ArrayOfBooleanDatatype:
case model.ArrayOfBooleanDataType:
if boolValue, ok := columnVal.([]interface{}); ok {
newColumnVal := make([]interface{}, len(boolValue))

Expand Down
81 changes: 81 additions & 0 deletions warehouse/transformer/debuglogger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package transformer

import (
"fmt"
"sync"

"github.com/google/uuid"
"github.com/samber/lo"

"github.com/rudderlabs/rudder-go-kit/config"
"github.com/rudderlabs/rudder-go-kit/logger"
"github.com/rudderlabs/rudder-go-kit/stringify"

ptrans "github.com/rudderlabs/rudder-server/processor/transformer"
"github.com/rudderlabs/rudder-server/utils/misc"
"github.com/rudderlabs/rudder-server/utils/types"
)

type DebugLogger struct {
logger logger.Logger
maxLoggedEvents config.ValueLoader[int]
eventLogMutex sync.Mutex
currentLogFileName string
loggedEvents int64
}

func NewDebugLogger(conf *config.Config, logger logger.Logger) *DebugLogger {
logFileName := generateLogFileName()

return &DebugLogger{
logger: logger.Child("debugLogger").With("currentLogFileName", logFileName),
maxLoggedEvents: conf.GetReloadableIntVar(10000, 1, "Processor.maxLoggedEvents"),
currentLogFileName: logFileName,
}
}

func generateLogFileName() string {
return fmt.Sprintf("warehouse_transformations_debug_%s.log", uuid.NewString())
}

func (d *DebugLogger) LogEvents(events []types.SingularEventT, commonMedata *ptrans.Metadata) error {
if len(events) == 0 {
return nil
}
d.eventLogMutex.Lock()
defer d.eventLogMutex.Unlock()

if d.loggedEvents >= int64(d.maxLoggedEvents.Load()) {
return nil
}

logEntries := lo.Map(events, func(item types.SingularEventT, index int) string {
return stringify.Any(ptrans.TransformerEvent{
Message: item,
Metadata: *commonMedata,
})
})

if err := d.writeLogEntries(logEntries); err != nil {
return fmt.Errorf("logging events: %w", err)
}

d.logger.Infon("Successfully logged events", logger.NewIntField("event_count", int64(len(logEntries))))
d.loggedEvents += int64(len(logEntries))
return nil
}

func (d *DebugLogger) writeLogEntries(entries []string) error {
writer, err := misc.CreateBufferedWriter(d.currentLogFileName)
if err != nil {
return fmt.Errorf("creating buffered writer: %w", err)
}
defer func() { _ = writer.Close() }()

for _, entry := range entries {
if _, err := writer.Write([]byte(entry + "\n")); err != nil {
return fmt.Errorf("writing log entry: %w", err)
}
}
return nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package reservedkeywords

import (
"embed"
"log"
"strings"

jsoniter "github.com/json-iterator/go"
"github.com/samber/lo"
)

var (
//go:embed reservedtablescolumns.json
tablesColumnsFile embed.FS

//go:embed reservednamespaces.json
namespacesFile embed.FS

reservedTablesColumns, reservedNamespaces map[string]map[string]struct{}

json = jsoniter.ConfigCompatibleWithStandardLibrary
)

func init() {
reservedTablesColumns = load(tablesColumnsFile, "reservedtablescolumns.json")
reservedNamespaces = load(namespacesFile, "reservednamespaces.json")
}

func load(file embed.FS, fileName string) map[string]map[string]struct{} {
data, err := file.ReadFile(fileName)
if err != nil {
log.Fatalf("failed to load reserved keywords from %s: %v", fileName, err)
}

var tempKeywords map[string][]string
if err := json.Unmarshal(data, &tempKeywords); err != nil {
log.Fatalf("failed to parse reserved keywords from %s: %v", fileName, err)
}

return lo.MapValues(tempKeywords, func(keywords []string, _ string) map[string]struct{} {
return lo.SliceToMap(keywords, func(k string) (string, struct{}) {
return strings.ToUpper(k), struct{}{}
})
})
}

// IsTableOrColumn checks if the given keyword is a reserved table/column keyword for the destination type.
func IsTableOrColumn(destType, keyword string) bool {
return isKeywordReserved(reservedTablesColumns, destType, keyword)
}

// IsNamespace checks if the given keyword is a reserved namespace keyword for the destination type.
func IsNamespace(destType, keyword string) bool {
return isKeywordReserved(reservedNamespaces, destType, keyword)
}

func isKeywordReserved(keywords map[string]map[string]struct{}, destType, keyword string) bool {
_, exists := keywords[destType][strings.ToUpper(keyword)]
return exists
}
Loading

0 comments on commit ca8ec5a

Please sign in to comment.