Skip to content

Commit

Permalink
#7 Add IP validation, go lint
Browse files Browse the repository at this point in the history
  • Loading branch information
thegodenage committed Mar 13, 2024
1 parent 154296d commit 3077d70
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 20 deletions.
3 changes: 3 additions & 0 deletions internal/ddos/ip_treeset.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func (ip *SyncIPTreeSetProvider) GetSet() *treeset.Set {
return ip.set
}

//nolint:unused
func (ip *SyncIPTreeSetProvider) startUpdating() {
go func() {
for {
Expand All @@ -59,6 +60,7 @@ func (ip *SyncIPTreeSetProvider) startUpdating() {
}()
}

//nolint:unused
func (ip *SyncIPTreeSetProvider) tryUpdateSet() {
ip.mu.Lock()
defer ip.mu.Unlock()
Expand All @@ -77,6 +79,7 @@ func (ip *SyncIPTreeSetProvider) tryUpdateSet() {

type fetchIPSliceFunc func(ctx context.Context) ([]net.IP, error)

//nolint:unused
func ipComparator(a interface{}, b interface{}) int {
aIP := a.(net.IP)
bIP := b.(net.IP)
Expand Down
6 changes: 3 additions & 3 deletions internal/request/ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const (
HeaderXClusterClientIP = "X-Cluster-Client-Ip"
)

// headers are the request headers that can provide us with the ip addresses
// headers are the Request headers that can provide us with the ip addresses
var headers = []string{
HeaderXForwardedFor,
HeaderXForwarded,
Expand All @@ -42,15 +42,15 @@ func GetRealIPAddress(r http.Request) (net.IP, error) {
if err != nil {
addrStr, err = getIPStringFromRequestRemoteAddress(r)
if err != nil {
return nil, fmt.Errorf("get ip string from request remote address")
return nil, fmt.Errorf("get ip string from Request remote address")
}
}

if ipAddr := net.ParseIP(addrStr); ipAddr != nil {
return ipAddr, nil
}

return nil, errors.New("cannot get ip address from the request")
return nil, errors.New("cannot get ip address from the Request")
}

func getIPStringFromHeaders(r http.Request) (string, error) {
Expand Down
15 changes: 8 additions & 7 deletions internal/request/wrapper.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
package request

import (
"net"
"net/http"
)

type Wrapper struct {
request *http.Request
Request *http.Request
IPAddress *net.IP
}

func NewRequestWrapper(r *http.Request) *Wrapper {
return &Wrapper{request: r}
}

func (w *Wrapper) Request() *http.Request {
return w.request
func NewRequestWrapper(r *http.Request, ipAddress *net.IP) *Wrapper {
return &Wrapper{
Request: r,
IPAddress: ipAddress,
}
}
12 changes: 9 additions & 3 deletions internal/waf/guard/ddos.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package guard

import (
"context"
"fmt"

"waffle/internal/ddos"
"waffle/internal/request"
)
Expand All @@ -9,7 +12,10 @@ type DDOS struct {
ipValidator ddos.IPValidator
}

func (D *DDOS) Validate(rw *request.Wrapper) error {
//TODO implement me
panic("implement me")
func (d *DDOS) Validate(ctx context.Context, rw *request.Wrapper) error {
if err := d.ipValidator.Validate(ctx, rw.IPAddress); err != nil {
return fmt.Errorf("validate ip using ip validator: %w", err)
}

return nil
}
9 changes: 5 additions & 4 deletions internal/waf/guard/guard.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package guard

import (
"context"
"fmt"
"sync"

Expand All @@ -10,7 +11,7 @@ import (
// Defender must be implemented by the struct
// representing defense rule (or set of rules).
type Defender interface {
Validate(rw *request.Wrapper) error
Validate(ctx context.Context, rw *request.Wrapper) error
}

// DefenseCoordinator coordinates defense. It validates request against set of defenders.
Expand All @@ -24,7 +25,7 @@ func NewDefenseCoordinator(defenders []Defender) *DefenseCoordinator {
return &DefenseCoordinator{defenders: defenders}
}

func (d *DefenseCoordinator) Validate(rw *request.Wrapper) error {
func (d *DefenseCoordinator) Validate(ctx context.Context, rw *request.Wrapper) error {
var wg sync.WaitGroup

errChan := make(chan error)
Expand All @@ -36,7 +37,7 @@ func (d *DefenseCoordinator) Validate(rw *request.Wrapper) error {
wg.Add(1)

go func(rw *request.Wrapper, d Defender, errChan chan error) {
if err := d.Validate(rw); err != nil {
if err := d.Validate(ctx, rw); err != nil {
errChan <- err
}

Expand All @@ -48,7 +49,7 @@ func (d *DefenseCoordinator) Validate(rw *request.Wrapper) error {
}()

select {
case <-rw.Request().Context().Done():
case <-rw.Request.Context().Done():
return nil
case err, ok := <-errChan:
if !ok {
Expand Down
5 changes: 3 additions & 2 deletions internal/waf/guard/xss.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package guard

import (
"context"
"errors"
"io"

Expand All @@ -13,8 +14,8 @@ type XSS struct {

// Validate validates if given input is XSS. It only returns error
// if given input is XSS, in other cases it returns nil.
func (X *XSS) Validate(rw *request.Wrapper) error {
body, err := io.ReadAll(rw.Request().Body)
func (X *XSS) Validate(_ context.Context, rw *request.Wrapper) error {
body, err := io.ReadAll(rw.Request.Body)
if err != nil {
return nil
}
Expand Down
6 changes: 5 additions & 1 deletion internal/waf/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ func NewHandler(
var _ http.Handler = (*Handler)(nil)

func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

ipAddr, err := request.GetRealIPAddress(*r)
if err != nil {
w.WriteHeader(http.StatusForbidden)
Expand All @@ -47,7 +49,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
tmp := h.limiter.SetRate(r.Context(), ipAddr, time.Now().Add(time.Second*5))
_, _ = w.Write([]byte(tmp))

if err := h.defender.Validate(guard.NewRequestWrapper(r)); err != nil {
requestWrapper := request.NewRequestWrapper(r, &ipAddr)

if err := h.defender.Validate(ctx, requestWrapper); err != nil {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(err.Error()))
return
Expand Down

0 comments on commit 3077d70

Please sign in to comment.