diff --git a/go.mod b/go.mod index 52cd93f6a..a88230ee8 100755 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module go.keploy.io/server/v2 go 1.21.0 -replace github.com/jackc/pgproto3/v2 => github.com/keploy/pgproto3/v2 v2.0.2 +replace github.com/jackc/pgproto3/v2 => github.com/keploy/pgproto3/v2 v2.0.5 require ( github.com/Microsoft/go-winio v0.6.1 // indirect diff --git a/go.sum b/go.sum index a52cb48cd..0cf0a5f12 100755 --- a/go.sum +++ b/go.sum @@ -127,8 +127,8 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/k0kubun/pp/v3 v3.2.0 h1:h33hNTZ9nVFNP3u2Fsgz8JXiF5JINoZfFq4SvKJwNcs= github.com/k0kubun/pp/v3 v3.2.0/go.mod h1:ODtJQbQcIRfAD3N+theGCV1m/CBxweERz2dapdz1EwA= -github.com/keploy/pgproto3/v2 v2.0.2 h1:exp+WlBBWucEmiYsjXezGrhzShdyHWkvQoIXzdj7Vj8= -github.com/keploy/pgproto3/v2 v2.0.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/keploy/pgproto3/v2 v2.0.5 h1:8spdNKZ+nOnHVxiimDsqulBRN6viPXPghkA7xppnzJ8= +github.com/keploy/pgproto3/v2 v2.0.5/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= diff --git a/pkg/core/proxy/integrations/generic/generic.go b/pkg/core/proxy/integrations/generic/generic.go index b76be9c06..13a9b74d5 100755 --- a/pkg/core/proxy/integrations/generic/generic.go +++ b/pkg/core/proxy/integrations/generic/generic.go @@ -31,7 +31,7 @@ func (g *Generic) MatchType(_ context.Context, _ []byte) bool { } func (g *Generic) RecordOutgoing(ctx context.Context, src net.Conn, dst net.Conn, mocks chan<- *models.Mock, opts models.OutgoingOptions) error { - logger := g.logger.With(zap.Any("Client IP Address", src.RemoteAddr().String()), zap.Any("Client ConnectionID", util.GetNextID()), zap.Any("Destination ConnectionID", util.GetNextID())) + logger := g.logger.With(zap.Any("Client IP Address", src.RemoteAddr().String()), zap.Any("Client ConnectionID", ctx.Value(models.ClientConnectionIDKey).(string)), zap.Any("Destination ConnectionID", ctx.Value(models.DestConnectionIDKey).(string))) reqBuf, err := util.ReadInitialBuf(ctx, logger, src) if err != nil { diff --git a/pkg/core/proxy/integrations/grpc/grpc.go b/pkg/core/proxy/integrations/grpc/grpc.go index ac4c89acc..af16eb867 100644 --- a/pkg/core/proxy/integrations/grpc/grpc.go +++ b/pkg/core/proxy/integrations/grpc/grpc.go @@ -34,7 +34,7 @@ func (g *Grpc) MatchType(_ context.Context, reqBuf []byte) bool { } func (g *Grpc) RecordOutgoing(ctx context.Context, src net.Conn, dst net.Conn, mocks chan<- *models.Mock, opts models.OutgoingOptions) error { - logger := g.logger.With(zap.Any("Client IP Address", src.RemoteAddr().String()), zap.Any("Client ConnectionID", util.GetNextID()), zap.Any("Destination ConnectionID", util.GetNextID())) + logger := g.logger.With(zap.Any("Client IP Address", src.RemoteAddr().String()), zap.Any("Client ConnectionID", ctx.Value(models.ClientConnectionIDKey).(string)), zap.Any("Destination ConnectionID", ctx.Value(models.DestConnectionIDKey).(string))) reqBuf, err := util.ReadInitialBuf(ctx, logger, src) if err != nil { diff --git a/pkg/core/proxy/integrations/http/http.go b/pkg/core/proxy/integrations/http/http.go index 8464a4745..e7801efe6 100755 --- a/pkg/core/proxy/integrations/http/http.go +++ b/pkg/core/proxy/integrations/http/http.go @@ -59,7 +59,7 @@ func (h *HTTP) MatchType(_ context.Context, buf []byte) bool { } func (h *HTTP) RecordOutgoing(ctx context.Context, src net.Conn, dst net.Conn, mocks chan<- *models.Mock, opts models.OutgoingOptions) error { - logger := h.logger.With(zap.Any("Client IP Address", src.RemoteAddr().String()), zap.Any("Client ConnectionID", util.GetNextID()), zap.Any("Destination ConnectionID", util.GetNextID())) + logger := h.logger.With(zap.Any("Client IP Address", src.RemoteAddr().String()), zap.Any("Client ConnectionID", ctx.Value(models.ClientConnectionIDKey).(string)), zap.Any("Destination ConnectionID", ctx.Value(models.DestConnectionIDKey).(string))) h.logger.Debug("Recording the outgoing http call in record mode") diff --git a/pkg/core/proxy/integrations/mongo/mongo.go b/pkg/core/proxy/integrations/mongo/mongo.go index de1f9b377..6b1cddf89 100644 --- a/pkg/core/proxy/integrations/mongo/mongo.go +++ b/pkg/core/proxy/integrations/mongo/mongo.go @@ -44,7 +44,7 @@ func (m *Mongo) MatchType(_ context.Context, buffer []byte) bool { } func (m *Mongo) RecordOutgoing(ctx context.Context, src net.Conn, dst net.Conn, mocks chan<- *models.Mock, opts models.OutgoingOptions) error { - logger := m.logger.With(zap.Any("Client IP Address", src.RemoteAddr().String()), zap.Any("Client ConnectionID", util.GetNextID()), zap.Any("Destination ConnectionID", util.GetNextID())) + logger := m.logger.With(zap.Any("Client IP Address", src.RemoteAddr().String()), zap.Any("Client ConnectionID", ctx.Value(models.ClientConnectionIDKey).(string)), zap.Any("Destination ConnectionID", ctx.Value(models.DestConnectionIDKey).(string))) reqBuf, err := util.ReadInitialBuf(ctx, logger, src) if err != nil { utils.LogError(logger, err, "failed to read the initial mongo message") diff --git a/pkg/core/proxy/integrations/mysql/mysql.go b/pkg/core/proxy/integrations/mysql/mysql.go index 5c93c246e..4e014c9ce 100644 --- a/pkg/core/proxy/integrations/mysql/mysql.go +++ b/pkg/core/proxy/integrations/mysql/mysql.go @@ -36,7 +36,7 @@ func (m *MySQL) MatchType(_ context.Context, _ []byte) bool { } func (m *MySQL) RecordOutgoing(ctx context.Context, src net.Conn, dst net.Conn, mocks chan<- *models.Mock, opts models.OutgoingOptions) error { - logger := m.logger.With(zap.Any("Client IP Address", src.RemoteAddr().String()), zap.Any("Client ConnectionID", util.GetNextID()), zap.Any("Destination ConnectionID", util.GetNextID())) + logger := m.logger.With(zap.Any("Client IP Address", src.RemoteAddr().String()), zap.Any("Client ConnectionID", ctx.Value(models.ClientConnectionIDKey).(string)), zap.Any("Destination ConnectionID", ctx.Value(models.DestConnectionIDKey).(string))) err := encodeMySQL(ctx, logger, src, dst, mocks, opts) if err != nil { diff --git a/pkg/core/proxy/integrations/postgres/v1/decode.go b/pkg/core/proxy/integrations/postgres/v1/decode.go index 48b37a35f..2650b1928 100644 --- a/pkg/core/proxy/integrations/postgres/v1/decode.go +++ b/pkg/core/proxy/integrations/postgres/v1/decode.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "strings" "time" "go.keploy.io/server/v2/pkg/core/proxy/integrations" @@ -19,8 +20,10 @@ import ( func decodePostgres(ctx context.Context, logger *zap.Logger, reqBuf []byte, clientConn net.Conn, dstCfg *integrations.ConditionalDstCfg, mockDb integrations.MockMemDb, _ models.OutgoingOptions) error { pgRequests := [][]byte{reqBuf} errCh := make(chan error, 1) - defer close(errCh) + go func(errCh chan error, pgRequests [][]byte) { + // close should be called from the producer of the channel + defer close(errCh) for { // Since protocol packets have to be parsed for checking stream end, // clientConnection have deadline for read to determine the end of stream. @@ -34,12 +37,11 @@ func decodePostgres(ctx context.Context, logger *zap.Logger, reqBuf []byte, clie for { buffer, err := pUtil.ReadBytes(ctx, logger, clientConn) if err != nil { - if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) && err != nil { + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { if err == io.EOF { logger.Debug("EOF error received from client. Closing conn in postgres !!") errCh <- err } - //TODO: why debug log sarthak? logger.Debug("failed to read the request message in proxy for postgres dependency") errCh <- err } @@ -55,7 +57,6 @@ func decodePostgres(ctx context.Context, logger *zap.Logger, reqBuf []byte, clie logger.Debug("the postgres request buffer is empty") continue } - matched, pgResponses, err := matchingReadablePG(ctx, logger, pgRequests, mockDb) if err != nil { errCh <- fmt.Errorf("error while matching tcs mocks %v", err) @@ -97,3 +98,46 @@ func decodePostgres(ctx context.Context, logger *zap.Logger, reqBuf []byte, clie return err } } + +type QueryData struct { + PrepIdentifier string `json:"PrepIdentifier" yaml:"PrepIdentifier"` + Query string `json:"Query" yaml:"Query"` +} + +type PrepMap map[string][]QueryData + +type TestPrepMap map[string][]QueryData + +func getRecordPrepStatement(allMocks []*models.Mock) PrepMap { + preparedstatement := make(PrepMap) + for _, v := range allMocks { + if v.Kind != "Postgres" { + continue + } + for _, req := range v.Spec.PostgresRequests { + var querydata []QueryData + psMap := make(map[string]string) + if len(req.PacketTypes) > 0 && req.PacketTypes[0] != "p" && req.Identfier != "StartupRequest" { + p := 0 + for _, header := range req.PacketTypes { + if header == "P" { + if strings.Contains(req.Parses[p].Name, "S_") { + psMap[req.Parses[p].Query] = req.Parses[p].Name + querydata = append(querydata, QueryData{PrepIdentifier: req.Parses[p].Name, + Query: req.Parses[p].Query, + }) + + } + p++ + } + } + } + // also append the query data for the prepared statement + if len(querydata) > 0 { + preparedstatement[v.ConnectionID] = append(preparedstatement[v.ConnectionID], querydata...) + } + } + + } + return preparedstatement +} diff --git a/pkg/core/proxy/integrations/postgres/v1/encode.go b/pkg/core/proxy/integrations/postgres/v1/encode.go index e18059aa3..0f59dfbc8 100755 --- a/pkg/core/proxy/integrations/postgres/v1/encode.go +++ b/pkg/core/proxy/integrations/postgres/v1/encode.go @@ -17,13 +17,6 @@ import ( ) func encodePostgres(ctx context.Context, logger *zap.Logger, reqBuf []byte, clientConn, destConn net.Conn, mocks chan<- *models.Mock, _ models.OutgoingOptions) error { - //closing the destination conn - defer func(destConn net.Conn) { - err := destConn.Close() - if err != nil { - utils.LogError(logger, err, "failed to close the destination connection") - } - }(destConn) logger.Debug("Inside the encodePostgresOutgoing function") var pgRequests []models.Backend @@ -79,10 +72,7 @@ func encodePostgres(ctx context.Context, logger *zap.Logger, reqBuf []byte, clie clientBuffChan := make(chan []byte) destBuffChan := make(chan []byte) - errChan := make(chan error) - defer close(clientBuffChan) - defer close(destBuffChan) - defer close(errChan) + errChan := make(chan error, 1) //get the error group from the context g := ctx.Value(models.ErrGroupKey).(*errgroup.Group) @@ -90,6 +80,7 @@ func encodePostgres(ctx context.Context, logger *zap.Logger, reqBuf []byte, clie // read requests from client g.Go(func() error { defer utils.Recover(logger) + defer close(clientBuffChan) pUtil.ReadBuffConn(ctx, logger, clientConn, clientBuffChan, errChan) return nil }) @@ -97,10 +88,19 @@ func encodePostgres(ctx context.Context, logger *zap.Logger, reqBuf []byte, clie // read responses from destination g.Go(func() error { defer utils.Recover(logger) + defer close(destBuffChan) pUtil.ReadBuffConn(ctx, logger, destConn, destBuffChan, errChan) return nil }) + go func() { + err := g.Wait() + if err != nil { + logger.Info("error group is returning an error", zap.Error(err)) + } + close(errChan) + }() + prevChunkWasReq := false logger.Debug("the iteration for the pg request starts", zap.Any("pgReqs", len(pgRequests)), zap.Any("pgResps", len(pgResponses))) @@ -125,6 +125,7 @@ func encodePostgres(ctx context.Context, logger *zap.Logger, reqBuf []byte, clie ResTimestampMock: resTimestampMock, Metadata: metadata, }, + ConnectionID: ctx.Value(models.ClientConnectionIDKey).(string), } return ctx.Err() } @@ -153,6 +154,7 @@ func encodePostgres(ctx context.Context, logger *zap.Logger, reqBuf []byte, clie ResTimestampMock: resTimestampMock, Metadata: metadata, }, + ConnectionID: ctx.Value(models.ClientConnectionIDKey).(string), } pgRequests = []models.Backend{} pgResponses = []models.Frontend{} @@ -295,17 +297,19 @@ func encodePostgres(ctx context.Context, logger *zap.Logger, reqBuf []byte, clie } if pg.FrontendWrapper.MsgType == 'C' { pg.FrontendWrapper.CommandComplete = *msg.(*pgproto3.CommandComplete) + // empty the command tag + pg.FrontendWrapper.CommandComplete.CommandTag = []byte{} pg.FrontendWrapper.CommandCompletes = append(pg.FrontendWrapper.CommandCompletes, pg.FrontendWrapper.CommandComplete) } - if pg.FrontendWrapper.DataRow.RowValues != nil { + if pg.FrontendWrapper.MsgType == 'D' && pg.FrontendWrapper.DataRow.RowValues != nil { // Create a new slice for each DataRow valuesCopy := make([]string, len(pg.FrontendWrapper.DataRow.RowValues)) copy(valuesCopy, pg.FrontendWrapper.DataRow.RowValues) row := pgproto3.DataRow{ RowValues: valuesCopy, // Use the copy of the values + Values: pg.FrontendWrapper.DataRow.Values, } - // fmt.Println("row is ", row) dataRows = append(dataRows, row) } } @@ -362,7 +366,7 @@ func encodePostgres(ctx context.Context, logger *zap.Logger, reqBuf []byte, clie if err != nil { logger.Debug("failed to decode the response message in proxy for postgres dependency", zap.Error(err)) } - if (len(afterEncoded) != len(buffer) && pgMock.PacketTypes[0] != "R") || len(pgMock.DataRows) > 0 { + if len(afterEncoded) != len(buffer) && pgMock.PacketTypes[0] != "R" { logger.Debug("the length of the encoded buffer is not equal to the length of the original buffer", zap.Any("after_encoded", len(afterEncoded)), zap.Any("buffer", len(buffer))) pgMock.Payload = bufStr } diff --git a/pkg/core/proxy/integrations/postgres/v1/match.go b/pkg/core/proxy/integrations/postgres/v1/match.go index 9f95164fe..beb487207 100644 --- a/pkg/core/proxy/integrations/postgres/v1/match.go +++ b/pkg/core/proxy/integrations/postgres/v1/match.go @@ -2,8 +2,11 @@ package v1 import ( "context" + "encoding/base64" "fmt" "math" + "reflect" + "strings" "github.com/jackc/pgproto3/v2" "go.keploy.io/server/v2/pkg/core/proxy/integrations" @@ -12,6 +15,53 @@ import ( "go.uber.org/zap" ) +var testmap TestPrepMap + +func getTestPS(reqBuff [][]byte, logger *zap.Logger, ConnectionID string) { + // maintain a map of current prepared statements and their corresponding connection id + // if it's the prepared statement match the query with the recorded prepared statement and return the response of that matched prepared statement at that connection + // so if parse is coming save to a same map + actualPgReq := decodePgRequest(reqBuff[0], logger) + if actualPgReq == nil { + return + } + testmap2 := make(TestPrepMap) + if testmap != nil { + testmap2 = testmap + } + querydata := make([]QueryData, 0) + if len(actualPgReq.PacketTypes) > 0 && actualPgReq.PacketTypes[0] != "p" && actualPgReq.Identfier != "StartupRequest" { + p := 0 + for _, header := range actualPgReq.PacketTypes { + if header == "P" { + if strings.Contains(actualPgReq.Parses[p].Name, "S_") && !IsValuePresent(ConnectionID, actualPgReq.Parses[p].Name) { + querydata = append(querydata, QueryData{PrepIdentifier: actualPgReq.Parses[p].Name, Query: actualPgReq.Parses[p].Query}) + } + p++ + } + } + } + + // also append the query data for the prepared statement + if len(querydata) > 0 { + testmap2[ConnectionID] = append(testmap2[ConnectionID], querydata...) + // fmt.Println("Test Prepared statement Map", testmap2) + testmap = testmap2 + } + +} + +func IsValuePresent(connectionid string, value string) bool { + if testmap != nil { + for _, v := range testmap[connectionid] { + if v.PrepIdentifier == value { + return true + } + } + } + return false +} + func matchingReadablePG(ctx context.Context, logger *zap.Logger, requestBuffers [][]byte, mockDb integrations.MockMemDb) (bool, []models.Frontend, error) { for { select { @@ -24,6 +74,20 @@ func matchingReadablePG(ctx context.Context, logger *zap.Logger, requestBuffers return false, nil, fmt.Errorf("error while getting tcs mocks %v", err) } + ConnectionID := ctx.Value(models.ClientConnectionIDKey).(string) + + recordedPrep := getRecordPrepStatement(tcsMocks) + reqGoingOn := decodePgRequest(requestBuffers[0], logger) + if reqGoingOn != nil { + logger.Debug("PacketTypes", zap.Any("PacketTypes", reqGoingOn.PacketTypes)) + // fmt.Println("REQUEST GOING ON - ", reqGoingOn) + logger.Debug("ConnectionId-", zap.String("ConnectionId", ConnectionID)) + logger.Debug("TestMap*****", zap.Any("TestMap", testmap)) + } + // if recordedPrep != nil { + // fmt.Println("PREPARED STATEMENT", recordedPrep) + // } + var sortFlag = true var sortedTcsMocks []*models.Mock var matchedMock *models.Mock @@ -44,35 +108,38 @@ func matchingReadablePG(ctx context.Context, logger *zap.Logger, requestBuffers } } + initMock := *mock if len(mock.Spec.PostgresRequests) == len(requestBuffers) { - for requestIndex, reqBuf := range requestBuffers { - if ctx.Err() != nil { - return false, nil, ctx.Err() - } - bufStr := util.EncodeBase64(reqBuf) - encoded, err := postgresDecoderBackend(mock.Spec.PostgresRequests[requestIndex]) + for requestIndex, reqBuff := range requestBuffers { + bufStr := base64.StdEncoding.EncodeToString(reqBuff) + encodedMock, err := postgresDecoderBackend(mock.Spec.PostgresRequests[requestIndex]) if err != nil { logger.Debug("Error while decoding postgres request", zap.Error(err)) } - if mock.Spec.PostgresRequests[requestIndex].Identfier == "StartupRequest" { - logger.Debug("CHANGING TO MD5 for Response") - mock.Spec.PostgresResponses[requestIndex].AuthType = 5 - continue - } + switch { + case bufStr == "AAAACATSFi8=": + ssl := models.Frontend{ + Payload: "Tg==", + } + return true, []models.Frontend{ssl}, nil + case mock.Spec.PostgresRequests[requestIndex].Identfier == "StartupRequest" && isStartupPacket(reqBuff) && mock.Spec.PostgresRequests[requestIndex].Payload != "AAAACATSFi8=" && mock.Spec.PostgresResponses[requestIndex].AuthType == 10: + logger.Debug("CHANGING TO MD5 for Response", zap.String("mock", mock.Name), zap.String("Req", bufStr)) + initMock.Spec.PostgresResponses[requestIndex].AuthType = 5 + return true, initMock.Spec.PostgresResponses, nil + case len(encodedMock) > 0 && encodedMock[0] == 'p' && mock.Spec.PostgresRequests[requestIndex].PacketTypes[0] == "p" && reqBuff[0] == 'p': + logger.Debug("CHANGING TO MD5 for Request and Response", zap.String("mock", mock.Name), zap.String("Req", bufStr)) - if len(encoded) > 0 && encoded[0] == 'p' { - logger.Debug("CHANGING TO MD5 for Request and Response") - mock.Spec.PostgresRequests[requestIndex].PasswordMessage.Password = "md5fe4f2f657f01fa1dd9d111d5391e7c07" + initMock.Spec.PostgresRequests[requestIndex].PasswordMessage.Password = "md5fe4f2f657f01fa1dd9d111d5391e7c07" - mock.Spec.PostgresResponses[requestIndex].PacketTypes = []string{"R", "S", "S", "S", "S", "S", "S", "S", "S", "S", "S", "S", "K", "Z"} - mock.Spec.PostgresResponses[requestIndex].AuthType = 0 - mock.Spec.PostgresResponses[requestIndex].BackendKeyData = pgproto3.BackendKeyData{ + initMock.Spec.PostgresResponses[requestIndex].PacketTypes = []string{"R", "S", "S", "S", "S", "S", "S", "S", "S", "S", "S", "S", "K", "Z"} + initMock.Spec.PostgresResponses[requestIndex].AuthType = 0 + initMock.Spec.PostgresResponses[requestIndex].BackendKeyData = pgproto3.BackendKeyData{ ProcessID: 2613, SecretKey: 824670820, } - mock.Spec.PostgresResponses[requestIndex].ReadyForQuery.TxStatus = 73 - mock.Spec.PostgresResponses[requestIndex].ParameterStatusCombined = []pgproto3.ParameterStatus{ + initMock.Spec.PostgresResponses[requestIndex].ReadyForQuery.TxStatus = 73 + initMock.Spec.PostgresResponses[requestIndex].ParameterStatusCombined = []pgproto3.ParameterStatus{ { Name: "application_name", Value: "", @@ -118,16 +185,13 @@ func matchingReadablePG(ctx context.Context, logger *zap.Logger, requestBuffers Value: "Etc/UTC", }, } + return true, initMock.Spec.PostgresResponses, nil } - if bufStr == "AAAACATSFi8=" { - ssl := models.Frontend{ - Payload: "Tg==", - } - return true, []models.Frontend{ssl}, nil - } } } + // maintain test prepare statement map for each connection id + getTestPS(requestBuffers, logger, ConnectionID) } logger.Debug("Sorted Mocks: ", zap.Any("Len of sortedTcsMocks", len(sortedTcsMocks))) @@ -138,19 +202,49 @@ func matchingReadablePG(ctx context.Context, logger *zap.Logger, requestBuffers // give more priority to sorted like if you find more than 0.5 in sorted then return that if len(sortedTcsMocks) > 0 { sorted = true + idx1, newMock := findPGStreamMatch(sortedTcsMocks, requestBuffers, logger, sorted, ConnectionID, recordedPrep) + if idx1 != -1 { + matched = true + matchedMock = tcsMocks[idx1] + if newMock != nil { + matchedMock = newMock + } + // fmt.Println("Matched In Absolute Custom Matching for sorted!!!", matchedMock.Name) + } idx = findBinaryStreamMatch(logger, sortedTcsMocks, requestBuffers, sorted) - if idx != -1 { + if idx != -1 && !matched { matched = true matchedMock = tcsMocks[idx] + // fmt.Println("Matched In Binary Matching for Sorted", matchedMock.Name) } } if !matched { sorted = false + idx1, newMock := findPGStreamMatch(sortedTcsMocks, requestBuffers, logger, sorted, ConnectionID, recordedPrep) + if idx1 != -1 { + matched = true + matchedMock = tcsMocks[idx1] + if newMock != nil { + matchedMock = newMock + } + // fmt.Println("Matched In Absolute Custom Matching for Unsorted", matchedMock.Name) + } idx = findBinaryStreamMatch(logger, tcsMocks, requestBuffers, sorted) - if idx != -1 { + // check if the validate the query with the matched mock + // if the query is same then return the response of that mock + var isValid = true + if idx != -1 && len(sortedTcsMocks) != 0 { + isValid, newMock = validateMock(tcsMocks, idx, requestBuffers, logger) + logger.Debug("Is Valid", zap.Bool("Is Valid", isValid)) + } + if idx != -1 && !matched { matched = true matchedMock = tcsMocks[idx] + if newMock != nil && !isValid { + matchedMock = newMock + } + // fmt.Println("Matched In Binary Matching for Unsorted", matchedMock.Name) } } @@ -237,3 +331,413 @@ func fuzzyCheck(encoded, reqBuf []byte) float64 { similarity := util.JaccardSimilarity(shingles1, shingles2) return similarity } + +func findPGStreamMatch(tcsMocks []*models.Mock, requestBuffers [][]byte, logger *zap.Logger, isSorted bool, connectionID string, recordedPrep PrepMap) (int, *models.Mock) { + + mxIdx := -1 + + match := false + // loop for the exact match of the request + for idx, mock := range tcsMocks { + if len(mock.Spec.PostgresRequests) == len(requestBuffers) { + for _, reqBuff := range requestBuffers { + actualPgReq := decodePgRequest(reqBuff, logger) + if actualPgReq == nil { + return -1, nil + } + + // here handle cases of prepared statement very carefully + match, err := compareExactMatch(mock, actualPgReq, logger) + if err != nil { + logger.Error("Error while matching exact match", zap.Error(err)) + continue + } + if match { + return idx, nil + } + } + } + } + if !isSorted { + return mxIdx, nil + } + // loop for the ps match of the request + if !match { + for idx, mock := range tcsMocks { + if len(mock.Spec.PostgresRequests) == len(requestBuffers) { + for _, reqBuff := range requestBuffers { + actualPgReq := decodePgRequest(reqBuff, logger) + if actualPgReq == nil { + return -1, nil + } + // just matching the corresponding PS in this case there is no need to edit the mock + match, newBindPs, err := PreparedStatementMatch(mock, actualPgReq, logger, connectionID, recordedPrep) + if err != nil { + logger.Error("Error while matching prepared statements", zap.Error(err)) + } + + if match { + logger.Debug("New Bind Prepared Statement", zap.Any("New Bind Prepared Statement", newBindPs), zap.String("ConnectionId", connectionID), zap.String("Mock Name", mock.Name)) + return idx, nil + } + // just check the query + if reflect.DeepEqual(actualPgReq.PacketTypes, []string{"P", "B", "D", "E"}) && reflect.DeepEqual(mock.Spec.PostgresRequests[0].PacketTypes, []string{"P", "B", "D", "E"}) { + if mock.Spec.PostgresRequests[0].Parses[0].Query == actualPgReq.Parses[0].Query { + return idx, nil + } + } + } + } + } + } + + if !match { + for idx, mock := range tcsMocks { + if len(mock.Spec.PostgresRequests) == len(requestBuffers) { + for _, reqBuff := range requestBuffers { + actualPgReq := decodePgRequest(reqBuff, logger) + if actualPgReq == nil { + return -1, nil + } + + // have to ignore first parse message of begin read only + // should compare only query in the parse message + if len(actualPgReq.PacketTypes) != len(mock.Spec.PostgresRequests[0].PacketTypes) { + //check for begin read only + if len(actualPgReq.PacketTypes) > 0 && len(mock.Spec.PostgresRequests[0].PacketTypes) > 0 { + + ischanged, newMock := changeResToPS(mock, actualPgReq, logger, connectionID) + + if ischanged { + return idx, newMock + } + continue + + } + + } + } + } + } + } + + return mxIdx, nil +} + +// check what are the queries for the given ps of actualPgReq +// check if the execute query is present for that or not +// mark that mock true and return the response by changing the res format like +// postgres data types acc to result set format +func changeResToPS(mock *models.Mock, actualPgReq *models.Backend, logger *zap.Logger, connectionID string) (bool, *models.Mock) { + actualpackets := actualPgReq.PacketTypes + mockPackets := mock.Spec.PostgresRequests[0].PacketTypes + + // [P, B, E, P, B, D, E] => [B, E, B, E] + // write code that of packet is ["B", "E"] and mockPackets ["P", "B", "D", "E"] handle it in case1 + // and if packet is [B, E, B, E] and mockPackets [P, B, E, P, B, D, E] handle it in case2 + + ischanged := false + var newMock *models.Mock + // [B E P D B E] + // [P, B, E, P, B, D, E] -> [B, E, P, B, D, E] + if (reflect.DeepEqual(actualpackets, []string{"B", "E", "P", "D", "B", "E"}) || reflect.DeepEqual(actualpackets, []string{"B", "E", "P", "B", "D", "E"})) && reflect.DeepEqual(mockPackets, []string{"P", "B", "E", "P", "B", "D", "E"}) { + // fmt.Println("Handling Case 1 for mock", mock.Name) + // handleCase1(packets, mockPackets) + // also check if the second query is same or not + // fmt.Println("ActualPgReq", actualPgReq.Parses[0].Query, "MOCK REQ 1", mock.Spec.PostgresRequests[0].Parses[0].Query, "MOCK REQ 2", mock.Spec.PostgresRequests[0].Parses[1].Query) + if actualPgReq.Parses[0].Query != mock.Spec.PostgresRequests[0].Parses[1].Query { + return false, nil + } + newMock = sliceCommandTag(mock, logger, testmap[connectionID], actualPgReq, 1) + return true, newMock + } + + // case 2 + var ps string + if reflect.DeepEqual(actualpackets, []string{"B", "E"}) && reflect.DeepEqual(mockPackets, []string{"P", "B", "D", "E"}) { + // fmt.Println("Handling Case 2 for mock", mock.Name) + ps = actualPgReq.Binds[0].PreparedStatement + for _, v := range testmap[connectionID] { + if v.Query == mock.Spec.PostgresRequests[0].Parses[0].Query && v.PrepIdentifier == ps { + ischanged = true + break + } + } + } + + if ischanged { + // if strings.Contains(ps, "S_") { + // fmt.Println("Inside Prepared Statement") + newMock = sliceCommandTag(mock, logger, testmap[connectionID], actualPgReq, 2) + // } + return true, newMock + } + + // packets = []string{"B", "E", "B", "E"} + // mockPackets = []string{"P", "B", "E", "P", "B", "D", "E"} + + // Case 3 + if reflect.DeepEqual(actualpackets, []string{"B", "E", "B", "E"}) && reflect.DeepEqual(mockPackets, []string{"P", "B", "E", "P", "B", "D", "E"}) { + // fmt.Println("Handling Case 3 for mock", mock.Name) + ischanged1 := false + ps1 := actualPgReq.Binds[0].PreparedStatement + for _, v := range testmap[connectionID] { + if v.Query == mock.Spec.PostgresRequests[0].Parses[0].Query && v.PrepIdentifier == ps1 { + ischanged1 = true + break + } + } + //Matched In Binary Matching for Unsorted mock-222 + ischanged2 := false + ps2 := actualPgReq.Binds[1].PreparedStatement + for _, v := range testmap[connectionID] { + if v.Query == mock.Spec.PostgresRequests[0].Parses[1].Query && v.PrepIdentifier == ps2 { + ischanged2 = true + break + } + } + if ischanged1 && ischanged2 { + newMock = sliceCommandTag(mock, logger, testmap[connectionID], actualPgReq, 2) + return true, newMock + } + } + + // Case 4 + if reflect.DeepEqual(actualpackets, []string{"B", "E", "B", "E"}) && reflect.DeepEqual(mockPackets, []string{"B", "E", "P", "B", "D", "E"}) { + // fmt.Println("Handling Case 4 for mock", mock.Name) + // get the query for the prepared statement of test mode + ischanged := false + ps := actualPgReq.Binds[1].PreparedStatement + for _, v := range testmap[connectionID] { + if v.Query == mock.Spec.PostgresRequests[0].Parses[0].Query && v.PrepIdentifier == ps { + ischanged = true + break + } + } + if ischanged { + newMock = sliceCommandTag(mock, logger, testmap[connectionID], actualPgReq, 2) + return true, newMock + } + + } + + return false, nil + +} + +func PreparedStatementMatch(mock *models.Mock, actualPgReq *models.Backend, logger *zap.Logger, ConnectionID string, recordedPrep PrepMap) (bool, []string, error) { + // fmt.Println("Inside PreparedStatementMatch") + // check the current Query associated with the connection id and Identifier + ifps := checkIfps(actualPgReq.PacketTypes) + if !ifps { + return false, nil, nil + } + // check if given mock is a prepared statement + ifpsMock := checkIfps(mock.Spec.PostgresRequests[0].PacketTypes) + if !ifpsMock { + return false, nil, nil + } + + if len(mock.Spec.PostgresRequests[0].PacketTypes) != len(actualPgReq.PacketTypes) { + return false, nil, nil + } + + // get all the binds from the actualPgReq + binds := actualPgReq.Binds + newBinPreparedStatement := make([]string, 0) + mockBinds := mock.Spec.PostgresRequests[0].Binds + mockConn := mock.ConnectionID + var foo = false + for idx, bind := range binds { + currentPs := bind.PreparedStatement + currentQuerydata := testmap[ConnectionID] + currentQuery := "" + // check in the map that what's the current query for this preparedstatement + // then will check what is the recorded prepared statement for this query + for _, v := range currentQuerydata { + if v.PrepIdentifier == currentPs { + // fmt.Println("Current query for this identifier is ", v.Query) + currentQuery = v.Query + break + } + } + logger.Debug("Current Query for this prepared statement", zap.String("Query", currentQuery), zap.String("Identifier", currentPs)) + foo = false + + // for _, mb := range mockBinds { + // check if the query for mock ps (v.PreparedStatement) is same as the current query + for _, querydata := range recordedPrep[mockConn] { + if querydata.Query == currentQuery && mockBinds[idx].PreparedStatement == querydata.PrepIdentifier { + logger.Debug("Matched with the recorded prepared statement with Identifier and connectionID is", zap.String("Identifier", querydata.PrepIdentifier), zap.String("ConnectionId", mockConn), zap.String("Current Identifier", currentPs), zap.String("Query", currentQuery)) + foo = true + break + } + // } + } + } + if foo { + return true, newBinPreparedStatement, nil + } + + return false, nil, nil +} + +func compareExactMatch(mock *models.Mock, actualPgReq *models.Backend, logger *zap.Logger) (bool, error) { + logger.Debug("Inside CompareExactMatch") + // have to ignore first parse message of begin read only + // should compare only query in the parse message + if len(actualPgReq.PacketTypes) != len(mock.Spec.PostgresRequests[0].PacketTypes) { + return false, nil + } + + // call a separate function for matching prepared statements + for idx, v := range actualPgReq.PacketTypes { + if v != mock.Spec.PostgresRequests[0].PacketTypes[idx] { + return false, nil + } + } + // IsPreparedStatement(mock, actualPgReq, logger, ConnectionId) + + // this will give me the + var ( + p, b, e int = 0, 0, 0 + ) + for i := 0; i < len(actualPgReq.PacketTypes); i++ { + switch actualPgReq.PacketTypes[i] { + case "P": + // fmt.Println("Inside P") + p++ + if actualPgReq.Parses[p-1].Query != mock.Spec.PostgresRequests[0].Parses[p-1].Query { + return false, nil + } + + if actualPgReq.Parses[p-1].Name != mock.Spec.PostgresRequests[0].Parses[p-1].Name { + return false, nil + } + + if len(actualPgReq.Parses[p-1].ParameterOIDs) != len(mock.Spec.PostgresRequests[0].Parses[p-1].ParameterOIDs) { + return false, nil + } + for j := 0; j < len(actualPgReq.Parses[p-1].ParameterOIDs); j++ { + if actualPgReq.Parses[p-1].ParameterOIDs[j] != mock.Spec.PostgresRequests[0].Parses[p-1].ParameterOIDs[j] { + return false, nil + } + } + + case "B": + // fmt.Println("Inside B") + b++ + if actualPgReq.Binds[b-1].DestinationPortal != mock.Spec.PostgresRequests[0].Binds[b-1].DestinationPortal { + return false, nil + } + + if actualPgReq.Binds[b-1].PreparedStatement != mock.Spec.PostgresRequests[0].Binds[b-1].PreparedStatement { + return false, nil + } + + if len(actualPgReq.Binds[b-1].ParameterFormatCodes) != len(mock.Spec.PostgresRequests[0].Binds[b-1].ParameterFormatCodes) { + return false, nil + } + for j := 0; j < len(actualPgReq.Binds[b-1].ParameterFormatCodes); j++ { + if actualPgReq.Binds[b-1].ParameterFormatCodes[j] != mock.Spec.PostgresRequests[0].Binds[b-1].ParameterFormatCodes[j] { + return false, nil + } + } + if len(actualPgReq.Binds[b-1].Parameters) != len(mock.Spec.PostgresRequests[0].Binds[b-1].Parameters) { + return false, nil + } + for j := 0; j < len(actualPgReq.Binds[b-1].Parameters); j++ { + for _, v := range actualPgReq.Binds[b-1].Parameters[j] { + if v != mock.Spec.PostgresRequests[0].Binds[b-1].Parameters[j][0] { + return false, nil + } + } + } + if len(actualPgReq.Binds[b-1].ResultFormatCodes) != len(mock.Spec.PostgresRequests[0].Binds[b-1].ResultFormatCodes) { + return false, nil + } + for j := 0; j < len(actualPgReq.Binds[b-1].ResultFormatCodes); j++ { + if actualPgReq.Binds[b-1].ResultFormatCodes[j] != mock.Spec.PostgresRequests[0].Binds[b-1].ResultFormatCodes[j] { + return false, nil + } + } + + case "E": + // fmt.Println("Inside E") + e++ + if actualPgReq.Executes[e-1].Portal != mock.Spec.PostgresRequests[0].Executes[e-1].Portal { + return false, nil + } + if actualPgReq.Executes[e-1].MaxRows != mock.Spec.PostgresRequests[0].Executes[e-1].MaxRows { + return false, nil + } + + case "c": + if actualPgReq.CopyDone != mock.Spec.PostgresRequests[0].CopyDone { + return false, nil + } + case "H": + if actualPgReq.CopyFail.Message != mock.Spec.PostgresRequests[0].CopyFail.Message { + return false, nil + } + case "Q": + if actualPgReq.Query != mock.Spec.PostgresRequests[0].Query { + return false, nil + } + default: + return false, nil + } + } + return true, nil +} + +// make this in such a way if it returns -1 then we will continue with the original mock +func validateMock(tcsMocks []*models.Mock, idx int, requestBuffers [][]byte, logger *zap.Logger) (bool, *models.Mock) { + + actualPgReq := decodePgRequest(requestBuffers[0], logger) + if actualPgReq == nil { + return true, nil + } + mock := tcsMocks[idx].Spec.PostgresRequests[0] + if len(mock.PacketTypes) == len(actualPgReq.PacketTypes) { + if reflect.DeepEqual(tcsMocks[idx].Spec.PostgresRequests[0].PacketTypes, []string{"B", "E", "P", "B", "D", "E"}) { + if mock.Parses[0].Query == actualPgReq.Parses[0].Query { + return true, nil + } + } + if reflect.DeepEqual(mock.PacketTypes, []string{"B", "E", "B", "E"}) { + // fmt.Println("Inside Validate Mock for B, E, B, E") + return true, nil + } + if reflect.DeepEqual(mock.PacketTypes, []string{"B", "E"}) { + // fmt.Println("Inside Validate Mock for B, E") + copyMock := *tcsMocks[idx] + copyMock.Spec.PostgresResponses[0].PacketTypes = []string{"2", "C", "Z"} + copyMock.Spec.PostgresResponses[0].Payload = "" + return false, ©Mock + } + if reflect.DeepEqual(mock.PacketTypes, []string{"P", "B", "D", "E"}) { + // fmt.Println("Inside Validate Mock for P, B, D, E") + copyMock := *tcsMocks[idx] + copyMock.Spec.PostgresResponses[0].PacketTypes = []string{"1", "2", "T", "C", "Z"} + copyMock.Spec.PostgresResponses[0].Payload = "" + return false, ©Mock + } + } else { + // [B, E, P, B, D, E] => [ P, B, D, E] + if reflect.DeepEqual(mock.PacketTypes, []string{"B", "E", "P", "B", "D", "E"}) && reflect.DeepEqual(actualPgReq.PacketTypes, []string{"P", "B", "D", "E"}) { + // fmt.Println("Inside Validate Mock for B, E, B, E") + if mock.Parses[0].Query == actualPgReq.Parses[0].Query { + // no need to do anything + // fmt.Println("Matched with the query AHHAHAHAHAH", mock.Parses[0].Query) + copyMock := *tcsMocks[idx] + copyMock.Spec.PostgresResponses[0].PacketTypes = []string{"1", "2", "T", "C", "Z"} + copyMock.Spec.PostgresResponses[0].Payload = "" + copyMock.Spec.PostgresResponses[0].CommandCompletes = copyMock.Spec.PostgresResponses[0].CommandCompletes[1:] + // fmt.Println("Matched with the query AHHAHAHAHAH", copyMock) + return false, ©Mock + } + } + } + return true, nil +} diff --git a/pkg/core/proxy/integrations/postgres/v1/postgres.go b/pkg/core/proxy/integrations/postgres/v1/postgres.go index 612ff25c5..76e2c8bd9 100755 --- a/pkg/core/proxy/integrations/postgres/v1/postgres.go +++ b/pkg/core/proxy/integrations/postgres/v1/postgres.go @@ -48,7 +48,7 @@ func (p *PostgresV1) MatchType(_ context.Context, reqBuf []byte) bool { } func (p *PostgresV1) RecordOutgoing(ctx context.Context, src net.Conn, dst net.Conn, mocks chan<- *models.Mock, opts models.OutgoingOptions) error { - logger := p.logger.With(zap.Any("Client IP Address", src.RemoteAddr().String()), zap.Any("Client ConnectionID", util.GetNextID()), zap.Any("Destination ConnectionID", util.GetNextID())) + logger := p.logger.With(zap.Any("Client IP Address", src.RemoteAddr().String()), zap.Any("Client ConnectionID", ctx.Value(models.ClientConnectionIDKey).(string)), zap.Any("Destination ConnectionID", ctx.Value(models.DestConnectionIDKey).(string))) reqBuf, err := util.ReadInitialBuf(ctx, logger, src) if err != nil { @@ -57,7 +57,7 @@ func (p *PostgresV1) RecordOutgoing(ctx context.Context, src net.Conn, dst net.C } err = encodePostgres(ctx, logger, reqBuf, src, dst, mocks, opts) if err != nil { - // TODO: why debug log sarthak? + // TODO: why debug log? logger.Debug("failed to encode the postgres message into the yaml") return err } @@ -67,7 +67,6 @@ func (p *PostgresV1) RecordOutgoing(ctx context.Context, src net.Conn, dst net.C func (p *PostgresV1) MockOutgoing(ctx context.Context, src net.Conn, dstCfg *integrations.ConditionalDstCfg, mockDb integrations.MockMemDb, opts models.OutgoingOptions) error { logger := p.logger.With(zap.Any("Client IP Address", src.RemoteAddr().String()), zap.Any("Client ConnectionID", util.GetNextID()), zap.Any("Destination ConnectionID", util.GetNextID())) - reqBuf, err := util.ReadInitialBuf(ctx, logger, src) if err != nil { utils.LogError(logger, err, "failed to read the initial postgres message") @@ -76,7 +75,6 @@ func (p *PostgresV1) MockOutgoing(ctx context.Context, src net.Conn, dstCfg *int err = decodePostgres(ctx, logger, reqBuf, src, dstCfg, mockDb, opts) if err != nil { - //TODO: why debug log sarthak? logger.Debug("failed to decode the postgres message from the yaml") return err } diff --git a/pkg/core/proxy/integrations/postgres/v1/util.go b/pkg/core/proxy/integrations/postgres/v1/util.go index e7a9b40cd..54a299ccf 100755 --- a/pkg/core/proxy/integrations/postgres/v1/util.go +++ b/pkg/core/proxy/integrations/postgres/v1/util.go @@ -1,11 +1,16 @@ package v1 import ( + "encoding/binary" "errors" "fmt" + "strconv" + "time" "github.com/jackc/pgproto3/v2" + "go.keploy.io/server/v2/pkg/core/proxy/integrations/util" "go.keploy.io/server/v2/pkg/models" + "go.uber.org/zap" ) func postgresDecoderFrontend(response models.Frontend) ([]byte, error) { @@ -13,7 +18,7 @@ func postgresDecoderFrontend(response models.Frontend) ([]byte, error) { var resbuffer []byte // list of packets available in the buffer packets := response.PacketTypes - var cc, dtr, ps = 0, 0, 0 + var cc, dtr, ps int = 0, 0, 0 for _, packet := range packets { var msg pgproto3.BackendMessage @@ -34,7 +39,8 @@ func postgresDecoderFrontend(response models.Frontend) ([]byte, error) { msg = &pgproto3.CopyDone{} case string('C'): msg = &pgproto3.CommandComplete{ - CommandTag: response.CommandCompletes[cc].CommandTag, + CommandTag: response.CommandCompletes[cc].CommandTag, + CommandTagType: response.CommandCompletes[cc].CommandTagType, } cc++ case string('d'): @@ -168,7 +174,6 @@ func postgresDecoderFrontend(response models.Frontend) ([]byte, error) { } encoded := msg.Encode([]byte{}) - // fmt.Println("Encoded packet ", packet, " is ", i, "-----", encoded) resbuffer = append(resbuffer, encoded...) } return resbuffer, nil @@ -285,3 +290,168 @@ func postgresDecoderBackend(request models.Backend) ([]byte, error) { } return reqbuffer, nil } + +func checkIfps(array []string) bool { + n := len(array) + if n%2 != 0 { + // If the array length is odd, it cannot match the pattern + return false + } + + for i := 0; i < n; i += 2 { + // Check if consecutive elements are "B" and "E" + if array[i] != "B" || array[i+1] != "E" { + return false + } + } + + return true +} + +func sliceCommandTag(mock *models.Mock, logger *zap.Logger, prep []QueryData, actualPgReq *models.Backend, psCase int) *models.Mock { + + logger.Debug("Inside Slice Command Tag for ", zap.Int("psCase", psCase)) + logger.Debug("Prep Query Data", zap.Any("prep", prep)) + switch psCase { + case 1: + + copyMock := *mock + // fmt.Println("Inside Slice Command Tag for ", psCase) + mockPackets := copyMock.Spec.PostgresResponses[0].PacketTypes + for idx, v := range mockPackets { + if v == "1" { + mockPackets = append(mockPackets[:idx], mockPackets[idx+1:]...) + } + } + copyMock.Spec.PostgresResponses[0].Payload = "" + copyMock.Spec.PostgresResponses[0].PacketTypes = mockPackets + + return ©Mock + case 2: + // ["2", D, C, Z] + copyMock := *mock + // fmt.Println("Inside Slice Command Tag for ", psCase) + mockPackets := copyMock.Spec.PostgresResponses[0].PacketTypes + for idx, v := range mockPackets { + if v == "1" || v == "T" { + mockPackets = append(mockPackets[:idx], mockPackets[idx+1:]...) + } + } + copyMock.Spec.PostgresResponses[0].Payload = "" + copyMock.Spec.PostgresResponses[0].PacketTypes = mockPackets + rsFormat := actualPgReq.Bind.ResultFormatCodes + + for idx, datarow := range copyMock.Spec.PostgresResponses[0].DataRows { + for column, rowVal := range datarow.RowValues { + // fmt.Println("datarow.RowValues", len(datarow.RowValues)) + if rsFormat[column] == 1 { + // datarows := make([]byte, 4) + newRow, _ := getChandedDataRow(rowVal) + // logger.Info("New Row Value", zap.String("newRow", newRow)) + copyMock.Spec.PostgresResponses[0].DataRows[idx].RowValues[column] = newRow + } + } + } + return ©Mock + default: + } + return nil +} + +func getChandedDataRow(input string) (string, error) { + // Convert input1 (integer input as string) to integer + buffer := make([]byte, 4) + if intValue, err := strconv.Atoi(input); err == nil { + + binary.BigEndian.PutUint32(buffer, uint32(intValue)) + return ("b64:" + util.EncodeBase64(buffer)), nil + } else if dateValue, err := time.Parse("2006-01-02", input); err == nil { + // Perform additional operations on the date + epoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + difference := dateValue.Sub(epoch).Hours() / 24 + // fmt.Printf("Difference in days from epoch: %.2f days\n", difference) + binary.BigEndian.PutUint32(buffer, uint32(difference)) + return ("b64:" + util.EncodeBase64(buffer)), nil + } + return "b64:AAAAAA==", errors.New("Invalid input") + +} + +func decodePgRequest(buffer []byte, logger *zap.Logger) *models.Backend { + + pg := NewBackend() + + if !isStartupPacket(buffer) && len(buffer) > 5 { + bufferCopy := buffer + for i := 0; i < len(bufferCopy)-5; { + logger.Debug("Inside the if condition") + pg.BackendWrapper.MsgType = buffer[i] + pg.BackendWrapper.BodyLen = int(binary.BigEndian.Uint32(buffer[i+1:])) - 4 + if len(buffer) < (i + pg.BackendWrapper.BodyLen + 5) { + logger.Debug("failed to translate the postgres request message due to shorter network packet buffer") + break + } + msg, err := pg.translateToReadableBackend(buffer[i:(i + pg.BackendWrapper.BodyLen + 5)]) + if err != nil && buffer[i] != 112 { + logger.Debug("failed to translate the request message to readable", zap.Error(err)) + } + if pg.BackendWrapper.MsgType == 'p' { + pg.BackendWrapper.PasswordMessage = *msg.(*pgproto3.PasswordMessage) + } + + if pg.BackendWrapper.MsgType == 'P' { + pg.BackendWrapper.Parse = *msg.(*pgproto3.Parse) + pg.BackendWrapper.Parses = append(pg.BackendWrapper.Parses, pg.BackendWrapper.Parse) + } + + if pg.BackendWrapper.MsgType == 'B' { + pg.BackendWrapper.Bind = *msg.(*pgproto3.Bind) + pg.BackendWrapper.Binds = append(pg.BackendWrapper.Binds, pg.BackendWrapper.Bind) + } + + if pg.BackendWrapper.MsgType == 'E' { + pg.BackendWrapper.Execute = *msg.(*pgproto3.Execute) + pg.BackendWrapper.Executes = append(pg.BackendWrapper.Executes, pg.BackendWrapper.Execute) + } + + pg.BackendWrapper.PacketTypes = append(pg.BackendWrapper.PacketTypes, string(pg.BackendWrapper.MsgType)) + + i += (5 + pg.BackendWrapper.BodyLen) + } + + pgMock := &models.Backend{ + PacketTypes: pg.BackendWrapper.PacketTypes, + Identfier: "ClientRequest", + Length: uint32(len(buffer)), + // Payload: bufStr, + Bind: pg.BackendWrapper.Bind, + Binds: pg.BackendWrapper.Binds, + PasswordMessage: pg.BackendWrapper.PasswordMessage, + CancelRequest: pg.BackendWrapper.CancelRequest, + Close: pg.BackendWrapper.Close, + CopyData: pg.BackendWrapper.CopyData, + CopyDone: pg.BackendWrapper.CopyDone, + CopyFail: pg.BackendWrapper.CopyFail, + Describe: pg.BackendWrapper.Describe, + Execute: pg.BackendWrapper.Execute, + Executes: pg.BackendWrapper.Executes, + Flush: pg.BackendWrapper.Flush, + FunctionCall: pg.BackendWrapper.FunctionCall, + GssEncRequest: pg.BackendWrapper.GssEncRequest, + Parse: pg.BackendWrapper.Parse, + Parses: pg.BackendWrapper.Parses, + Query: pg.BackendWrapper.Query, + SSlRequest: pg.BackendWrapper.SSlRequest, + StartupMessage: pg.BackendWrapper.StartupMessage, + SASLInitialResponse: pg.BackendWrapper.SASLInitialResponse, + SASLResponse: pg.BackendWrapper.SASLResponse, + Sync: pg.BackendWrapper.Sync, + Terminate: pg.BackendWrapper.Terminate, + MsgType: pg.BackendWrapper.MsgType, + AuthType: pg.BackendWrapper.AuthType, + } + return pgMock + } + + return nil +} diff --git a/pkg/core/proxy/proxy.go b/pkg/core/proxy/proxy.go index 3d74b698b..f2d13c43d 100755 --- a/pkg/core/proxy/proxy.go +++ b/pkg/core/proxy/proxy.go @@ -193,8 +193,10 @@ func (p *Proxy) start(ctx context.Context) error { p.logger.Info("proxy stopped...") }(listener) - clientConnErrGrp, clientConnCtx := errgroup.WithContext(ctx) + clientConnCtx, clientConnCancel := context.WithCancel(ctx) + clientConnErrGrp, _ := errgroup.WithContext(clientConnCtx) defer func() { + clientConnCancel() err := clientConnErrGrp.Wait() if err != nil { p.logger.Debug("failed to handle the client connection", zap.Error(err)) @@ -231,9 +233,9 @@ func (p *Proxy) start(ctx context.Context) error { // handle the client connection case clientConn := <-clientConnCh: clientConnErrGrp.Go(func() error { - utils.Recover(p.logger) + defer utils.Recover(p.logger) err := p.handleConnection(clientConnCtx, clientConn) - if err != nil { + if err != nil && err != io.EOF { utils.LogError(p.logger, err, "failed to handle the client connection") } return nil @@ -254,7 +256,13 @@ func (p *Proxy) handleConnection(ctx context.Context, srcConn net.Conn) error { // making a new client connection id for each client connection clientConnID := util.GetNextID() - p.logger.Debug("New client connection", zap.Any("connectionID", clientConnID)) + p.logger.Info("New client connection", zap.Any("connectionID", clientConnID)) + + // dstConn stores conn with actual destination for the outgoing network call + var dstConn net.Conn + + //Dialing for tls conn + destConnID := util.GetNextID() remoteAddr := srcConn.RemoteAddr().(*net.TCPAddr) sourcePort := remoteAddr.Port @@ -294,13 +302,26 @@ func (p *Proxy) handleConnection(ctx context.Context, srcConn net.Conn) error { // This is used to handle the parser errors parserErrGrp, parserCtx := errgroup.WithContext(ctx) parserCtx = context.WithValue(parserCtx, models.ErrGroupKey, parserErrGrp) + parserCtx = context.WithValue(parserCtx, models.ClientConnectionIDKey, fmt.Sprint(clientConnID)) + parserCtx = context.WithValue(parserCtx, models.DestConnectionIDKey, fmt.Sprint(destConnID)) + parserCtx, parserCtxCancel := context.WithCancel(parserCtx) defer func() { + parserCtxCancel() + err := srcConn.Close() if err != nil { utils.LogError(p.logger, err, "failed to close the source connection") return } + if dstConn != nil { + err = dstConn.Close() + if err != nil { + utils.LogError(p.logger, err, "failed to close the destination connection") + return + } + } + err = parserErrGrp.Wait() if err != nil { utils.LogError(p.logger, err, "failed to handle the parser cleanUp") @@ -383,12 +404,6 @@ func (p *Proxy) handleConnection(ctx context.Context, srcConn net.Conn) error { logger: p.logger, } - // dstConn stores conn with actual destination for the outgoing network call - var dstConn net.Conn - - //Dialing for tls conn - destConnID := util.GetNextID() - logger := p.logger.With(zap.Any("Client IP Address", srcConn.RemoteAddr().String()), zap.Any("Client ConnectionID", clientConnID), zap.Any("Destination IP Address", dstAddr), zap.Any("Destination ConnectionID", destConnID)) dstCfg := &integrations.ConditionalDstCfg{ @@ -437,6 +452,7 @@ func (p *Proxy) handleConnection(ctx context.Context, srcConn net.Conn) error { } generic := true + //Checking for all the parsers. for _, parser := range p.Integrations { if parser.MatchType(parserCtx, initialBuf) { @@ -448,7 +464,7 @@ func (p *Proxy) handleConnection(ctx context.Context, srcConn net.Conn) error { } } else { err := parser.MockOutgoing(parserCtx, srcConn, dstCfg, m.(*MockManager), rule.OutgoingOptions) - if err != nil { + if err != nil && err != io.EOF { utils.LogError(logger, err, "failed to mock the outgoing message") return err } diff --git a/pkg/core/proxy/util/util.go b/pkg/core/proxy/util/util.go index c81d32783..543978a0a 100755 --- a/pkg/core/proxy/util/util.go +++ b/pkg/core/proxy/util/util.go @@ -39,7 +39,7 @@ func ReadBuffConn(ctx context.Context, logger *zap.Logger, conn net.Conn, buffer for { select { case <-ctx.Done(): - errChannel <- ctx.Err() + // errChannel <- ctx.Err() return default: if conn == nil { @@ -47,6 +47,9 @@ func ReadBuffConn(ctx context.Context, logger *zap.Logger, conn net.Conn, buffer } buffer, err := ReadBytes(ctx, logger, conn) if err != nil { + if ctx.Err() != nil { // to avoid sending buffer to closed channel if the context is cancelled + return + } utils.LogError(logger, err, "failed to read the packet message in proxy") errChannel <- err return diff --git a/pkg/models/const.go b/pkg/models/const.go index faf4d3d85..f70ee547a 100755 --- a/pkg/models/const.go +++ b/pkg/models/const.go @@ -189,3 +189,5 @@ const ( type contextKey string const ErrGroupKey contextKey = "errGroup" +const ClientConnectionIDKey contextKey = "clientConnectionId" +const DestConnectionIDKey contextKey = "destConnectionId" diff --git a/pkg/models/mock.go b/pkg/models/mock.go index 8f1ada191..7de8bdd84 100755 --- a/pkg/models/mock.go +++ b/pkg/models/mock.go @@ -8,6 +8,7 @@ type Mock struct { Kind Kind `json:"Kind,omitempty" bson:"Kind,omitempty"` Spec MockSpec `json:"Spec,omitempty" bson:"Spec,omitempty"` TestModeInfo TestModeInfo `json:"TestModeInfo,omitempty" bson:"TestModeInfo,omitempty"` // Map for additional test mode information + ConnectionID string `json:"ConnectionId,omitempty" bson:"ConnectionId,omitempty"` } type TestModeInfo struct { diff --git a/pkg/platform/yaml/mockdb/db.go b/pkg/platform/yaml/mockdb/db.go index ff3ed8248..28afc6848 100644 --- a/pkg/platform/yaml/mockdb/db.go +++ b/pkg/platform/yaml/mockdb/db.go @@ -254,9 +254,9 @@ func (ys *MockYaml) GetUnFilteredMocks(ctx context.Context, testSetID string, af return unfilteredMocks[i].Spec.ReqTimestampMock.Before(unfilteredMocks[j].Spec.ReqTimestampMock) }) - if len(unfilteredMocks) > 10 { - unfilteredMocks = unfilteredMocks[:10] - } + // if len(unfilteredMocks) > 10 { + // unfilteredMocks = unfilteredMocks[:10] + // } mocks := append(filteredMocks, unfilteredMocks...) diff --git a/pkg/platform/yaml/mockdb/util.go b/pkg/platform/yaml/mockdb/util.go index 33dce3af1..90212a81a 100644 --- a/pkg/platform/yaml/mockdb/util.go +++ b/pkg/platform/yaml/mockdb/util.go @@ -13,9 +13,10 @@ import ( func EncodeMock(mock *models.Mock, logger *zap.Logger) (*yaml.NetworkTrafficDoc, error) { yamlDoc := yaml.NetworkTrafficDoc{ - Version: mock.Version, - Kind: mock.Kind, - Name: mock.Name, + Version: mock.Version, + Kind: mock.Kind, + Name: mock.Name, + ConnectionID: mock.ConnectionID, } switch mock.Kind { case models.Mongo: @@ -88,6 +89,7 @@ func EncodeMock(mock *models.Mock, logger *zap.Logger) (*yaml.NetworkTrafficDoc, return nil, err } case models.Postgres: + // case models.PostgresV2: postgresSpec := models.PostgresSpec{ Metadata: mock.Spec.Metadata, @@ -167,9 +169,10 @@ func decodeMocks(yamlMocks []*yaml.NetworkTrafficDoc, logger *zap.Logger) ([]*mo for _, m := range yamlMocks { mock := models.Mock{ - Version: m.Version, - Name: m.Name, - Kind: m.Kind, + Version: m.Version, + Name: m.Name, + Kind: m.Kind, + ConnectionID: m.ConnectionID, } mockCheck := strings.Split(string(m.Kind), "-") if len(mockCheck) > 1 { @@ -235,6 +238,7 @@ func decodeMocks(yamlMocks []*yaml.NetworkTrafficDoc, logger *zap.Logger) ([]*mo } case models.Postgres: + // case models.PostgresV2: PostSpec := models.PostgresSpec{} err := m.Spec.Decode(&PostSpec) diff --git a/pkg/platform/yaml/yaml.go b/pkg/platform/yaml/yaml.go index fba5a2eb4..ad6d7059f 100755 --- a/pkg/platform/yaml/yaml.go +++ b/pkg/platform/yaml/yaml.go @@ -16,11 +16,12 @@ import ( // NetworkTrafficDoc stores the request-response data of a network call (ingress or egress) type NetworkTrafficDoc struct { - Version models.Version `json:"version" yaml:"version"` - Kind models.Kind `json:"kind" yaml:"kind"` - Name string `json:"name" yaml:"name"` - Spec yamlLib.Node `json:"spec" yaml:"spec"` - Curl string `json:"curl" yaml:"curl,omitempty"` + Version models.Version `json:"version" yaml:"version"` + Kind models.Kind `json:"kind" yaml:"kind"` + Name string `json:"name" yaml:"name"` + Spec yamlLib.Node `json:"spec" yaml:"spec"` + Curl string `json:"curl" yaml:"curl,omitempty"` + ConnectionID string `json:"connectionId" yaml:"connectionId,omitempty"` } // ctxReader wraps an io.Reader with a context for cancellation support @@ -46,18 +47,13 @@ type ctxWriter struct { func (cw *ctxWriter) Write(p []byte) (n int, err error) { for len(p) > 0 { - select { - case <-cw.ctx.Done(): - return n, cw.ctx.Err() - default: - var written int - written, err = cw.writer.Write(p) - n += written - if err != nil { - return n, err - } - p = p[written:] + var written int + written, err = cw.writer.Write(p) + n += written + if err != nil { + return n, err } + p = p[written:] } return n, nil }