Skip to content

Commit

Permalink
update: allow config files set address validation (#73)
Browse files Browse the repository at this point in the history
Add DialedAddressValidator support for JSON and Protobuf config files.

Signed-off-by: Gaukas Wang <[email protected]>
  • Loading branch information
gaukas authored Jun 27, 2024
1 parent 8979246 commit d4faf1b
Show file tree
Hide file tree
Showing 6 changed files with 431 additions and 68 deletions.
56 changes: 56 additions & 0 deletions address_validator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package water

import (
"errors"
)

var (
ErrAddressValidatorNotInitialized = errors.New("address validator not initialized properly")
ErrAddressValidationDenied = errors.New("address validation denied")
)

type addressValidator struct {
catchAll bool
allowlist map[string][]string // map[address]networks
denylist map[string][]string // map[address]networks
}

func (a *addressValidator) validate(network, address string) error {
if a.catchAll {
// only check denylist, otherwise allow
if a.denylist == nil {
return ErrAddressValidatorNotInitialized
}

if deniedNetworks, ok := a.denylist[address]; ok {
if deniedNetworks == nil {
return ErrAddressValidatorNotInitialized
}

for _, deniedNet := range deniedNetworks {
if deniedNet == network {
return ErrAddressValidationDenied
}
}
}
return nil
} else {
// only check allowlist, otherwise deny
if a.allowlist == nil {
return ErrAddressValidatorNotInitialized
}

if allowedNetworks, ok := a.allowlist[address]; ok {
if allowedNetworks == nil {
return ErrAddressValidatorNotInitialized
}

for _, allowedNet := range allowedNetworks {
if allowedNet == network {
return nil
}
}
}
return ErrAddressValidationDenied
}
}
75 changes: 75 additions & 0 deletions address_validator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package water

// package water instead of water_test to access unexported struct addressValidator and its unexported fields/methods

import "testing"

func Test_addressValidator_validate(t *testing.T) {
var a addressValidator

// test catchAll with nil denylist
a.catchAll = true

if err := a.validate("random net", "random address"); err != ErrAddressValidatorNotInitialized {
t.Errorf("Expected ErrAddressValidatorNotInitialized, got %v", err)
}

// test nil denylist entry
a.denylist = map[string][]string{
"denied address": nil,
}

if err := a.validate("random net", "denied address"); err != ErrAddressValidatorNotInitialized {
t.Errorf("Expected ErrAddressValidatorNotInitialized, got %v", err)
}

// test denied address on denied network
a.denylist["denied address"] = []string{"denied net"}

if err := a.validate("denied net", "denied address"); err != ErrAddressValidationDenied {
t.Errorf("Expected ErrAddressValidationDenied, got %v", err)
}

// test random network with denied address
if err := a.validate("random net", "denied address"); err != nil {
t.Errorf("Expected nil, got %v", err)
}

// test random address on denied network
if err := a.validate("denied net", "random address"); err != nil {
t.Errorf("Expected nil, got %v", err)
}

// test not catchAll with nil allowlist
a.catchAll = false

if err := a.validate("random net", "random address"); err != ErrAddressValidatorNotInitialized {
t.Errorf("Expected ErrAddressValidatorNotInitialized, got %v", err)
}

// test nil allowlist entry
a.allowlist = map[string][]string{
"allowed address": nil,
}

if err := a.validate("random net", "allowed address"); err != ErrAddressValidatorNotInitialized {
t.Errorf("Expected ErrAddressValidatorNotInitialized, got %v", err)
}

// test allowed address on allowed network
a.allowlist["allowed address"] = []string{"allowed net"}

if err := a.validate("allowed net", "allowed address"); err != nil {
t.Errorf("Expected nil, got %v", err)
}

// test random network with allowed address
if err := a.validate("random net", "allowed address"); err != ErrAddressValidationDenied {
t.Errorf("Expected ErrAddressValidationDenied, got %v", err)
}

// test random address on allowed network
if err := a.validate("allowed net", "random address"); err != ErrAddressValidationDenied {
t.Errorf("Expected ErrAddressValidationDenied, got %v", err)
}
}
35 changes: 35 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,16 @@ func (c *Config) UnmarshalJSON(data []byte) error {
}
}

if c.DialedAddressValidator == nil {
a := &addressValidator{
catchAll: confJson.Network.AddressValidation.CatchAll,
allowlist: confJson.Network.AddressValidation.Allowlist,
denylist: confJson.Network.AddressValidation.Denylist,
}

c.DialedAddressValidator = a.validate
}

if len(confJson.Network.Listener.Network) > 0 && len(confJson.Network.Listener.Address) > 0 {
c.NetworkListener, err = net.Listen(confJson.Network.Listener.Network, confJson.Network.Listener.Address)
if err != nil {
Expand Down Expand Up @@ -281,6 +291,31 @@ func (c *Config) UnmarshalProto(b []byte) error {
c.TransportModuleConfig = TransportModuleConfigFromBytes(confProto.GetTransportModule().GetConfig())
}

// Parse DialedAddressValidator if not already set
if c.DialedAddressValidator == nil {
a := &addressValidator{
catchAll: confProto.GetNetwork().GetAddressValidation().GetCatchAll(),
}

allowlist := confProto.GetNetwork().GetAddressValidation().GetAllowlist()
if len(allowlist) > 0 {
a.allowlist = make(map[string][]string)
for k, v := range allowlist {
a.allowlist[k] = v.GetNames()
}
}

denylist := confProto.GetNetwork().GetAddressValidation().GetDenylist()
if len(denylist) > 0 {
a.denylist = make(map[string][]string)
for k, v := range denylist {
a.denylist[k] = v.GetNames()
}
}

c.DialedAddressValidator = a.validate
}

// Parse NetworkListener
listenerNetwork, listenerAddress := confProto.GetNetwork().GetListener().GetNetwork(), confProto.GetNetwork().GetListener().GetAddress()
if len(listenerNetwork) > 0 && len(listenerAddress) > 0 {
Expand Down
5 changes: 5 additions & 0 deletions configbuilder/config.json.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ type ConfigJSON struct {

Network struct {
// DialerFunc string `json:"dialer_func,omitempty"` // we have no good way to represent a func in JSON format yet
AddressValidation struct {
CatchAll bool `json:"catch_all,omitempty"` // If set, will allow all unspecified addresses. Otherwise, unspecified addresses will be rejected.
Allowlist map[string][]string `json:"allowlist,omitempty"` // e.g. {"1.1.1.1:443": ["tcp", "udp"], "1.0.0.1:443": ["tcp"], ...}
Denylist map[string][]string `json:"denylist,omitempty"` // e.g. {"1.0.0.0:80": ["udp"], ...}
} `json:"address_validator,omitempty"`
Listener struct {
Network string `json:"network"` // e.g. "tcp"
Address string `json:"address"` // e.g. "0.0.0.0:0"
Expand Down
Loading

0 comments on commit d4faf1b

Please sign in to comment.