diff --git a/gemax/fs_test.go b/gemax/fs_test.go index 369e000..0c5481e 100644 --- a/gemax/fs_test.go +++ b/gemax/fs_test.go @@ -3,6 +3,7 @@ package gemax_test import ( "bytes" "context" + "crypto/x509" "net/url" "testing" "testing/fstest" @@ -44,6 +45,10 @@ func (req *incomingRequest) RemoteAddr() string { return req.remoteAddr } func (req *incomingRequest) URL() *url.URL { return req.url } +func (req *incomingRequest) Certificates() []*x509.Certificate { + return nil +} + type responseWriter struct { status status.Code meta string diff --git a/gemax/request.go b/gemax/request.go index 92be17c..9aa073d 100644 --- a/gemax/request.go +++ b/gemax/request.go @@ -23,6 +23,8 @@ const MaxRequestSize = int64(1024 + len("\r\n")) type IncomingRequest interface { URL() *url.URL RemoteAddr() string + // Certificates returns the TLS certificates provided by the client. + Certificates() []*x509.Certificate } var ( @@ -104,6 +106,10 @@ func (req *incomingRequest) RemoteAddr() string { return req.remoteAddr } +func (req *incomingRequest) Certificates() []*x509.Certificate { + return req.certs +} + // - found delimiter -> return data[:delimIndex+1], err // - found EOF -> return data, err // - found error -> return data, err diff --git a/gemax/request_test.go b/gemax/request_test.go index f656921..731c489 100644 --- a/gemax/request_test.go +++ b/gemax/request_test.go @@ -1,10 +1,27 @@ package gemax_test import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" "strings" + "sync" "testing" + "time" "github.com/ninedraft/gemax/gemax" + + "tailscale.com/net/memnet" +) + +var ( + clientCert = testCert("client") + serverCert = testCert("server") ) func TestParseIncomingRequest(t *testing.T) { @@ -73,6 +90,88 @@ func TestParseIncomingRequest(t *testing.T) { "gemini://example.com\r", expect{err: true}) } +func TestRequest_Certificates(test *testing.T) { + test.Parallel() + test.Log("Test that we can get the client certificates from the client request") + + wg := sync.WaitGroup{} + defer wg.Wait() + + a, b := memnet.NewConn(test.Name(), 10<<20) + defer func() { + _ = a.Close() + _ = b.Close() + }() + + deadline := time.Now().Add(5 * time.Second) + _ = a.SetDeadline(deadline) + _ = b.SetDeadline(deadline) + + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + test.Log("setting up server") + server := tls.Server(a, &tls.Config{ + //nolint:gosec // G402 - it's ok to skip verification for gemini server + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.RequireAnyClientCert, + VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error { + return nil + }, + VerifyConnection: func(tls.ConnectionState) error { + return nil + }, + }) + defer func() { _ = server.Close() }() + + test.Log("setting up client") + client := tls.Client(b, &tls.Config{ + //nolint:gosec // G402 - it's ok to skip verification for gemini server + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{clientCert}, + VerifyConnection: func(tls.ConnectionState) error { return nil }, + VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error { + return nil + }, + }) + + wg.Add(1) + go func() { + defer func() { _ = client.Close() }() + defer wg.Done() + + test.Log("sending request") + _, errRequest := fmt.Fprintf(client, "gemini://localhost:1968\r\n") + if errRequest != nil { + test.Error("sending request:", errRequest) + return + } + }() + + // run handshake manually, because ParseIncomingRequest + // accesses the connection state before the handshake is complete + test.Log("server: handshaking client") + if err := server.HandshakeContext(ctx); err != nil { + test.Fatal("server handshake:", err) + } + + test.Log("handling request") + req, errParseReq := gemax.ParseIncomingRequest(server, test.Name()) + if errParseReq != nil { + test.Fatal("parsing request:", errParseReq) + } + + certs := req.Certificates() + if len(certs) == 0 { + test.Error("no certificates in incoming request") + return + } + assertEq(test, certs[0].Issuer.CommonName, "client", "client cert issuer") +} + func assertEq[E comparable](t *testing.T, got, want E, format string, args ...any) { t.Helper() @@ -81,3 +180,33 @@ func assertEq[E comparable](t *testing.T, got, want E, format string, args ...an t.Errorf(format, args...) } } + +func testCert(organization string) tls.Certificate { + privateKey, errGenerate := rsa.GenerateKey(rand.Reader, 2048) + if errGenerate != nil { + panic("failed to generate private key: " + errGenerate.Error()) + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(int64(time.Now().Year())), + Subject: pkix.Name{ + CommonName: organization, + Organization: []string{organization}, + }, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), // Valid for 1 year + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + derBytes, errCert := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if errCert != nil { + panic("failed to create certificate: " + errCert.Error()) + } + + return tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: privateKey, + } +} diff --git a/gemax/server.go b/gemax/server.go index 3a98a0a..ea82b94 100644 --- a/gemax/server.go +++ b/gemax/server.go @@ -97,6 +97,12 @@ func (server *Server) Serve(ctx context.Context, listener net.Listener) error { go func() { defer wg.Done() defer server.removeTrack(track) + + if err := handshake(ctx, conn); err != nil { + server.logf("WARN: handshake with %q failed: %v", conn.RemoteAddr(), err) + return + } + server.handle(ctx, conn) }() } @@ -216,3 +222,11 @@ func (server *Server) buildHosts() { server.hosts[host] = struct{}{} } } + +func handshake(ctx context.Context, conn net.Conn) error { + if c, ok := conn.(*tls.Conn); ok { + return c.HandshakeContext(ctx) + } + + return nil +} diff --git a/gemax/server_test.go b/gemax/server_test.go index 5fbead5..2b88037 100644 --- a/gemax/server_test.go +++ b/gemax/server_test.go @@ -3,6 +3,7 @@ package gemax_test import ( "context" "crypto/tls" + "crypto/x509" "errors" "fmt" "io" @@ -287,6 +288,135 @@ func TestPageNotFound(test *testing.T) { }) } +func TestServer_Identity(test *testing.T) { + test.Parallel() + test.Log( + "Check that server fetches client certificates and passes them to the handler", + ) + + called := make(chan []*x509.Certificate) + server := gemax.Server{ + Logf: test.Logf, + Addr: testaddr.Addr(), + Hosts: []string{"example.com"}, + Handler: func(_ context.Context, rw gemax.ResponseWriter, req gemax.IncomingRequest) { + rw.WriteStatus(status.Success, "example text") + called <- req.Certificates() + }, + } + + ctx := context.Background() + + go func() { + _ = server.ListenAndServe(ctx, &tls.Config{ + //nolint:gosec // G402 - it's ok to skip verification for gemini server + InsecureSkipVerify: true, + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.RequireAnyClientCert, + }) + }() + + cfg := &tls.Config{ + ServerName: "example.com", + //nolint:gosec // G402 - it's ok to skip verification for gemini server + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{clientCert}, + VerifyConnection: func(tls.ConnectionState) error { return nil }, + VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error { + return nil + }, + } + + var conn net.Conn + + // wait for server to start + for i := 0; true; i++ { + c, errDial := tls.Dial("tcp", server.Addr, cfg) + + switch { + case i >= 20 && errDial != nil: + test.Fatal("server is not started: %w", errDial) + case errDial != nil: + test.Log("server is not started yet, retrying...") + continue + } + conn = c + break + } + + _, _ = fmt.Fprintf(conn, "gemini://example.com/\r\n") + _ = conn.Close() + + gotCerts := <-called + if len(gotCerts) != 1 { + test.Fatalf("got %d certificates, want 1", len(gotCerts)) + } + assertEq(test, gotCerts[0].Subject.CommonName, "client", "certificate CN") +} + +func TestServer_Identity_Error(test *testing.T) { + test.Parallel() + test.Log( + "Check that server handles bad handshake", + ) + + server := gemax.Server{ + Logf: test.Logf, + Addr: testaddr.Addr(), + Hosts: []string{"example.com"}, + Handler: func(context.Context, gemax.ResponseWriter, gemax.IncomingRequest) { + test.Errorf("handler must not be called") + }, + } + + ctx := context.Background() + errTest := errors.New("test error") + + go func() { + _ = server.ListenAndServe(ctx, &tls.Config{ + //nolint:gosec // G402 - it's ok to skip verification for gemini server + InsecureSkipVerify: true, + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.RequireAnyClientCert, + VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error { + return errTest + }, + }) + }() + + cfg := &tls.Config{ + ServerName: "example.com", + //nolint:gosec // G402 - it's ok to skip verification for gemini server + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{clientCert}, + VerifyConnection: func(tls.ConnectionState) error { return nil }, + VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error { + return nil + }, + } + + var conn net.Conn + + // wait for server to start + for i := 0; true; i++ { + c, errDial := tls.Dial("tcp", server.Addr, cfg) + + switch { + case i >= 20 && errDial != nil: + test.Fatal("server is not started: %w", errDial) + case errDial != nil: + test.Log("server is not started yet, retrying...") + continue + } + conn = c + _, _ = fmt.Fprintf(conn, "gemini://example.com/\r\n") + _ = conn.Close() + break + } +} + func setupServer(t *testing.T, handler gemax.Handler) (*memnet.Listener, *gemax.Server) { t.Helper() var server = &gemax.Server{ diff --git a/gemax/server_utils_test.go b/gemax/server_utils_test.go index 6792fcd..539d002 100644 --- a/gemax/server_utils_test.go +++ b/gemax/server_utils_test.go @@ -3,6 +3,7 @@ package gemax_test import ( "bytes" "context" + "crypto/x509" urlpkg "net/url" "reflect" "strings" @@ -71,6 +72,10 @@ func (req *request) RemoteAddr() string { return req.remoteAddr } +func (req *request) Certificates() []*x509.Certificate { + return nil +} + type responseRecorder struct { status status.Code meta string