Skip to content

Commit

Permalink
add: bypass rules filter (keploy#1837)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChinmayaSharma-hue authored May 10, 2024
1 parent 45f7d98 commit f898233
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 16 deletions.
12 changes: 9 additions & 3 deletions pkg/core/hooks/conn/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func NewFactory(inactivityThreshold time.Duration, logger *zap.Logger) *Factory

// ProcessActiveTrackers iterates over all conn the trackers and checks if they are complete. If so, it captures the ingress call and
// deletes the tracker. If the tracker is inactive for a long time, it deletes it.
func (factory *Factory) ProcessActiveTrackers(ctx context.Context, t chan *models.TestCase) {
func (factory *Factory) ProcessActiveTrackers(ctx context.Context, t chan *models.TestCase, opts models.IncomingOptions) {
factory.mutex.Lock()
defer factory.mutex.Unlock()
var trackersToDelete []ID
Expand All @@ -64,7 +64,7 @@ func (factory *Factory) ProcessActiveTrackers(ctx context.Context, t chan *model
utils.LogError(factory.logger, err, "failed to parse the http response from byte array", zap.Any("responseBuf", responseBuf))
continue
}
capture(ctx, factory.logger, t, parsedHTTPReq, parsedHTTPRes, reqTimestampTest, resTimestampTest)
capture(ctx, factory.logger, t, parsedHTTPReq, parsedHTTPRes, reqTimestampTest, resTimestampTest, opts)

} else if tracker.IsInactive(factory.inactivityThreshold) {
trackersToDelete = append(trackersToDelete, connID)
Expand All @@ -91,7 +91,7 @@ func (factory *Factory) GetOrCreate(connectionID ID) *Tracker {
return tracker
}

func capture(_ context.Context, logger *zap.Logger, t chan *models.TestCase, req *http.Request, resp *http.Response, reqTimeTest time.Time, resTimeTest time.Time) {
func capture(_ context.Context, logger *zap.Logger, t chan *models.TestCase, req *http.Request, resp *http.Response, reqTimeTest time.Time, resTimeTest time.Time, opts models.IncomingOptions) {
reqBody, err := io.ReadAll(req.Body)
if err != nil {
utils.LogError(logger, err, "failed to read the http request body")
Expand All @@ -110,6 +110,12 @@ func capture(_ context.Context, logger *zap.Logger, t chan *models.TestCase, req
utils.LogError(logger, err, "failed to read the http response body")
return
}

if isFiltered(logger, req, opts) {
logger.Debug("The request is a filtered request")
return
}

t <- &models.TestCase{
Version: models.GetVersion(),
Name: pkg.ToYamlHTTPHeader(req.Header)["Keploy-Test-Name"],
Expand Down
4 changes: 2 additions & 2 deletions pkg/core/hooks/conn/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
var eventAttributesSize = int(unsafe.Sizeof(SocketDataEvent{}))

// ListenSocket starts the socket event listeners
func ListenSocket(ctx context.Context, l *zap.Logger, openMap, dataMap, closeMap *ebpf.Map) (<-chan *models.TestCase, error) {
func ListenSocket(ctx context.Context, l *zap.Logger, openMap, dataMap, closeMap *ebpf.Map, opts models.IncomingOptions) (<-chan *models.TestCase, error) {
t := make(chan *models.TestCase, 500)
err := initRealTimeOffset()
if err != nil {
Expand All @@ -46,7 +46,7 @@ func ListenSocket(ctx context.Context, l *zap.Logger, openMap, dataMap, closeMap
return
default:
// TODO refactor this to directly consume the events from the maps
c.ProcessActiveTrackers(ctx, t)
c.ProcessActiveTrackers(ctx, t, opts)
time.Sleep(100 * time.Millisecond)
}
}
Expand Down
71 changes: 71 additions & 0 deletions pkg/core/hooks/conn/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@ package conn

import (
"fmt"
"go.keploy.io/server/v2/config"
proxyHttp "go.keploy.io/server/v2/pkg/core/proxy/integrations/http"
"go.keploy.io/server/v2/pkg/models"
"go.keploy.io/server/v2/utils"
"go.uber.org/zap"
"net/http"
"regexp"
"strconv"
"strings"
"time"

"golang.org/x/sys/unix"
Expand Down Expand Up @@ -40,6 +49,68 @@ func convertUnixNanoToTime(unixNano uint64) time.Time {
return time.Unix(seconds, nanoRemainder)
}

func isFiltered(logger *zap.Logger, req *http.Request, opts models.IncomingOptions) bool {
destPort, err := strconv.Atoi(strings.Split(req.Host, ":")[1])
if err != nil {
utils.LogError(logger, err, "failed to obtain destination port from request")
return false
}
var bypassRules []config.BypassRule

for _, filter := range opts.Filters {
bypassRules = append(bypassRules, filter.BypassRule)
}

// Host, Path and Port matching
headerOpts := models.OutgoingOptions{
Rules: bypassRules,
MongoPassword: "",
SQLDelay: 0,
FallBackOnMiss: false,
}
passThrough := proxyHttp.IsPassThrough(logger, req, uint(destPort), headerOpts)

for _, filter := range opts.Filters {
if filter.URLMethods != nil && len(filter.URLMethods) != 0 {
urlMethodMatch := false
for _, method := range filter.URLMethods {
if method == req.Method {
urlMethodMatch = true
break
}
}
passThrough = urlMethodMatch
if !passThrough {
continue
}
}
if filter.Headers != nil && len(filter.Headers) != 0 {
headerMatch := false
for filterHeaderKey, filterHeaderValue := range filter.Headers {
regex, err := regexp.Compile(filterHeaderValue)
if err != nil {
utils.LogError(logger, err, "failed to compile the header regex")
continue
}
if req.Header.Get(filterHeaderKey) != "" {
for _, value := range req.Header.Values(filterHeaderKey) {
headerMatch = regex.MatchString(value)
if headerMatch {
break
}
}
}
passThrough = headerMatch
if passThrough {
break
}
}
}
}

return passThrough
}

//// LogAny appends input of any type to a logs.txt file in the current directory
//func LogAny(value string) error {
//
Expand Down
4 changes: 2 additions & 2 deletions pkg/core/hooks/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,11 +492,11 @@ func (h *Hooks) load(_ context.Context, opts core.HookCfg) error {
return nil
}

func (h *Hooks) Record(ctx context.Context, _ uint64) (<-chan *models.TestCase, error) {
func (h *Hooks) Record(ctx context.Context, _ uint64, opts models.IncomingOptions) (<-chan *models.TestCase, error) {
// TODO use the session to get the app id
// and then use the app id to get the test cases chan
// and pass that to eBPF consumers/listeners
return conn.ListenSocket(ctx, h.logger, h.objects.SocketOpenEvents, h.objects.SocketDataEvents, h.objects.SocketCloseEvents)
return conn.ListenSocket(ctx, h.logger, h.objects.SocketOpenEvents, h.objects.SocketDataEvents, h.objects.SocketCloseEvents, opts)
}

func (h *Hooks) unLoad(_ context.Context) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/core/proxy/integrations/http/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func decodeHTTP(ctx context.Context, logger *zap.Logger, reqBuf []byte, clientCo
logger.Debug("after matching the http request", zap.Any("isMatched", ok), zap.Any("stub", stub), zap.Error(err))

if !ok {
if !isPassThrough(logger, request, dstCfg.Port, opts) {
if !IsPassThrough(logger, request, dstCfg.Port, opts) {
utils.LogError(logger, nil, "Didn't match any preExisting http mock", zap.Any("metadata", getReqMeta(request)))
}
if opts.FallBackOnMiss {
Expand Down
2 changes: 1 addition & 1 deletion pkg/core/proxy/integrations/http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func ParseFinalHTTP(_ context.Context, logger *zap.Logger, mock *finalHTTP, dest
}

// Check if the request is a passThrough request
if isPassThrough(logger, req, destPort, opts) {
if IsPassThrough(logger, req, destPort, opts) {
logger.Debug("The request is a passThrough request", zap.Any("metadata", getReqMeta(req)))
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/core/proxy/integrations/http/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ func isJSON(body []byte) bool {
return json.Unmarshal(body, &js) == nil
}

func isPassThrough(logger *zap.Logger, req *http.Request, destPort uint, opts models.OutgoingOptions) bool {
func IsPassThrough(logger *zap.Logger, req *http.Request, destPort uint, opts models.OutgoingOptions) bool {
passThrough := false

for _, bypass := range opts.Rules {
Expand Down
4 changes: 2 additions & 2 deletions pkg/core/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
"go.keploy.io/server/v2/pkg/models"
)

func (c *Core) GetIncoming(ctx context.Context, id uint64, _ models.IncomingOptions) (<-chan *models.TestCase, error) {
return c.Hooks.Record(ctx, id)
func (c *Core) GetIncoming(ctx context.Context, id uint64, opts models.IncomingOptions) (<-chan *models.TestCase, error) {
return c.Hooks.Record(ctx, id, opts)
}

func (c *Core) GetOutgoing(ctx context.Context, id uint64, opts models.OutgoingOptions) (<-chan *models.Mock, error) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/core/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type Hooks interface {
OutgoingInfo
TestBenchInfo
Load(ctx context.Context, id uint64, cfg HookCfg) error
Record(ctx context.Context, id uint64) (<-chan *models.TestCase, error)
Record(ctx context.Context, id uint64, opts models.IncomingOptions) (<-chan *models.TestCase, error)
}

type HookCfg struct {
Expand Down
2 changes: 1 addition & 1 deletion pkg/models/instrument.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type OutgoingOptions struct {
}

type IncomingOptions struct {
//Filters []config.Filter
Filters []config.Filter
}

type SetupOptions struct {
Expand Down
14 changes: 12 additions & 2 deletions pkg/service/record/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,12 @@ func (r *Recorder) Start(ctx context.Context) error {
}
}

incomingOpts := models.IncomingOptions{
Filters: r.config.Record.Filters,
}

// fetching test cases and mocks from the application and inserting them into the database
incomingChan, err = r.instrumentation.GetIncoming(ctx, appID, models.IncomingOptions{})
incomingChan, err = r.instrumentation.GetIncoming(ctx, appID, incomingOpts)
if err != nil {
stopReason = "failed to get incoming frames"
utils.LogError(r.logger, err, stopReason)
Expand All @@ -161,7 +165,13 @@ func (r *Recorder) Start(ctx context.Context) error {
return nil
})

outgoingChan, err = r.instrumentation.GetOutgoing(ctx, appID, models.OutgoingOptions{})
outgoingOpts := models.OutgoingOptions{
Rules: r.config.BypassRules,
MongoPassword: r.config.Test.MongoPassword,
FallBackOnMiss: r.config.Test.FallBackOnMiss,
}

outgoingChan, err = r.instrumentation.GetOutgoing(ctx, appID, outgoingOpts)
if err != nil {
stopReason = "failed to get outgoing frames"
utils.LogError(r.logger, err, stopReason)
Expand Down

0 comments on commit f898233

Please sign in to comment.