Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add unix sockets support for URLs #874

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions docs/docs/api-access-rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ Access Rules have four principal keys:
Oathkeeper Proxy as the Decision API does not forward the request to the
upstream.
- `url` (string): The URL the request will be forwarded to.
- UNIX sockets: use `unix://path/to/unix.sock` format for unix socket.
HTTP path can be set using `path` parameter. TLS can be enabled using
`tls` parameter.
- `preserve_host` (bool): If set to `false` (default), the forwarded request
will include the host and port of the `url` value. If `true`, the host and
port of the ORY Oathkeeper Proxy will be used instead:
Expand Down
19 changes: 19 additions & 0 deletions docs/docs/pipeline/authn.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ not currently supported, and will fail silently.

- `check_session_url` (string, required) - The session store to forward request
method/path/headers to for validation.
- UNIX sockets: use `unix://path/to/unix.sock` format for unix socket. HTTP
path can be set using `path` parameter. TLS can be enabled using `tls`
parameter.
- `only` ([]string, optional) - If set, only requests that have at least one of
the set cookies will be forwarded, others will be passed to the next
authenticator. If unset, all requests are forwarded.
Expand Down Expand Up @@ -299,6 +302,19 @@ authenticators:
preserve_query: true
```

```yaml
# Some Access Rule using unix socket endpoint : access-rule-3.yaml
id: access-rule-3
# match: ...
# upstream: ...
authenticators:
- handler: cookie_session
config:
check_session_url: unix://session/store/host.sock?path=/check-session&tls=true
only:
- sessionid
```

### Access Rule Example

```shell
Expand Down Expand Up @@ -345,6 +361,9 @@ not currently supported, and will fail silently.

- `check_session_url` (string, required) - The session store to forward request
method/path/headers to for validation.
- UNIX sockets: use `unix://path/to/unix.sock` format for unix socket. HTTP
path can be set using `path` parameter. TLS can be enabled using `tls`
parameter.
- `preserve_path` (boolean, optional) - If set, any path in `check_session_url`
will be preserved instead of replacing the path with the path of the request
being checked.
Expand Down
6 changes: 6 additions & 0 deletions docs/docs/pipeline/authz.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ if it returns a "403 Forbidden" response code, the access is denied.
- `remote` (string, required) - The remote authorizer's URL. The remote
authorizer is expected to return either "200 OK" or "403 Forbidden" to
allow/deny access.
- UNIX sockets: use `unix://path/to/unix.sock` format for unix socket. HTTP
path can be set using `path` parameter. TLS can be enabled using `tls`
parameter.
- `headers` (map of strings, optional) - The HTTP headers sent to the remote
authorizer. The values will be parsed by the Go
[`text/template`](https://golang.org/pkg/text/template/) package and applied
Expand Down Expand Up @@ -373,6 +376,9 @@ Forbidden" response code, the access is denied.
- `remote` (string, required) - The remote authorizer's URL. The remote
authorizer is expected to return either "200 OK" or "403 Forbidden" to
allow/deny access.
- UNIX sockets: use `unix://path/to/unix.sock` format for unix socket. HTTP
path can be set using `path` parameter. TLS can be enabled using `tls`
parameter.
- `payload` (string, required) - The request's JSON payload sent to the remote
authorizer. The string will be parsed by the Go
[`text/template`](https://golang.org/pkg/text/template/) package and applied
Expand Down
86 changes: 86 additions & 0 deletions helper/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package helper

import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
)

type transport struct {
base *http.Transport
dialer *net.Dialer
tlsDialer *tls.Dialer
}

func (t *transport) handleUnixAddr(addr string) (string, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return "", err

Check warning on line 21 in helper/transport.go

View check run for this annotation

Codecov / codecov/patch

helper/transport.go#L21

Added line #L21 was not covered by tests
}
path, err := url.PathUnescape(host)
if err != nil {
return "", err

Check warning on line 25 in helper/transport.go

View check run for this annotation

Codecov / codecov/patch

helper/transport.go#L25

Added line #L25 was not covered by tests
}
return path, nil
}

func (t *transport) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if path, err := t.handleUnixAddr(addr); err != nil {
return nil, err

Check warning on line 32 in helper/transport.go

View check run for this annotation

Codecov / codecov/patch

helper/transport.go#L32

Added line #L32 was not covered by tests
} else {
return t.dialer.DialContext(ctx, "unix", path)
}
}

func (t *transport) dialTlsContext(ctx context.Context, network, addr string) (net.Conn, error) {
if path, err := t.handleUnixAddr(addr); err != nil {
return nil, err

Check warning on line 40 in helper/transport.go

View check run for this annotation

Codecov / codecov/patch

helper/transport.go#L40

Added line #L40 was not covered by tests
} else {
return t.tlsDialer.DialContext(ctx, "unix", path)
}
}

func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) {
if r.URL != nil {
switch r.URL.Scheme {
case "http", "https":
return http.DefaultTransport.RoundTrip(r)
case "unix":
urlValues := r.URL.Query()
req := r.Clone(r.Context())
if urlValues.Get("tls") != "" {
req.URL.Scheme = "https"
} else {
req.URL.Scheme = "http"
}
req.URL.Host = url.QueryEscape(r.URL.Path)
req.URL.Path = urlValues.Get("path")
v := req.URL.Query()
v.Del("tls")
v.Del("path")
req.URL.RawQuery = v.Encode()
return t.base.RoundTrip(req)
default:

Check warning on line 66 in helper/transport.go

View check run for this annotation

Codecov / codecov/patch

helper/transport.go#L66

Added line #L66 was not covered by tests
}
}
return nil, fmt.Errorf("invalid request")
}

func NewRoundTripper() http.RoundTripper {
base := http.DefaultTransport.(*http.Transport).Clone()
dialer := &net.Dialer{}
t := &transport{
base: base,
dialer: dialer,
tlsDialer: &tls.Dialer{
NetDialer: dialer,
Config: base.TLSClientConfig,
},
}
t.base.DialContext = t.dialContext
t.base.DialTLSContext = t.dialTlsContext
return t
}
194 changes: 194 additions & 0 deletions helper/transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package helper

import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httputil"
"net/url"
"path"
"reflect"
"testing"
"time"
)

const pattern = "/test/roundtrip"

func generateCertificate() (*tls.Certificate, error) {
priv, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
return nil, err
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 180),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
crt, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return nil, err
}
return &tls.Certificate{
Certificate: [][]byte{crt},
PrivateKey: priv,
}, nil
}

func TestRoundTrip(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc(pattern, func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
rw.Write([]byte(r.URL.RawQuery))
})
cert, err := generateCertificate()
if err != nil {
t.Error(err)
}
dt := http.DefaultTransport.(*http.Transport)
dt.TLSHandshakeTimeout = time.Second
dt.IdleConnTimeout = time.Second
dt.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
tests := []struct {
name string
expect func() (net.Listener, *http.Request, bool)
}{
{
"Invalid request",
func() (net.Listener, *http.Request, bool) {
return nil, &http.Request{}, true
},
},
{
"HTTP : Dial error",
func() (net.Listener, *http.Request, bool) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", path.Join("127.0.0.1:12345", pattern)), nil)
if err != nil {
t.Error(err)
}
return nil, req, true
},
},
{
"UNIX : Dial error",
func() (net.Listener, *http.Request, bool) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", path.Join("unix://path/to/unix.sock?path=%s", pattern)), nil)
if err != nil {
t.Error(err)
}
return nil, req, true
},
},
{
"HTTP : OK",
func() (net.Listener, *http.Request, bool) {
lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Error(err)
}
go http.Serve(lis, mux)
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", path.Join(lis.Addr().String(), pattern)), nil)
if err != nil {
t.Error(err)
}
return lis, req, false
},
},
{
"UNIX : OK",
func() (net.Listener, *http.Request, bool) {
lis, err := net.Listen("unix", path.Join(t.TempDir(), "unix.sock"))
if err != nil {
t.Error(err)
}
go http.Serve(lis, mux)
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("unix://%s?path=%s&a=1&b=2", lis.Addr().String(), pattern), nil)
if err != nil {
t.Error(err)
}
return lis, req, false
},
},
{
"HTTP + TLS : OK",
func() (net.Listener, *http.Request, bool) {
lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Error(err)
}
lis = tls.NewListener(lis, &tls.Config{
NextProtos: []string{"http/1.1"},
Certificates: []tls.Certificate{*cert},
})
go http.Serve(lis, mux)
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://%s", path.Join(lis.Addr().String(), pattern)), nil)
if err != nil {
t.Error(err)
}
return lis, req, false
},
},
{
"UNIX + TLS : OK",
func() (net.Listener, *http.Request, bool) {
lis, err := net.Listen("unix", path.Join(t.TempDir(), "unix.sock"))
if err != nil {
t.Error(err)
}
lis = tls.NewListener(lis, &tls.Config{
NextProtos: []string{"http/1.1"},
Certificates: []tls.Certificate{*cert},
})
go http.Serve(lis, mux)
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("unix://%s?path=%s&tls=true&a=1&b=2", lis.Addr().String(), pattern), nil)
if err != nil {
t.Error(err)
}
return lis, req, false
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := NewRoundTripper()
lis, xreq, wantErr := tt.expect()
res, err := r.RoundTrip(xreq)
if wantErr && err == nil || !wantErr && err != nil {
t.Errorf("want err : %v, got err : %v", wantErr, err)
}
if res != nil {
if res.StatusCode != http.StatusOK {
t.Errorf("want code : %v, got code : %v", http.StatusOK, res.StatusCode)
s, _ := httputil.DumpRequest(xreq, false)
fmt.Println(string(s))
s, _ = httputil.DumpResponse(res, false)
fmt.Println(string(s))
} else {
testURL := url.URL{}
b, _ := io.ReadAll(res.Body)
testURL.RawQuery = string(b)
for k, v := range xreq.URL.Query() {
if k != "path" && k != "tls" && !reflect.DeepEqual(v, testURL.Query()[k]) {
t.Errorf("want query : %v, got query : %v", testURL.Query()[k], v)
}
}
}
}
if lis != nil {
lis.Close()
}
})
}
}
9 changes: 8 additions & 1 deletion pipeline/authn/authenticator_bearer_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,18 @@ type AuthenticatorBearerTokenConfiguration struct {

type AuthenticatorBearerToken struct {
c configuration.Provider
h *http.Client
}

func NewAuthenticatorBearerToken(c configuration.Provider) *AuthenticatorBearerToken {
return &AuthenticatorBearerToken{
c: c,
h: &http.Client{
Transport: helper.NewRoundTripper(),
CheckRedirect: http.DefaultClient.CheckRedirect,
Jar: http.DefaultClient.Jar,
Timeout: http.DefaultClient.Timeout,
},
}
}

Expand Down Expand Up @@ -85,7 +92,7 @@ func (a *AuthenticatorBearerToken) Authenticate(r *http.Request, session *Authen
return errors.WithStack(ErrAuthenticatorNotResponsible)
}

body, err := forwardRequestToSessionStore(r, cf.CheckSessionURL, cf.PreserveQuery, cf.PreservePath, cf.PreserveHost, cf.SetHeaders)
body, err := forwardRequestToSessionStore(r, a.h, cf.CheckSessionURL, cf.PreserveQuery, cf.PreservePath, cf.PreserveHost, cf.SetHeaders)
if err != nil {
return err
}
Expand Down
Loading