forked from microsoft/kiota-http-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompression_handler.go
144 lines (118 loc) · 4.16 KB
/
compression_handler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package nethttplibrary
import (
"bytes"
"compress/gzip"
"io"
"net/http"
abstractions "github.com/microsoft/kiota-abstractions-go"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
// CompressionHandler represents a compression middleware
type CompressionHandler struct {
options CompressionOptions
}
// CompressionOptions is a configuration object for the CompressionHandler middleware
type CompressionOptions struct {
enableCompression bool
}
type compression interface {
abstractions.RequestOption
ShouldCompress() bool
}
var compressKey = abstractions.RequestOptionKey{Key: "CompressionHandler"}
// NewCompressionHandler creates an instance of a compression middleware
func NewCompressionHandler() *CompressionHandler {
options := NewCompressionOptions(true)
return NewCompressionHandlerWithOptions(options)
}
// NewCompressionHandlerWithOptions creates an instance of the compression middlerware with
// specified configurations.
func NewCompressionHandlerWithOptions(option CompressionOptions) *CompressionHandler {
return &CompressionHandler{options: option}
}
// NewCompressionOptions creates a configuration object for the CompressionHandler
func NewCompressionOptions(enableCompression bool) CompressionOptions {
return CompressionOptions{enableCompression: enableCompression}
}
// GetKey returns CompressionOptions unique name in context object
func (o CompressionOptions) GetKey() abstractions.RequestOptionKey {
return compressKey
}
// ShouldCompress reads compression setting form CompressionOptions
func (o CompressionOptions) ShouldCompress() bool {
return o.enableCompression
}
// Intercept is invoked by the middleware pipeline to either move the request/response
// to the next middleware in the pipeline
func (c *CompressionHandler) Intercept(pipeline Pipeline, middlewareIndex int, req *http.Request) (*http.Response, error) {
reqOption, ok := req.Context().Value(compressKey).(compression)
if !ok {
reqOption = c.options
}
obsOptions := GetObservabilityOptionsFromRequest(req)
ctx := req.Context()
var span trace.Span
if obsOptions != nil {
ctx, span = otel.GetTracerProvider().Tracer(obsOptions.GetTracerInstrumentationName()).Start(ctx, "CompressionHandler_Intercept")
span.SetAttributes(attribute.Bool("com.microsoft.kiota.handler.compression.enable", true))
defer span.End()
req = req.WithContext(ctx)
}
if !reqOption.ShouldCompress() || req.Body == nil {
return pipeline.Next(req, middlewareIndex)
}
if span != nil {
span.SetAttributes(attribute.Bool("http.request_body_compressed", true))
}
unCompressedBody, err := io.ReadAll(req.Body)
unCompressedContentLength := req.ContentLength
if err != nil {
if span != nil {
span.RecordError(err)
}
return nil, err
}
compressedBody, size, err := compressReqBody(unCompressedBody)
if err != nil {
if span != nil {
span.RecordError(err)
}
return nil, err
}
req.Header.Set("Content-Encoding", "gzip")
req.Body = compressedBody
req.ContentLength = int64(size)
if span != nil {
span.SetAttributes(attribute.Int64("http.request_content_length", req.ContentLength))
}
// Sending request with compressed body
resp, err := pipeline.Next(req, middlewareIndex)
if err != nil {
return nil, err
}
// If response has status 415 retry request with uncompressed body
if resp.StatusCode == 415 {
delete(req.Header, "Content-Encoding")
req.Body = io.NopCloser(bytes.NewBuffer(unCompressedBody))
req.ContentLength = unCompressedContentLength
if span != nil {
span.SetAttributes(attribute.Int64("http.request_content_length", req.ContentLength),
attribute.Int("http.request_content_length", 415))
}
return pipeline.Next(req, middlewareIndex)
}
return resp, nil
}
func compressReqBody(reqBody []byte) (io.ReadCloser, int, error) {
var buffer bytes.Buffer
gzipWriter := gzip.NewWriter(&buffer)
if _, err := gzipWriter.Write(reqBody); err != nil {
return nil, 0, err
}
if err := gzipWriter.Close(); err != nil {
return nil, 0, err
}
return io.NopCloser(&buffer), buffer.Len(), nil
}