diff --git a/proxy/proxy.go b/proxy/proxy.go index 7e62c0b..afe455d 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -66,6 +66,11 @@ func (ph HTTPProxyHandler) proxy(response http.ResponseWriter, request *http.Req return } if resp.StatusCode != http.StatusSwitchingProtocols { + if request.Method == http.MethodGet && utils.IsChunkedEncoding(resp) { + log.Debug("[dispatch] Forward chunked response") + utils.ForwardChunked(response, resp) + return + } log.Debug("[dispatch] Forward http response") if err := utils.Forward(resp, response); err != nil { log.Errorf("[dispatch] forward docker socket response failed %v", err) diff --git a/utils/utils.go b/utils/utils.go index 4e7f6f6..9dcd669 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "io/ioutil" + "net" "net/http" "strconv" "strings" @@ -24,6 +25,87 @@ func Initialize(bufSize int) { debug = log.GetLevel() == log.DebugLevel } +func IsChunkedEncoding(src *http.Response) bool { + for _, value := range src.TransferEncoding { + if lower := strings.ToLower(strings.Trim(value, " ")); lower == "chunked" { + return true + } + } + return false +} + +func ForwardChunked(response http.ResponseWriter, resp *http.Response) { + log.Info("[ForwardChunked] Will forward chunked response") + // we will hijack connection and link with dockerd connection + // test response writer could be hijacked + if hijacker, ok := response.(http.Hijacker); ok { + // test resp body is writable + doForwardChunked(response, resp, hijacker) + return + } + log.Error("[ForwardChunked] can't Hijack ServerResponseWriter") + if err := WriteBadGateWayResponse( + response, + HTTPSimpleMessageResponseBody{ + Message: "Can't Hijack ServerResponseWriter", + }, + ); err != nil { + log.Errorf("[ForwardChunked] write bad gateway response %v", err) + } +} + +func doForwardChunked(response http.ResponseWriter, resp *http.Response, hijacker http.Hijacker) { + var err error + // first we send response to non overrided client, make sure it's ready for new protocol + if err = writeChunkedResponseHeader( + response, + http.StatusSwitchingProtocols, + resp.Header, + ); err != nil { + log.Errorf("[doForwardChunked] write chunked response header failed %v", err) + return + } + var conn net.Conn + log.Info("[doForwardChunked] Hijack server http connection") + if conn, _, err = hijacker.Hijack(); err != nil { + log.Errorf("[doForwardChunked] Hijack ServerResponseWriter failed %v", err) + return + } + defer forwardChunked(conn, resp.Body) + // link client conn and server conn + log.Info("[doForwardChunked] completed") +} + +func forwardChunked(client io.ReadWriteCloser, server io.ReadCloser) { + log.Info("[forwardChunked] Starting forward chunked stream") + if _, err := io.Copy(client, server); err != nil { + if err == io.EOF { + log.Info("[forwardChunked] forwardChunked encounter EOF") + return + } + log.Errorf("[forwardChunked] forwardChunked end with %v", err) + } + log.Infof("[forwardChunked] End forward chunked stream") +} + +func writeChunkedResponseHeader(response http.ResponseWriter, statusCode int, header http.Header) error { + log.Infof("[WriteToServerResponse] Write ServerResponse, statusCode = %v", statusCode) + PrintHeaders("ServerResponse", header) + responseHeader := response.Header() + for key, values := range header { + for _, value := range values { + responseHeader.Add(key, value) + } + } + response.WriteHeader(statusCode) + if flusher, ok := response.(http.Flusher); ok { + flusher.Flush() + } else { + log.Error("[WriteToServerResponse] Can't make flush to http.flusher") + } + return nil +} + // Forward . func Forward(src *http.Response, dst http.ResponseWriter) error { copyHeader(src, dst) @@ -100,6 +182,15 @@ func copyBody(reader io.ReadCloser, dst io.Writer) (err error) { return } +func isChunked(src *http.Response) bool { + for _, value := range src.TransferEncoding { + if lower := strings.ToLower(strings.Trim(value, " ")); lower == "chunked" { + return true + } + } + return false +} + func copyHeader(src *http.Response, dst http.ResponseWriter) { PrintHeaders("ClientResponse", src.Header) log.Debugf("[copyHeader] ContentLength = %v", src.ContentLength)