Skip to content

Commit

Permalink
Merge pull request #37 from ninedraft/limit-connections
Browse files Browse the repository at this point in the history
  • Loading branch information
ninedraft authored Nov 12, 2022
2 parents d44fd05 + 9400f5e commit ab106b5
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
test:
strategy:
matrix:
go-version: [1.18.x, 1.19.*]
go-version: [1.19.*]
os: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.os }}
steps:
Expand Down
38 changes: 37 additions & 1 deletion gemax/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ import (
"sync"

"github.com/ninedraft/gemax/gemax/status"
"golang.org/x/net/netutil"
)

// DefaultMaxConnections default number of maximum connections.
const DefaultMaxConnections = 256

// Handler describes a gemini protocol handler.
type Handler func(ctx context.Context, rw ResponseWriter, req IncomingRequest)

Expand All @@ -25,6 +29,11 @@ type Server struct {
ConnContext func(ctx context.Context, conn net.Conn) context.Context
Logf func(format string, args ...interface{})

// Maximum number of simultaneous connections served by Server.
// 0 - DefaultMaxConnections
// <0 - no limitation
MaxConnections int

mu sync.RWMutex
conns map[*connTrack]struct{}
listeners map[net.Listener]struct{}
Expand All @@ -42,15 +51,26 @@ func (server *Server) init() {
}

// ListenAndServe starts a TLS gemini server at specified server.
// It will block until context is canceled.
// It respects the MaxConnections setting.
// It will await all running handlers to end.
func (server *Server) ListenAndServe(ctx context.Context, tlsCfg *tls.Config) error {
server.init()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var lc = net.ListenConfig{}

var tcpListener, errListener = lc.Listen(ctx, "tcp", server.Addr)
if errListener != nil {
return fmt.Errorf("creating listener: %w", errListener)
}

if n := server.maxConnections(); n >= 0 {
var limited = netutil.LimitListener(tcpListener, n)
server.addListener(limited)
tcpListener = limited
}

var listener = tls.NewListener(tcpListener, tlsCfg)
go func() {
<-ctx.Done()
Expand All @@ -62,22 +82,38 @@ func (server *Server) ListenAndServe(ctx context.Context, tlsCfg *tls.Config) er
}

// Serve starts server on provided listener. Provided context will be passed to handlers.
// Serve will await all running handlers to end.
func (server *Server) Serve(ctx context.Context, listener net.Listener) error {
server.init()
server.addListener(listener)
var wg sync.WaitGroup
for {
var conn, errAccept = listener.Accept()
if errAccept != nil {
wg.Wait()
return fmt.Errorf("gemini server: %w", errAccept)
}
var track = server.addConn(conn)
wg.Add(1)
go func() {
defer wg.Done()
defer server.removeTrack(track)
server.handle(ctx, conn)
}()
}
}

func (server *Server) maxConnections() int {
switch {
case server.MaxConnections > 0:
return server.MaxConnections
case server.MaxConnections == 0:
return DefaultMaxConnections
default:
return -1
}
}

// Stop gracefully shuts down the server: closes all connections.
func (server *Server) Stop() {
server.closeAll()
Expand Down Expand Up @@ -118,7 +154,7 @@ func (server *Server) handle(ctx context.Context, conn net.Conn) {
code = status.BadRequest
}
if errParseReq != nil {
server.logf("WARN: bad request: remote_ip=%s, code=%s", conn.RemoteAddr(), code)
server.logf("WARN: bad request: remote_ip=%s, code=%s: %v", conn.RemoteAddr(), code, errParseReq)
rw.WriteStatus(code, status.Text(code))
return
}
Expand Down
81 changes: 81 additions & 0 deletions gemax/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"io"
"net"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -52,6 +54,7 @@ func TestServerBadRequest(test *testing.T) {
}

func TestServerInvalidHost(test *testing.T) {
test.Parallel()
var listener, server = setupEchoServer(test)
server.Hosts = []string{"example.com"}
defer func() { _ = listener.Close() }()
Expand All @@ -70,6 +73,7 @@ func TestServerInvalidHost(test *testing.T) {
}

func TestServerCancelListen(test *testing.T) {
test.Parallel()
var server = &gemax.Server{
Addr: testaddr.Addr(),
Logf: test.Logf,
Expand Down Expand Up @@ -99,6 +103,7 @@ func TestServerCancelListen(test *testing.T) {
}

func TestListenAndServe(test *testing.T) {
test.Parallel()
var server = &gemax.Server{
Addr: "localhost:40423",
Logf: test.Logf,
Expand Down Expand Up @@ -139,6 +144,82 @@ func TestListenAndServe(test *testing.T) {
test.Logf("%s / %v", data, errRead)
}

func TestLimitedListen(test *testing.T) {
test.Parallel()
var trigger = make(chan struct{})
var counter atomic.Int64

var server = &gemax.Server{
Addr: testaddr.Addr(),
Logf: test.Logf,
MaxConnections: 2,
Handler: func(_ context.Context, rw gemax.ResponseWriter, _ gemax.IncomingRequest) {
counter.Add(1)
<-trigger
_, _ = io.WriteString(rw, "example text")
},
}
test.Logf("loading test certs")
var cert, errCert = tls.LoadX509KeyPair("testdata/cert.pem", "testdata/key.pem")
if errCert != nil {
test.Fatal(errCert)
}
var cfg = &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
}
var ctx, cancel = context.WithCancel(context.Background())
test.Cleanup(cancel)
test.Logf("starting test server")

var wg = sync.WaitGroup{}
defer wg.Wait()

wg.Add(1)
go func() {
defer wg.Done()
test.Logf("test server: listening on %q", server.Addr)
var err = server.ListenAndServe(ctx, cfg)
switch {
case err == nil, errors.Is(err, net.ErrClosed):
return
default:
test.Errorf("test server: listening: %v", err)
}
}()
time.Sleep(time.Second)

var client = &gemax.Client{}

wg.Add(2 * server.MaxConnections)
for i := 0; i < 2*server.MaxConnections; i++ {
go func() {
defer wg.Done()
var resp, errFetch = client.Fetch(ctx, "gemini://"+server.Addr)
switch {
case errFetch == nil:
// pass
case errors.Is(errFetch, context.Canceled):
return
default:
test.Error("fetching: ", errFetch)
return
}
defer func() { _ = resp.Close() }()
expectResponse(test, resp, "example text")
var data, errRead = io.ReadAll(resp)
test.Logf("%s / %v", data, errRead)
}()
}

time.Sleep(time.Second)
if counter.Load() > int64(server.MaxConnections) {
test.Errorf("number of simultaneous connections must not exceed %d", server.MaxConnections)
}
cancel()
close(trigger)
}

// emulates michael-lazar/gemini-diagnostics localhost $PORT --checks='URLDotEscape'
func TestURLDotEscape(test *testing.T) {
var listener, server = setupEchoServer(test)
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/ninedraft/gemax

go 1.19

require golang.org/x/net v0.2.0
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU=
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=

0 comments on commit ab106b5

Please sign in to comment.