From 654967d9bba879c513b2470faab2ead8e45085a4 Mon Sep 17 00:00:00 2001
From: Berezhnoy Pavel <pberejnoy2005@gmail.com>
Date: Thu, 30 Jan 2025 18:43:58 +0300
Subject: [PATCH] sourceWithEncoding middleware works impropertly in
 combination with sourceWithBatch (#248)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* ReadN method implementation added to sourceWithEncoding middleware for correct integration with sourceWithBatch middleware

* gofumpt -w source_middleware_test.go

---------

Co-authored-by: Pavel Berezhnoy <p.berezhnoy@corp.mail.ru>
Co-authored-by: Lovro Mažgon <lovro.mazgon@gmail.com>
---
 source_middleware.go      |  31 ++++++++--
 source_middleware_test.go | 124 ++++++++++++++++++++++++++++++++++++++
 2 files changed, 151 insertions(+), 4 deletions(-)

diff --git a/source_middleware.go b/source_middleware.go
index 1aca5c7c..046c435b 100644
--- a/source_middleware.go
+++ b/source_middleware.go
@@ -433,16 +433,39 @@ func (s *sourceWithEncoding) Read(ctx context.Context) (opencdc.Record, error) {
 		return rec, err
 	}
 
-	if err := s.encodeKey(ctx, &rec); err != nil {
-		return rec, err
-	}
-	if err := s.encodePayload(ctx, &rec); err != nil {
+	if err := s.encode(ctx, &rec); err != nil {
 		return rec, err
 	}
 
 	return rec, nil
 }
 
+func (s *sourceWithEncoding) ReadN(ctx context.Context, n int) ([]opencdc.Record, error) {
+	recs, err := s.Source.ReadN(ctx, n)
+	if err != nil {
+		return recs, err
+	}
+
+	for i := range recs {
+		if err := s.encode(ctx, &recs[i]); err != nil {
+			return recs, fmt.Errorf("unable to encode record %d: %w", i, err)
+		}
+	}
+
+	return recs, nil
+}
+
+func (s *sourceWithEncoding) encode(ctx context.Context, rec *opencdc.Record) error {
+	if err := s.encodeKey(ctx, rec); err != nil {
+		return err
+	}
+	if err := s.encodePayload(ctx, rec); err != nil {
+		return err
+	}
+
+	return nil
+}
+
 func (s *sourceWithEncoding) encodeKey(ctx context.Context, rec *opencdc.Record) error {
 	if _, ok := rec.Key.(opencdc.StructuredData); !ok {
 		// log warning once, to avoid spamming the logs
diff --git a/source_middleware_test.go b/source_middleware_test.go
index 8bdd6505..129ff8ac 100644
--- a/source_middleware_test.go
+++ b/source_middleware_test.go
@@ -637,6 +637,130 @@ func TestSourceWithEncoding_Read(t *testing.T) {
 	}
 }
 
+func TestSourceWithEncoding_ReadN(t *testing.T) {
+	is := is.New(t)
+	ctx := context.Background()
+
+	testDataStruct := opencdc.StructuredData{
+		"foo":   "bar",
+		"long":  int64(1),
+		"float": 2.34,
+		"time":  time.Now().UTC().Truncate(time.Microsecond), // avro precision is microseconds
+	}
+	wantSchema := `{"name":"record","type":"record","fields":[{"name":"float","type":"double"},{"name":"foo","type":"string"},{"name":"long","type":"long"},{"name":"time","type":{"type":"long","logicalType":"timestamp-micros"}}]}`
+
+	customTestSchema, err := schema.Create(ctx, schema.TypeAvro, "custom-test-schema", []byte(wantSchema))
+	is.NoErr(err)
+
+	bytes, err := customTestSchema.Marshal(testDataStruct)
+	is.NoErr(err)
+	testDataRaw := opencdc.RawData(bytes)
+
+	testCases := []struct {
+		name      string
+		inputRecs []opencdc.Record
+		wantRecs  []opencdc.Record
+	}{{
+		name: "no records returned",
+	}, {
+		name: "single record returned",
+		inputRecs: []opencdc.Record{
+			{
+				Key: testDataStruct.Clone(),
+				Payload: opencdc.Change{
+					Before: testDataStruct.Clone(),
+					After:  testDataStruct.Clone(),
+				},
+			},
+		},
+		wantRecs: []opencdc.Record{
+			{
+				Key: testDataRaw,
+				Payload: opencdc.Change{
+					Before: testDataRaw,
+					After:  testDataRaw,
+				},
+			},
+		},
+	}, {
+		name: "multiple records returned",
+		inputRecs: []opencdc.Record{{
+			Key: testDataStruct.Clone(),
+			Payload: opencdc.Change{
+				Before: testDataStruct.Clone(),
+				After:  testDataStruct.Clone(),
+			},
+		}, {
+			Key: testDataStruct.Clone(),
+			Payload: opencdc.Change{
+				Before: testDataStruct.Clone(),
+				After:  testDataStruct.Clone(),
+			},
+		}, {
+			Key: testDataStruct.Clone(),
+			Payload: opencdc.Change{
+				Before: testDataStruct.Clone(),
+				After:  testDataStruct.Clone(),
+			},
+		}},
+		wantRecs: []opencdc.Record{{
+			Key: testDataRaw,
+			Payload: opencdc.Change{
+				Before: testDataRaw,
+				After:  testDataRaw,
+			},
+		}, {
+			Key: testDataRaw,
+			Payload: opencdc.Change{
+				Before: testDataRaw,
+				After:  testDataRaw,
+			},
+		}, {
+			Key: testDataRaw,
+			Payload: opencdc.Change{
+				Before: testDataRaw,
+				After:  testDataRaw,
+			},
+		}},
+	}}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			is := is.New(t)
+			src := NewMockSource(gomock.NewController(t))
+
+			underTest := (&SourceWithEncoding{}).Wrap(src)
+
+			for i := range tc.inputRecs {
+				tc.inputRecs[i].Metadata = map[string]string{
+					opencdc.MetadataCollection:           "foo",
+					opencdc.MetadataKeySchemaSubject:     customTestSchema.Subject,
+					opencdc.MetadataKeySchemaVersion:     strconv.Itoa(customTestSchema.Version),
+					opencdc.MetadataPayloadSchemaSubject: customTestSchema.Subject,
+					opencdc.MetadataPayloadSchemaVersion: strconv.Itoa(customTestSchema.Version),
+				}
+			}
+
+			src.EXPECT().ReadN(ctx, 100).Return(tc.inputRecs, nil)
+
+			got, err := underTest.ReadN(ctx, 100)
+			is.NoErr(err)
+
+			is.Equal(len(got), len(tc.wantRecs))
+
+			for i := range got {
+				gotKey := got[i].Key
+				gotPayloadBefore := got[i].Payload.Before
+				gotPayloadAfter := got[i].Payload.After
+
+				is.Equal("", cmp.Diff(tc.wantRecs[i].Key, gotKey))
+				is.Equal("", cmp.Diff(tc.wantRecs[i].Payload.Before, gotPayloadBefore))
+				is.Equal("", cmp.Diff(tc.wantRecs[i].Payload.After, gotPayloadAfter))
+			}
+		})
+	}
+}
+
 // -- SourceWithBatch --------------------------------------------------
 
 func TestSourceWithBatch_ReadN(t *testing.T) {