Skip to content

Commit

Permalink
Merge pull request #38 from ninedraft/fix-request-parsing
Browse files Browse the repository at this point in the history
fix request parsing
  • Loading branch information
ninedraft authored Jun 17, 2023
2 parents 3158c2a + f00ceca commit 0f98247
Show file tree
Hide file tree
Showing 17 changed files with 900 additions and 144 deletions.
102 changes: 92 additions & 10 deletions gemax/request.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
package gemax

import (
"bytes"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/fs"
"net/url"
"strings"

"golang.org/x/exp/slices"
)

var requestSuffix = []byte("\n")

// MaxRequestSize is the maximum incoming request size in bytes.
const MaxRequestSize = 1026
const MaxRequestSize = int64(1024 + len("\r\n"))

// IncomingRequest describes a server side request object.
type IncomingRequest interface {
Expand All @@ -21,31 +29,72 @@ var (
errDotPath = errors.New("dots in path are not permitted")
)

var ErrBadRequest = errors.New("bad request")

// ParseIncomingRequest constructs an IncomingRequest from bytestream
// and additional parameters (remote address for now).
func ParseIncomingRequest(re io.Reader, remoteAddr string) (IncomingRequest, error) {
var reader = io.LimitReader(re, MaxRequestSize)
var u string
var _, errReadRequest = fmt.Fscanf(reader, "%s\r\n", &u)
if errReadRequest != nil {
return nil, fmt.Errorf("bad request: %w", errReadRequest)
var certs []*x509.Certificate
if tlsConn, ok := re.(*tls.Conn); ok {
certs = slices.Clone(tlsConn.ConnectionState().PeerCertificates)
}

re = io.LimitReader(re, MaxRequestSize)

line, errLine := readUntil(re, '\n')
if errLine != nil {
return nil, errLine
}

if !bytes.HasSuffix(line, requestSuffix) {
return nil, ErrBadRequest
}
var parsed, errParse = url.ParseRequestURI(u)

line = bytes.TrimRight(line, "\r\n")

parsed, errParse := url.ParseRequestURI(string(line))
if errParse != nil {
return nil, fmt.Errorf("bad request: %w", errParse)
return nil, fmt.Errorf("%w: %w", ErrBadRequest, errParse)
}

if parsed.Scheme == "" {
return nil, fmt.Errorf("%w: missing scheme", ErrBadRequest)
}
if strings.Contains(parsed.Path, "/..") {
return nil, errDotPath

if !isValidPath(parsed.Path) {
return nil, fmt.Errorf("%w: %w", ErrBadRequest, errDotPath)
}

if parsed.Path == "" {
parsed.Path = "/"
}

return &incomingRequest{
url: parsed,
remoteAddr: remoteAddr,
certs: certs,
}, nil
}

func isValidPath(path string) bool {

path = strings.TrimPrefix(path, "/")
path = strings.TrimSuffix(path, "/")

switch path {
case ".", "..":
return false
case "":
return true
}

return fs.ValidPath(path)
}

type incomingRequest struct {
url *url.URL
remoteAddr string
certs []*x509.Certificate
}

func (req *incomingRequest) URL() *url.URL {
Expand All @@ -55,3 +104,36 @@ func (req *incomingRequest) URL() *url.URL {
func (req *incomingRequest) RemoteAddr() string {
return req.remoteAddr
}

// - found delimiter -> return data[:delimIndex+1], err
// - found EOF -> return data, err
// - found error -> return data, err
func readUntil(re io.Reader, delim byte) ([]byte, error) {
b := make([]byte, 0, MaxRequestSize/4)
var errRead error
for {
if len(b) == cap(b) {
// Add more capacity (let append pick how much).
b = append(b, 0)[:len(b)]
}
n, err := re.Read(b[len(b):cap(b)])
b = b[:len(b)+n]

delimIndex := bytes.IndexByte(b, delim)
if delimIndex >= 0 {
b = b[:delimIndex+1]
}

if errors.Is(err, io.EOF) && delimIndex < 0 {
// EOF, but no delimiter found.
err = errors.Join(ErrBadRequest, io.ErrUnexpectedEOF)
}

if delimIndex >= 0 || err != nil {
errRead = err
break
}
}

return b, errRead
}
83 changes: 83 additions & 0 deletions gemax/request_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package gemax_test

import (
"strings"
"testing"

"github.com/ninedraft/gemax/gemax"
)

func TestParseIncomingRequest(t *testing.T) {
t.Parallel()
t.Log("parsing incoming request line")

const remoteAddr = "remote-addr"
type expect struct {
err bool
url string
}

tc := func(name, input string, expected expect) {
t.Run(name, func(t *testing.T) {
t.Parallel()

re := strings.NewReader(input)

parsed, err := gemax.ParseIncomingRequest(re, remoteAddr)

if (err != nil) != expected.err {
t.Errorf("error = %v, want error = %v", err, expected.err)
}

if parsed == nil && err == nil {
t.Error("parsed = nil, want not nil")
return
}

if parsed != nil {
assertEq(t, parsed.RemoteAddr(), remoteAddr, "remote addr")
assertEq(t, parsed.URL().String(), expected.url, "url")
}
})
}

tc("valid",
"gemini://example.com\r\n", expect{
url: "gemini://example.com/",
})
tc("valid no \\r",
"gemini://example.com\n", expect{
url: "gemini://example.com/",
})
tc("valid with path",
"gemini://example.com/path\r\n", expect{
url: "gemini://example.com/path",
})
tc("valid with path and query",
"gemini://example.com/path?query=value\r\n", expect{
url: "gemini://example.com/path?query=value",
})
tc("valid http",
"http://example.com\r\n", expect{
url: "http://example.com/",
})

tc("too long",
"http://example.com/"+strings.Repeat("a", 2048)+"\r\n",
expect{err: true})
tc("empty",
"", expect{err: true})
tc("no new \\r\\n",
"gemini://example.com", expect{err: true})
tc("no \\n",
"gemini://example.com\r", expect{err: true})
}

func assertEq[E comparable](t *testing.T, got, want E, format string, args ...any) {
t.Helper()

if got != want {
t.Errorf("got %v, want %v", got, want)
t.Errorf(format, args...)
}
}
13 changes: 3 additions & 10 deletions gemax/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package gemax
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
Expand Down Expand Up @@ -146,15 +145,9 @@ func (server *Server) handle(ctx context.Context, conn net.Conn) {
}
}()
var req, errParseReq = ParseIncomingRequest(conn, conn.RemoteAddr().String())
var code = status.Success
switch {
case errors.Is(errParseReq, errDotPath):
code = status.PermanentFailure
case errParseReq != nil:
code = status.BadRequest
}
if errParseReq != nil {
server.logf("WARN: bad request: remote_ip=%s, code=%s: %v", conn.RemoteAddr(), code, errParseReq)
const code = status.BadRequest
server.logf("WARN: bad request: remote_addr=%s, code=%s: %v", conn.RemoteAddr(), code, errParseReq)
rw.WriteStatus(code, status.Text(code))
return
}
Expand Down Expand Up @@ -217,7 +210,7 @@ func (server *Server) validHost(u *url.URL) bool {

func (server *Server) buildHosts() {
if server.hosts == nil {
server.hosts = map[string]struct{}{}
server.hosts = make(map[string]struct{}, len(server.Hosts))
}
for _, host := range server.Hosts {
server.hosts[host] = struct{}{}
Expand Down
Loading

0 comments on commit 0f98247

Please sign in to comment.