Skip to content

Commit

Permalink
Merge pull request #41 from ninedraft/server-request-certs
Browse files Browse the repository at this point in the history
add parsing and fetching server requests
  • Loading branch information
ninedraft authored Jun 18, 2023
2 parents 2fe9c69 + f173401 commit 6ecc348
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 0 deletions.
5 changes: 5 additions & 0 deletions gemax/fs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gemax_test
import (
"bytes"
"context"
"crypto/x509"
"net/url"
"testing"
"testing/fstest"
Expand Down Expand Up @@ -44,6 +45,10 @@ func (req *incomingRequest) RemoteAddr() string { return req.remoteAddr }

func (req *incomingRequest) URL() *url.URL { return req.url }

func (req *incomingRequest) Certificates() []*x509.Certificate {
return nil
}

type responseWriter struct {
status status.Code
meta string
Expand Down
6 changes: 6 additions & 0 deletions gemax/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ const MaxRequestSize = int64(1024 + len("\r\n"))
type IncomingRequest interface {
URL() *url.URL
RemoteAddr() string
// Certificates returns the TLS certificates provided by the client.
Certificates() []*x509.Certificate
}

var (
Expand Down Expand Up @@ -104,6 +106,10 @@ func (req *incomingRequest) RemoteAddr() string {
return req.remoteAddr
}

func (req *incomingRequest) Certificates() []*x509.Certificate {
return req.certs
}

// - found delimiter -> return data[:delimIndex+1], err
// - found EOF -> return data, err
// - found error -> return data, err
Expand Down
129 changes: 129 additions & 0 deletions gemax/request_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
package gemax_test

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
"strings"
"sync"
"testing"
"time"

"github.com/ninedraft/gemax/gemax"

"tailscale.com/net/memnet"
)

var (
clientCert = testCert("client")
serverCert = testCert("server")
)

func TestParseIncomingRequest(t *testing.T) {
Expand Down Expand Up @@ -73,6 +90,88 @@ func TestParseIncomingRequest(t *testing.T) {
"gemini://example.com\r", expect{err: true})
}

func TestRequest_Certificates(test *testing.T) {
test.Parallel()
test.Log("Test that we can get the client certificates from the client request")

wg := sync.WaitGroup{}
defer wg.Wait()

a, b := memnet.NewConn(test.Name(), 10<<20)
defer func() {
_ = a.Close()
_ = b.Close()
}()

deadline := time.Now().Add(5 * time.Second)
_ = a.SetDeadline(deadline)
_ = b.SetDeadline(deadline)

ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()

test.Log("setting up server")
server := tls.Server(a, &tls.Config{
//nolint:gosec // G402 - it's ok to skip verification for gemini server
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS13,
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.RequireAnyClientCert,
VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error {
return nil
},
VerifyConnection: func(tls.ConnectionState) error {
return nil
},
})
defer func() { _ = server.Close() }()

test.Log("setting up client")
client := tls.Client(b, &tls.Config{
//nolint:gosec // G402 - it's ok to skip verification for gemini server
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS13,
Certificates: []tls.Certificate{clientCert},
VerifyConnection: func(tls.ConnectionState) error { return nil },
VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error {
return nil
},
})

wg.Add(1)
go func() {
defer func() { _ = client.Close() }()
defer wg.Done()

test.Log("sending request")
_, errRequest := fmt.Fprintf(client, "gemini://localhost:1968\r\n")
if errRequest != nil {
test.Error("sending request:", errRequest)
return
}
}()

// run handshake manually, because ParseIncomingRequest
// accesses the connection state before the handshake is complete
test.Log("server: handshaking client")
if err := server.HandshakeContext(ctx); err != nil {
test.Fatal("server handshake:", err)
}

test.Log("handling request")
req, errParseReq := gemax.ParseIncomingRequest(server, test.Name())
if errParseReq != nil {
test.Fatal("parsing request:", errParseReq)
}

certs := req.Certificates()
if len(certs) == 0 {
test.Error("no certificates in incoming request")
return
}
assertEq(test, certs[0].Issuer.CommonName, "client", "client cert issuer")
}

func assertEq[E comparable](t *testing.T, got, want E, format string, args ...any) {
t.Helper()

Expand All @@ -81,3 +180,33 @@ func assertEq[E comparable](t *testing.T, got, want E, format string, args ...an
t.Errorf(format, args...)
}
}

func testCert(organization string) tls.Certificate {
privateKey, errGenerate := rsa.GenerateKey(rand.Reader, 2048)
if errGenerate != nil {
panic("failed to generate private key: " + errGenerate.Error())
}

template := x509.Certificate{
SerialNumber: big.NewInt(int64(time.Now().Year())),
Subject: pkix.Name{
CommonName: organization,
Organization: []string{organization},
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour), // Valid for 1 year
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}

derBytes, errCert := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if errCert != nil {
panic("failed to create certificate: " + errCert.Error())
}

return tls.Certificate{
Certificate: [][]byte{derBytes},
PrivateKey: privateKey,
}
}
14 changes: 14 additions & 0 deletions gemax/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ func (server *Server) Serve(ctx context.Context, listener net.Listener) error {
go func() {
defer wg.Done()
defer server.removeTrack(track)

if err := handshake(ctx, conn); err != nil {
server.logf("WARN: handshake with %q failed: %v", conn.RemoteAddr(), err)
return
}

server.handle(ctx, conn)
}()
}
Expand Down Expand Up @@ -216,3 +222,11 @@ func (server *Server) buildHosts() {
server.hosts[host] = struct{}{}
}
}

func handshake(ctx context.Context, conn net.Conn) error {
if c, ok := conn.(*tls.Conn); ok {
return c.HandshakeContext(ctx)
}

return nil
}
130 changes: 130 additions & 0 deletions gemax/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gemax_test
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -287,6 +288,135 @@ func TestPageNotFound(test *testing.T) {
})
}

func TestServer_Identity(test *testing.T) {
test.Parallel()
test.Log(
"Check that server fetches client certificates and passes them to the handler",
)

called := make(chan []*x509.Certificate)
server := gemax.Server{
Logf: test.Logf,
Addr: testaddr.Addr(),
Hosts: []string{"example.com"},
Handler: func(_ context.Context, rw gemax.ResponseWriter, req gemax.IncomingRequest) {
rw.WriteStatus(status.Success, "example text")
called <- req.Certificates()
},
}

ctx := context.Background()

go func() {
_ = server.ListenAndServe(ctx, &tls.Config{
//nolint:gosec // G402 - it's ok to skip verification for gemini server
InsecureSkipVerify: true,
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.RequireAnyClientCert,
})
}()

cfg := &tls.Config{
ServerName: "example.com",
//nolint:gosec // G402 - it's ok to skip verification for gemini server
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS13,
Certificates: []tls.Certificate{clientCert},
VerifyConnection: func(tls.ConnectionState) error { return nil },
VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error {
return nil
},
}

var conn net.Conn

// wait for server to start
for i := 0; true; i++ {
c, errDial := tls.Dial("tcp", server.Addr, cfg)

switch {
case i >= 20 && errDial != nil:
test.Fatal("server is not started: %w", errDial)
case errDial != nil:
test.Log("server is not started yet, retrying...")
continue
}
conn = c
break
}

_, _ = fmt.Fprintf(conn, "gemini://example.com/\r\n")
_ = conn.Close()

gotCerts := <-called
if len(gotCerts) != 1 {
test.Fatalf("got %d certificates, want 1", len(gotCerts))
}
assertEq(test, gotCerts[0].Subject.CommonName, "client", "certificate CN")
}

func TestServer_Identity_Error(test *testing.T) {
test.Parallel()
test.Log(
"Check that server handles bad handshake",
)

server := gemax.Server{
Logf: test.Logf,
Addr: testaddr.Addr(),
Hosts: []string{"example.com"},
Handler: func(context.Context, gemax.ResponseWriter, gemax.IncomingRequest) {
test.Errorf("handler must not be called")
},
}

ctx := context.Background()
errTest := errors.New("test error")

go func() {
_ = server.ListenAndServe(ctx, &tls.Config{
//nolint:gosec // G402 - it's ok to skip verification for gemini server
InsecureSkipVerify: true,
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.RequireAnyClientCert,
VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error {
return errTest
},
})
}()

cfg := &tls.Config{
ServerName: "example.com",
//nolint:gosec // G402 - it's ok to skip verification for gemini server
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS13,
Certificates: []tls.Certificate{clientCert},
VerifyConnection: func(tls.ConnectionState) error { return nil },
VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error {
return nil
},
}

var conn net.Conn

// wait for server to start
for i := 0; true; i++ {
c, errDial := tls.Dial("tcp", server.Addr, cfg)

switch {
case i >= 20 && errDial != nil:
test.Fatal("server is not started: %w", errDial)
case errDial != nil:
test.Log("server is not started yet, retrying...")
continue
}
conn = c
_, _ = fmt.Fprintf(conn, "gemini://example.com/\r\n")
_ = conn.Close()
break
}
}

func setupServer(t *testing.T, handler gemax.Handler) (*memnet.Listener, *gemax.Server) {
t.Helper()
var server = &gemax.Server{
Expand Down
5 changes: 5 additions & 0 deletions gemax/server_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gemax_test
import (
"bytes"
"context"
"crypto/x509"
urlpkg "net/url"
"reflect"
"strings"
Expand Down Expand Up @@ -71,6 +72,10 @@ func (req *request) RemoteAddr() string {
return req.remoteAddr
}

func (req *request) Certificates() []*x509.Certificate {
return nil
}

type responseRecorder struct {
status status.Code
meta string
Expand Down

0 comments on commit 6ecc348

Please sign in to comment.