Skip to content

Commit

Permalink
Merge branch 'response-buffering'
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinmcconnell committed Jul 26, 2024
2 parents a44867a + a4e6c3b commit 421ac64
Show file tree
Hide file tree
Showing 10 changed files with 339 additions and 150 deletions.
27 changes: 20 additions & 7 deletions internal/cmd/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ func newDeployCommand() *deployCommand {
deployCommand.cmd = &cobra.Command{
Use: "deploy <service>",
Short: "Deploy a target host",
PreRunE: deployCommand.preRun,
RunE: deployCommand.deploy,
Args: cobra.ExactArgs(1),
ValidArgs: []string{"service"},
Expand All @@ -41,9 +42,10 @@ func newDeployCommand() *deployCommand {

deployCommand.cmd.Flags().DurationVar(&deployCommand.args.TargetOptions.ResponseTimeout, "target-timeout", server.DefaultTargetTimeout, "Maximum time to wait for the target server to respond when serving requests")

deployCommand.cmd.Flags().BoolVar(&deployCommand.args.TargetOptions.BufferRequests, "buffer-requests", false, "Enable request buffering")
deployCommand.cmd.Flags().Int64Var(&deployCommand.args.TargetOptions.MaxRequestMemoryBufferSize, "buffer-memory", server.DefaultMaxRequestMemoryBufferSize, "Max size of request memory buffer")
deployCommand.cmd.Flags().Int64Var(&deployCommand.args.TargetOptions.MaxRequestBodySize, "max-request-body", server.DefaultMaxRequestBodySize, "Max size of request body")
deployCommand.cmd.Flags().BoolVar(&deployCommand.args.TargetOptions.BufferingEnabled, "buffer", false, "Enable buffering")
deployCommand.cmd.Flags().Int64Var(&deployCommand.args.TargetOptions.MaxMemoryBufferSize, "buffer-memory", server.DefaultMaxMemoryBufferSize, "Max size of memory buffer")
deployCommand.cmd.Flags().Int64Var(&deployCommand.args.TargetOptions.MaxRequestBodySize, "max-request-body", server.DefaultMaxRequestBodySize, "Max size of request body when buffering (default of 0 means unlimited)")
deployCommand.cmd.Flags().Int64Var(&deployCommand.args.TargetOptions.MaxResponseBodySize, "max-response-body", server.DefaultMaxRequestBodySize, "Max size of response body when buffering (default of 0 means unlimited)")

deployCommand.cmd.MarkFlagRequired("target")

Expand All @@ -53,10 +55,6 @@ func newDeployCommand() *deployCommand {
func (c *deployCommand) deploy(cmd *cobra.Command, args []string) error {
c.args.Service = args[0]

if c.tls && c.args.Host == "" {
return fmt.Errorf("host must be set when using TLS")
}

if c.tls {
c.args.ServiceOptions.ACMECachePath = globalConfig.CertificatePath()
c.args.ServiceOptions.TLSHostname = c.args.Host
Expand All @@ -72,3 +70,18 @@ func (c *deployCommand) deploy(cmd *cobra.Command, args []string) error {
return client.Call("kamal-proxy.Deploy", c.args, &response)
})
}

func (c *deployCommand) preRun(cmd *cobra.Command, args []string) error {
flagsRequiringBuffering := []string{"max-request-body", "max-response-body", "buffer-memory"}
for _, flag := range flagsRequiringBuffering {
if cmd.Flags().Changed(flag) && !cmd.Flags().Changed("buffer") {
return fmt.Errorf("%s can only be set when buffering is enabled", flag)
}
}

if cmd.Flags().Changed("tls") && !cmd.Flags().Changed("host") {
return fmt.Errorf("host must be set when using TLS")
}

return nil
}
155 changes: 103 additions & 52 deletions internal/server/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,100 +6,151 @@ import (
"io"
"log/slog"
"os"
"sync"
)

var (
ErrMaximumSizeExceeded = errors.New("maximum size exceeded")
ErrWriteAfterRead = errors.New("write after read")
)

type BufferReadCloser struct {
type Buffer struct {
maxBytes int64
maxMemBytes int64

memoryBuffer bytes.Buffer
diskBuffer *os.File
multiReader io.Reader
memoryBuffer bytes.Buffer
memBytesWritten int64
diskBuffer *os.File
diskBytesWritten int64
overflowed bool
reader io.Reader
closeOnce sync.Once
}

func NewBufferReadCloser(r io.ReadCloser, maxBytes, maxMemBytes int64) (*BufferReadCloser, error) {
brc := &BufferReadCloser{
func NewBufferedReadCloser(r io.ReadCloser, maxBytes, maxMemBytes int64) (io.ReadCloser, error) {
buf := &Buffer{
maxBytes: maxBytes,
maxMemBytes: maxMemBytes,
}

err := brc.populate(r)
_, err := io.Copy(buf, r)
if err != nil {
return nil, err
}

return brc, err
return buf, nil
}

func (b *BufferReadCloser) Read(p []byte) (n int, err error) {
return b.multiReader.Read(p)
func NewBufferedWriteCloser(maxBytes, maxMemBytes int64) *Buffer {
return &Buffer{
maxBytes: maxBytes,
maxMemBytes: maxMemBytes,
}
}

func (b *BufferReadCloser) Close() error {
if b.diskBuffer != nil {
b.diskBuffer.Close()
os.Remove(b.diskBuffer.Name())
slog.Debug("Buffer: removing spill", "file", b.diskBuffer.Name())
func (b *Buffer) Write(p []byte) (int, error) {
if b.reader != nil {
return 0, ErrWriteAfterRead
}
return nil
}

func (b *BufferReadCloser) populate(r io.ReadCloser) error {
defer r.Close()
length := int64(len(p))
totalWritten := b.memBytesWritten + b.diskBytesWritten

moreDataRemaining, err := b.populateMemoryBuffer(r)
if err != nil {
return err
if b.maxBytes > 0 && totalWritten+length > b.maxBytes {
b.overflowed = true
return 0, ErrMaximumSizeExceeded
}

if !moreDataRemaining {
b.multiReader = &b.memoryBuffer
return nil
if b.diskBuffer != nil {
return b.writeToDisk(p)
}

if b.memBytesWritten+length <= b.maxMemBytes {
return b.writeToMemory(p)
}

err = b.populateDiskBuffer(r)
// We're writing past the memory buffer, so we need to start the spill to disk
err := b.createSpill()
if err != nil {
return err
return 0, err
}

b.multiReader = io.MultiReader(&b.memoryBuffer, b.diskBuffer)
memWritten, err := b.writeToMemory(p[:b.maxMemBytes-b.memBytesWritten])
if err != nil {
return memWritten, err
}

diskWritten, err := b.writeToDisk(p[memWritten:])
return memWritten + diskWritten, err
}

func (b *Buffer) Read(p []byte) (n int, err error) {
b.setReader()
return b.reader.Read(p)
}

func (b *Buffer) Overflowed() bool {
return b.overflowed
}

func (b *Buffer) Send(w io.Writer) error {
b.setReader()
_, err := io.Copy(w, b.reader)
return err
}

func (b *Buffer) Close() error {
b.closeOnce.Do(func() {
b.discardSpill()
})

return nil
}

func (b *BufferReadCloser) populateMemoryBuffer(r io.ReadCloser) (bool, error) {
limitReader := io.LimitReader(r, b.maxMemBytes)
copied, err := b.memoryBuffer.ReadFrom(limitReader)
if err != nil {
return false, err
}
func (b *Buffer) writeToMemory(p []byte) (int, error) {
n, err := b.memoryBuffer.Write(p)
b.memBytesWritten += int64(n)
return n, err
}

moreDataRemaining := copied == b.maxMemBytes
return moreDataRemaining, nil
func (b *Buffer) writeToDisk(p []byte) (int, error) {
n, err := b.diskBuffer.Write(p)
b.diskBytesWritten += int64(n)
return n, err
}

func (b *BufferReadCloser) populateDiskBuffer(r io.ReadCloser) error {
var err error
func (b *Buffer) setReader() {
if b.reader == nil {
if b.diskBuffer != nil {
b.diskBuffer.Seek(0, 0)
b.reader = io.MultiReader(&b.memoryBuffer, b.diskBuffer)
} else {
b.reader = &b.memoryBuffer
}
}
}

b.diskBuffer, err = os.CreateTemp("", "proxy-buffer")
func (b *Buffer) createSpill() error {
f, err := os.CreateTemp("", "proxy-buffer-")
if err != nil {
slog.Error("Buffer: failed to create spill file", "error", err)
return err
}

slog.Debug("Buffer: spilling request to disk", "file", b.diskBuffer.Name())
b.diskBuffer = f
slog.Debug("Buffer: spilling to disk", "file", b.diskBuffer.Name())

maxDiskBytes := b.maxBytes - b.maxMemBytes
limitReader := io.LimitReader(r, maxDiskBytes)
copied, err := io.Copy(b.diskBuffer, limitReader)
if err != nil {
return err
}
return nil
}

if copied == maxDiskBytes {
b.Close()
return ErrMaximumSizeExceeded
}
func (b *Buffer) discardSpill() {
if b.diskBuffer != nil {
b.diskBuffer.Close()

b.diskBuffer.Seek(0, 0)
return err
slog.Debug("Buffer: removing spill", "file", b.diskBuffer.Name())
err := os.Remove(b.diskBuffer.Name())
if err != nil {
slog.Error("Buffer: failed to remove spill", "file", b.diskBuffer.Name(), "error", err)
}
}
}
74 changes: 64 additions & 10 deletions internal/server/buffer_middleware.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
package server

import (
"bufio"
"log/slog"
"net"
"net/http"
)

type BufferMiddleware struct {
maxBytes int64
maxMemBytes int64
next http.Handler
maxMemBytes int64
maxRequestBytes int64
maxResponseBytes int64
next http.Handler
}

func WithBufferMiddleware(maxBytes, maxMemBytes int64, next http.Handler) http.Handler {
func WithBufferMiddleware(maxMemBytes, maxRequestBytes, maxResponseBytes int64, next http.Handler) http.Handler {
return &BufferMiddleware{
maxBytes: maxBytes,
maxMemBytes: maxMemBytes,
next: next,
maxMemBytes: maxMemBytes,
maxRequestBytes: maxRequestBytes,
maxResponseBytes: maxResponseBytes,
next: next,
}
}

func (h *BufferMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
buffer, err := NewBufferReadCloser(r.Body, h.maxBytes, h.maxMemBytes)
requestBuffer, err := NewBufferedReadCloser(r.Body, h.maxRequestBytes, h.maxMemBytes)
if err != nil {
if err == ErrMaximumSizeExceeded {
http.Error(w, "Request too large", http.StatusRequestEntityTooLarge)
Expand All @@ -31,6 +35,56 @@ func (h *BufferMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

r.Body = buffer
h.next.ServeHTTP(w, r)
responseBuffer := NewBufferedWriteCloser(h.maxResponseBytes, h.maxMemBytes)
responseWriter := &bufferedResponseWriter{ResponseWriter: w, statusCode: http.StatusOK, buffer: responseBuffer}
defer responseBuffer.Close()

r.Body = requestBuffer
h.next.ServeHTTP(responseWriter, r)

err = responseWriter.Send()
if err != nil {
slog.Error("Error sending response", "path", r.URL.Path, "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}

type bufferedResponseWriter struct {
http.ResponseWriter
statusCode int
buffer *Buffer
hijacked bool
}

func (w *bufferedResponseWriter) Send() error {
if w.buffer.Overflowed() {
return ErrMaximumSizeExceeded
}

if w.hijacked {
return nil
}

w.ResponseWriter.WriteHeader(w.statusCode)
return w.buffer.Send(w.ResponseWriter)
}

func (w *bufferedResponseWriter) Header() http.Header {
return w.ResponseWriter.Header()
}

func (w *bufferedResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
}

func (w *bufferedResponseWriter) Write(data []byte) (int, error) {
return w.buffer.Write(data)
}

func (w *bufferedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacker, ok := w.ResponseWriter.(http.Hijacker); ok {
w.hijacked = true
return hijacker.Hijack()
}
return nil, nil, http.ErrNotSupported
}
Loading

0 comments on commit 421ac64

Please sign in to comment.