From 654967d9bba879c513b2470faab2ead8e45085a4 Mon Sep 17 00:00:00 2001 From: Berezhnoy Pavel <pberejnoy2005@gmail.com> Date: Thu, 30 Jan 2025 18:43:58 +0300 Subject: [PATCH] sourceWithEncoding middleware works impropertly in combination with sourceWithBatch (#248) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ReadN method implementation added to sourceWithEncoding middleware for correct integration with sourceWithBatch middleware * gofumpt -w source_middleware_test.go --------- Co-authored-by: Pavel Berezhnoy <p.berezhnoy@corp.mail.ru> Co-authored-by: Lovro Mažgon <lovro.mazgon@gmail.com> --- source_middleware.go | 31 ++++++++-- source_middleware_test.go | 124 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 4 deletions(-) diff --git a/source_middleware.go b/source_middleware.go index 1aca5c7c..046c435b 100644 --- a/source_middleware.go +++ b/source_middleware.go @@ -433,16 +433,39 @@ func (s *sourceWithEncoding) Read(ctx context.Context) (opencdc.Record, error) { return rec, err } - if err := s.encodeKey(ctx, &rec); err != nil { - return rec, err - } - if err := s.encodePayload(ctx, &rec); err != nil { + if err := s.encode(ctx, &rec); err != nil { return rec, err } return rec, nil } +func (s *sourceWithEncoding) ReadN(ctx context.Context, n int) ([]opencdc.Record, error) { + recs, err := s.Source.ReadN(ctx, n) + if err != nil { + return recs, err + } + + for i := range recs { + if err := s.encode(ctx, &recs[i]); err != nil { + return recs, fmt.Errorf("unable to encode record %d: %w", i, err) + } + } + + return recs, nil +} + +func (s *sourceWithEncoding) encode(ctx context.Context, rec *opencdc.Record) error { + if err := s.encodeKey(ctx, rec); err != nil { + return err + } + if err := s.encodePayload(ctx, rec); err != nil { + return err + } + + return nil +} + func (s *sourceWithEncoding) encodeKey(ctx context.Context, rec *opencdc.Record) error { if _, ok := rec.Key.(opencdc.StructuredData); !ok { // log warning once, to avoid spamming the logs diff --git a/source_middleware_test.go b/source_middleware_test.go index 8bdd6505..129ff8ac 100644 --- a/source_middleware_test.go +++ b/source_middleware_test.go @@ -637,6 +637,130 @@ func TestSourceWithEncoding_Read(t *testing.T) { } } +func TestSourceWithEncoding_ReadN(t *testing.T) { + is := is.New(t) + ctx := context.Background() + + testDataStruct := opencdc.StructuredData{ + "foo": "bar", + "long": int64(1), + "float": 2.34, + "time": time.Now().UTC().Truncate(time.Microsecond), // avro precision is microseconds + } + wantSchema := `{"name":"record","type":"record","fields":[{"name":"float","type":"double"},{"name":"foo","type":"string"},{"name":"long","type":"long"},{"name":"time","type":{"type":"long","logicalType":"timestamp-micros"}}]}` + + customTestSchema, err := schema.Create(ctx, schema.TypeAvro, "custom-test-schema", []byte(wantSchema)) + is.NoErr(err) + + bytes, err := customTestSchema.Marshal(testDataStruct) + is.NoErr(err) + testDataRaw := opencdc.RawData(bytes) + + testCases := []struct { + name string + inputRecs []opencdc.Record + wantRecs []opencdc.Record + }{{ + name: "no records returned", + }, { + name: "single record returned", + inputRecs: []opencdc.Record{ + { + Key: testDataStruct.Clone(), + Payload: opencdc.Change{ + Before: testDataStruct.Clone(), + After: testDataStruct.Clone(), + }, + }, + }, + wantRecs: []opencdc.Record{ + { + Key: testDataRaw, + Payload: opencdc.Change{ + Before: testDataRaw, + After: testDataRaw, + }, + }, + }, + }, { + name: "multiple records returned", + inputRecs: []opencdc.Record{{ + Key: testDataStruct.Clone(), + Payload: opencdc.Change{ + Before: testDataStruct.Clone(), + After: testDataStruct.Clone(), + }, + }, { + Key: testDataStruct.Clone(), + Payload: opencdc.Change{ + Before: testDataStruct.Clone(), + After: testDataStruct.Clone(), + }, + }, { + Key: testDataStruct.Clone(), + Payload: opencdc.Change{ + Before: testDataStruct.Clone(), + After: testDataStruct.Clone(), + }, + }}, + wantRecs: []opencdc.Record{{ + Key: testDataRaw, + Payload: opencdc.Change{ + Before: testDataRaw, + After: testDataRaw, + }, + }, { + Key: testDataRaw, + Payload: opencdc.Change{ + Before: testDataRaw, + After: testDataRaw, + }, + }, { + Key: testDataRaw, + Payload: opencdc.Change{ + Before: testDataRaw, + After: testDataRaw, + }, + }}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + is := is.New(t) + src := NewMockSource(gomock.NewController(t)) + + underTest := (&SourceWithEncoding{}).Wrap(src) + + for i := range tc.inputRecs { + tc.inputRecs[i].Metadata = map[string]string{ + opencdc.MetadataCollection: "foo", + opencdc.MetadataKeySchemaSubject: customTestSchema.Subject, + opencdc.MetadataKeySchemaVersion: strconv.Itoa(customTestSchema.Version), + opencdc.MetadataPayloadSchemaSubject: customTestSchema.Subject, + opencdc.MetadataPayloadSchemaVersion: strconv.Itoa(customTestSchema.Version), + } + } + + src.EXPECT().ReadN(ctx, 100).Return(tc.inputRecs, nil) + + got, err := underTest.ReadN(ctx, 100) + is.NoErr(err) + + is.Equal(len(got), len(tc.wantRecs)) + + for i := range got { + gotKey := got[i].Key + gotPayloadBefore := got[i].Payload.Before + gotPayloadAfter := got[i].Payload.After + + is.Equal("", cmp.Diff(tc.wantRecs[i].Key, gotKey)) + is.Equal("", cmp.Diff(tc.wantRecs[i].Payload.Before, gotPayloadBefore)) + is.Equal("", cmp.Diff(tc.wantRecs[i].Payload.After, gotPayloadAfter)) + } + }) + } +} + // -- SourceWithBatch -------------------------------------------------- func TestSourceWithBatch_ReadN(t *testing.T) {