Skip to content

Commit

Permalink
refactor: sync mysql changes acc to latest code (keploy#1809)
Browse files Browse the repository at this point in the history
* refactor: sync mysql changes acc to latest code

Signed-off-by: Sarthak160 <[email protected]>

* Fixed lint issues

Signed-off-by: Chinmay <[email protected]>

* fix: Added an argument to DecodeMySQL function to find out the mode

Signed-off-by: Chinmay <[email protected]>

* refactor: remove set mocks function form the memdb interface

Signed-off-by: Sarthak160 <[email protected]>

---------

Signed-off-by: Sarthak160 <[email protected]>
Signed-off-by: Chinmay <[email protected]>
Co-authored-by: Chinmay <[email protected]>
  • Loading branch information
Sarthak160 and ChinmayaSharma-hue authored May 13, 2024
1 parent b9673e3 commit b1efc78
Show file tree
Hide file tree
Showing 30 changed files with 665 additions and 480 deletions.
10 changes: 9 additions & 1 deletion pkg/core/proxy/integrations/mysql/authMoreData.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

type NextAuthPacket struct {
PluginData byte `json:"plugin_data,omitempty" yaml:"plugin_data,omitempty"`
PluginData byte `yaml:"plugin_data"`
}

func decodeAuthMoreData(data []byte) (*NextAuthPacket, error) {
Expand All @@ -17,3 +17,11 @@ func decodeAuthMoreData(data []byte) (*NextAuthPacket, error) {
PluginData: data[0],
}, nil
}

// Encode function for Next Authentication method Packet
//func encodeAuthMoreData(packet *NextAuthPacket) ([]byte, error) {
// if packet.PluginData != 0x02 {
// return nil, errors.New("invalid PluginData value for NextAuthPacket")
// }
// return []byte{packet.PluginData}, nil
//}
6 changes: 3 additions & 3 deletions pkg/core/proxy/integrations/mysql/authSwitchRequestPacket.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import (
)

type AuthSwitchRequestPacket struct {
StatusTag byte `json:"status_tag,omitempty" yaml:"status_tag,omitempty"`
PluginName string `json:"plugin_name,omitempty" yaml:"plugin_name,omitempty"`
PluginAuthData string `json:"plugin_authdata,omitempty" yaml:"plugin_authdata,omitempty"`
StatusTag byte `yaml:"status_tag"`
PluginName string `yaml:"plugin_name"`
PluginAuthData string `yaml:"plugin_authdata"`
}

func decodeAuthSwitchRequest(data []byte) (*AuthSwitchRequestPacket, error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
)

type AuthSwitchResponsePacket struct {
AuthResponseData string `json:"auth_response_data,omitempty" yaml:"auth_response_data,omitempty"`
AuthResponseData string `yaml:"auth_response_data"`
}

func decodeAuthSwitchResponse(data []byte) (*AuthSwitchResponsePacket, error) {
Expand Down
13 changes: 6 additions & 7 deletions pkg/core/proxy/integrations/mysql/comChangeUserPacket.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
package mysql

import (
"encoding/base64"
"errors"
"strings"
)

type ComChangeUserPacket struct {
User string `json:"user,omitempty" yaml:"user,omitempty,flow"`
Auth string `json:"auth,omitempty" yaml:"auth,omitempty,flow"`
Db string `json:"db,omitempty" yaml:"db,omitempty,flow"`
CharacterSet uint8 `json:"character_set,omitempty" yaml:"character_set,omitempty,flow"`
AuthPlugin string `json:"auth_plugin,omitempty" yaml:"auth_plugin,omitempty,flow"`
User string `yaml:"user"`
Auth []byte `yaml:"auth,omitempty,flow"`
Db string `yaml:"db"`
CharacterSet uint8 `yaml:"character_set"`
AuthPlugin string `yaml:"auth_plugin"`
}

func decodeComChangeUser(data []byte) (ComChangeUserPacket, error) {
Expand All @@ -33,7 +32,7 @@ func decodeComChangeUser(data []byte) (ComChangeUserPacket, error) {

return ComChangeUserPacket{
User: user,
Auth: base64.StdEncoding.EncodeToString(auth),
Auth: auth,
Db: db,
CharacterSet: characterSet,
AuthPlugin: authPlugin,
Expand Down
6 changes: 3 additions & 3 deletions pkg/core/proxy/integrations/mysql/comFetchPacket.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import (
)

type ComStmtFetchPacket struct {
StatementID uint32 `json:"statement_id,omitempty" yaml:"statement_id,omitempty"`
RowCount uint32 `json:"row_count,omitempty" yaml:"row_count,omitempty"`
Info string `json:"info,omitempty" yaml:"info,omitempty"`
StatementID uint32 `yaml:"statement_id"`
RowCount uint32 `yaml:"row_count"`
Info string `yaml:"info"`
}

func decodeComStmtFetch(data []byte) (ComStmtFetchPacket, error) {
Expand Down
18 changes: 16 additions & 2 deletions pkg/core/proxy/integrations/mysql/comInitDb.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
package mysql

type ComInitDbPacket struct {
Status byte `json:"status,omitempty" yaml:"status,omitempty"`
DbName string `json:"db_name,omitempty" yaml:"db_name,omitempty"`
Status byte
DbName string
}

//func decodeComInitDb(data []byte) (*ComInitDbPacket, error) {
// if len(data) < 2 {
// return nil, errors.New("data too short for COM_INIT_DB")
// }
// status := data[0]
//
// // The rest of the packet after the command byte is the database name
// dbName := string(data[1:])
// return &ComInitDbPacket{
// Status: status,
// DbName: dbName,
// }, nil
//}
8 changes: 4 additions & 4 deletions pkg/core/proxy/integrations/mysql/comStmtCloseMoreData.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ import (
)

type ComStmtPreparePacket1 struct {
Header []byte `json:"header,omitempty" yaml:"header,omitempty"`
Query string `json:"query,omitempty" yaml:"query,omitempty"`
Header []byte
Query string
}

type ComStmtCloseAndPrepare struct {
StmtClose ComStmtClosePacket `json:"stmt_close,omitempty" yaml:"stmt_close,omitempty"`
StmtPrepare ComStmtPreparePacket1 `json:"stmt_prepare,omitempty" yaml:"stmt_prepare,omitempty"`
StmtClose ComStmtClosePacket
StmtPrepare ComStmtPreparePacket1
}

func decodeComStmtCloseMoreData(data []byte) (*ComStmtCloseAndPrepare, error) {
Expand Down
4 changes: 2 additions & 2 deletions pkg/core/proxy/integrations/mysql/comStmtClosePacket.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
)

type ComStmtClosePacket struct {
Status byte `json:"status,omitempty" yaml:"status,omitempty"`
StatementID uint32 `json:"statement_id,omitempty" yaml:"statement_id,omitempty"`
Status byte
StatementID uint32
}

func decodeComStmtClose(data []byte) (*ComStmtClosePacket, error) {
Expand Down
38 changes: 13 additions & 25 deletions pkg/core/proxy/integrations/mysql/comStmtPrepareOk.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import (
)

type StmtPrepareOk struct {
Status byte `json:"status,omitempty" yaml:"status,omitempty,flow"`
StatementID uint32 `json:"statement_id,omitempty" yaml:"statement_id,omitempty,flow"`
NumColumns uint16 `json:"num_columns,omitempty" yaml:"num_columns,omitempty,flow"`
NumParams uint16 `json:"num_params,omitempty" yaml:"num_params,omitempty,flow"`
WarningCount uint16 `json:"warning_count,omitempty" yaml:"warning_count,omitempty,flow"`
ColumnDefs []ColumnDefinition `json:"column_definitions,omitempty" yaml:"column_definitions,omitempty,flow"`
ParamDefs []ColumnDefinition `json:"param_definitions,omitempty" yaml:"param_definitions,omitempty,flow"`
Status byte `yaml:"status"`
StatementID uint32 `yaml:"statement_id"`
NumColumns uint16 `yaml:"num_columns"`
NumParams uint16 `yaml:"num_params"`
WarningCount uint16 `yaml:"warning_count"`
ColumnDefs []ColumnDefinition `yaml:"column_definitions,omitempty,flow"`
ParamDefs []ColumnDefinition `yaml:"param_definitions,omitempty,flow"`
}

func decodeComStmtPrepareOk(data []byte) (*StmtPrepareOk, error) {
Expand Down Expand Up @@ -191,24 +191,12 @@ func encodeStmtPrepareOk(packet *models.MySQLStmtPrepareOk) ([]byte, error) {

func encodeColumnDefinition(buf *bytes.Buffer, column *models.ColumnDefinition, seqNum *byte) error {
tmpBuf := &bytes.Buffer{}
if err := writeLengthEncodedString(tmpBuf, column.Catalog); err != nil {
return err
}
if err := writeLengthEncodedString(tmpBuf, column.Schema); err != nil {
return err
}
if err := writeLengthEncodedString(tmpBuf, column.Table); err != nil {
return err
}
if err := writeLengthEncodedString(tmpBuf, column.OrgTable); err != nil {
return err
}
if err := writeLengthEncodedString(tmpBuf, column.Name); err != nil {
return err
}
if err := writeLengthEncodedString(tmpBuf, column.OrgName); err != nil {
return err
}
writeLengthEncodedString(tmpBuf, column.Catalog)
writeLengthEncodedString(tmpBuf, column.Schema)
writeLengthEncodedString(tmpBuf, column.Table)
writeLengthEncodedString(tmpBuf, column.OrgTable)
writeLengthEncodedString(tmpBuf, column.Name)
writeLengthEncodedString(tmpBuf, column.OrgName)
tmpBuf.WriteByte(0x0C)
if err := binary.Write(tmpBuf, binary.LittleEndian, column.CharacterSet); err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion pkg/core/proxy/integrations/mysql/comStmtPreparePacket.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

type ComStmtPreparePacket struct {
Query string `json:"query,omitempty" yaml:"query,omitempty,flow"`
Query string
}

func decodeComStmtPrepare(data []byte) (*ComStmtPreparePacket, error) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/core/proxy/integrations/mysql/comStmtResetPacket.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

type COM_STMT_RESET struct {
StatementID uint32 `json:"statement_id,omitempty" yaml:"statement_id,omitempty,flow"`
StatementID uint32 `yaml:"statement_id"`
}

func decodeComStmtReset(packet []byte) (*COM_STMT_RESET, error) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package mysql

import (
"encoding/base64"
"encoding/binary"
"fmt"
)

type COM_STMT_SEND_LONG_DATA struct {
StatementID uint32 `json:"statement_id,omitempty" yaml:"statement_id,omitempty,flow"`
ParameterID uint16 `json:"parameter_id,omitempty" yaml:"parameter_id,omitempty,flow"`
Data string `json:"data,omitempty" yaml:"data,omitempty,flow"`
StatementID uint32 `yaml:"statement_id"`
ParameterID uint16 `yaml:"parameter_id"`
Data []byte `yaml:"data,omitempty,flow"`
}

func decodeComStmtSendLongData(packet []byte) (COM_STMT_SEND_LONG_DATA, error) {
Expand All @@ -22,6 +21,6 @@ func decodeComStmtSendLongData(packet []byte) (COM_STMT_SEND_LONG_DATA, error) {
return COM_STMT_SEND_LONG_DATA{
StatementID: stmtID,
ParameterID: paramID,
Data: base64.StdEncoding.EncodeToString(data),
Data: data,
}, nil
}
30 changes: 14 additions & 16 deletions pkg/core/proxy/integrations/mysql/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func decodeMySQL(ctx context.Context, logger *zap.Logger, clientConn net.Conn, d
errCh <- err
return
}
sqlMock, found := getFirstSQLMock(configMocks)
sqlMock, matchedIndex, found := getFirstSQLMock(configMocks)
if !found {
logger.Debug("No SQL mock found")
errCh <- err
Expand All @@ -59,7 +59,6 @@ func decodeMySQL(ctx context.Context, logger *zap.Logger, clientConn net.Conn, d
errCh <- err
return
}

_, err = clientConn.Write(binaryPacket)
if err != nil {
if ctx.Err() != nil {
Expand All @@ -69,17 +68,17 @@ func decodeMySQL(ctx context.Context, logger *zap.Logger, clientConn net.Conn, d
errCh <- err
return
}
matchedIndex := 0
matchedReqIndex := 0
configMocks[matchedIndex].Spec.MySQLResponses = append(configMocks[matchedIndex].Spec.MySQLResponses[:matchedReqIndex], configMocks[matchedIndex].Spec.MySQLResponses[matchedReqIndex+1:]...)
if len(configMocks[matchedIndex].Spec.MySQLResponses) == 0 {
configMocks = append(configMocks[:matchedIndex], configMocks[matchedIndex+1:]...)
err = mockDb.FlagMockAsUsed(configMocks[matchedIndex])
if err != nil {
utils.LogError(logger, err, "Failed to flag mock as used")
errCh <- err
return
}
//configMocks = append(configMocks[:matchedIndex], configMocks[matchedIndex+1:]...)
//err = mockDb.FlagMockAsUsed(configMocks[matchedIndex])
//if err != nil {
// utils.LogError(logger, err, "Failed to flag mock as used")
// errCh <- err
// return
//}
mockDb.DeleteUnFilteredMock(configMocks[matchedIndex])
}
//h.SetConfigMocks(configMocks)
firstLoop = false
Expand Down Expand Up @@ -131,7 +130,7 @@ func decodeMySQL(ctx context.Context, logger *zap.Logger, clientConn net.Conn, d
expectingHandshakeResponseTest = true
}

oprRequest, requestHeader, decodedRequest, err := DecodeMySQLPacket(logger, bytesToMySQLPacket(requestBuffer))
oprRequest, requestHeader, decodedRequest, err := DecodeMySQLPacket(logger, bytesToMySQLPacket(requestBuffer), clientConn, models.MODE_TEST)
if err != nil {
utils.LogError(logger, err, "Failed to decode MySQL packet")
errCh <- err
Expand Down Expand Up @@ -206,7 +205,6 @@ func decodeMySQL(ctx context.Context, logger *zap.Logger, clientConn net.Conn, d
errCh <- err
return
}

_, err = clientConn.Write(responseBinary)
if err != nil {
if ctx.Err() != nil {
Expand All @@ -228,11 +226,11 @@ func decodeMySQL(ctx context.Context, logger *zap.Logger, clientConn net.Conn, d
}
}

func getFirstSQLMock(configMocks []*models.Mock) (*models.Mock, bool) {
for _, mock := range configMocks {
func getFirstSQLMock(configMocks []*models.Mock) (*models.Mock, int, bool) {
for index, mock := range configMocks {
if len(mock.Spec.MySQLResponses) > 0 && mock.Kind == "SQL" && mock.Spec.MySQLResponses[0].Header.PacketType == "MySQLHandshakeV10" {
return mock, true
return mock, index, true
}
}
return nil, false
return nil, 0, false
}
Loading

0 comments on commit b1efc78

Please sign in to comment.