diff --git a/pkg/api/queryapi/compression.go b/pkg/api/queryapi/compression.go new file mode 100644 index 00000000000..7dd6fcbacab --- /dev/null +++ b/pkg/api/queryapi/compression.go @@ -0,0 +1,90 @@ +package queryapi + +import ( + "io" + "net/http" + "strings" + + "github.com/klauspost/compress/gzip" + "github.com/klauspost/compress/snappy" + "github.com/klauspost/compress/zlib" + "github.com/klauspost/compress/zstd" +) + +const ( + acceptEncodingHeader = "Accept-Encoding" + contentEncodingHeader = "Content-Encoding" + gzipEncoding = "gzip" + deflateEncoding = "deflate" + snappyEncoding = "snappy" + zstdEncoding = "zstd" +) + +// Wrapper around http.Handler which adds suitable response compression based +// on the client's Accept-Encoding headers. +type compressedResponseWriter struct { + http.ResponseWriter + writer io.Writer +} + +// Writes HTTP response content data. +func (c *compressedResponseWriter) Write(p []byte) (int, error) { + return c.writer.Write(p) +} + +// Closes the compressedResponseWriter and ensures to flush all data before. +func (c *compressedResponseWriter) Close() { + if zstdWriter, ok := c.writer.(*zstd.Encoder); ok { + zstdWriter.Flush() + } + if snappyWriter, ok := c.writer.(*snappy.Writer); ok { + snappyWriter.Flush() + } + if zlibWriter, ok := c.writer.(*zlib.Writer); ok { + zlibWriter.Flush() + } + if gzipWriter, ok := c.writer.(*gzip.Writer); ok { + gzipWriter.Flush() + } + if closer, ok := c.writer.(io.Closer); ok { + defer closer.Close() + } +} + +// Constructs a new compressedResponseWriter based on client request headers. +func newCompressedResponseWriter(writer http.ResponseWriter, req *http.Request) *compressedResponseWriter { + encodings := strings.Split(req.Header.Get(acceptEncodingHeader), ",") + for _, encoding := range encodings { + switch strings.TrimSpace(encoding) { + case zstdEncoding: + encoder, err := zstd.NewWriter(writer) + if err == nil { + writer.Header().Set(contentEncodingHeader, zstdEncoding) + return &compressedResponseWriter{ResponseWriter: writer, writer: encoder} + } + case snappyEncoding: + writer.Header().Set(contentEncodingHeader, snappyEncoding) + return &compressedResponseWriter{ResponseWriter: writer, writer: snappy.NewBufferedWriter(writer)} + case gzipEncoding: + writer.Header().Set(contentEncodingHeader, gzipEncoding) + return &compressedResponseWriter{ResponseWriter: writer, writer: gzip.NewWriter(writer)} + case deflateEncoding: + writer.Header().Set(contentEncodingHeader, deflateEncoding) + return &compressedResponseWriter{ResponseWriter: writer, writer: zlib.NewWriter(writer)} + } + } + return &compressedResponseWriter{ResponseWriter: writer, writer: writer} +} + +// CompressionHandler is a wrapper around http.Handler which adds suitable +// response compression based on the client's Accept-Encoding headers. +type CompressionHandler struct { + Handler http.Handler +} + +// ServeHTTP adds compression to the original http.Handler's ServeHTTP() method. +func (c CompressionHandler) ServeHTTP(writer http.ResponseWriter, req *http.Request) { + compWriter := newCompressedResponseWriter(writer, req) + c.Handler.ServeHTTP(compWriter, req) + compWriter.Close() +} diff --git a/pkg/api/queryapi/query_api.go b/pkg/api/queryapi/query_api.go index e3793ef5bee..5dd125a6c39 100644 --- a/pkg/api/queryapi/query_api.go +++ b/pkg/api/queryapi/query_api.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "strconv" "time" "github.com/go-kit/log" @@ -208,7 +209,7 @@ func (q *QueryAPI) Wrap(f apiFunc) http.HandlerFunc { w.WriteHeader(http.StatusNoContent) } - return httputil.CompressionHandler{ + return CompressionHandler{ Handler: http.HandlerFunc(hf), }.ServeHTTP } @@ -237,6 +238,7 @@ func (q *QueryAPI) respond(w http.ResponseWriter, req *http.Request, data interf } w.Header().Set("Content-Type", codec.ContentType().String()) + w.Header().Set("X-Uncompressed-Length", strconv.Itoa(len(b))) w.WriteHeader(http.StatusOK) if n, err := w.Write(b); err != nil { level.Error(q.logger).Log("error writing response", "url", req.URL, "bytesWritten", n, "err", err) diff --git a/pkg/frontend/transport/handler.go b/pkg/frontend/transport/handler.go index c988209fb0a..0f843e9b9cb 100644 --- a/pkg/frontend/transport/handler.go +++ b/pkg/frontend/transport/handler.go @@ -76,8 +76,6 @@ const ( limitBytesStoreGateway = `exceeded bytes limit` ) -var noopResponseSizeLimiter = limiter.NewResponseSizeLimiter(0) - // Config for a Handler. type HandlerConfig struct { LogQueriesLongerThan time.Duration `yaml:"log_queries_longer_than"` @@ -308,7 +306,7 @@ func (f *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // If the response status code is not 2xx, try to get the // error message from response body. if resp.StatusCode/100 != 2 { - body, err2 := tripperware.BodyBytes(resp, noopResponseSizeLimiter, f.log) + body, err2 := tripperware.BodyBytes(resp, f.log) if err2 == nil { err = httpgrpc.Errorf(resp.StatusCode, "%s", string(body)) } diff --git a/pkg/querier/tripperware/instantquery/instant_query.go b/pkg/querier/tripperware/instantquery/instant_query.go index 54fe4aeba0d..0d75a777763 100644 --- a/pkg/querier/tripperware/instantquery/instant_query.go +++ b/pkg/querier/tripperware/instantquery/instant_query.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "net/url" + "strconv" "strings" "time" @@ -45,8 +46,15 @@ type instantQueryCodec struct { func NewInstantQueryCodec(compressionStr string, defaultCodecTypeStr string) instantQueryCodec { compression := tripperware.NonCompression // default - if compressionStr == string(tripperware.GzipCompression) { + switch compressionStr { + case string(tripperware.GzipCompression): compression = tripperware.GzipCompression + + case string(tripperware.SnappyCompression): + compression = tripperware.SnappyCompression + + case string(tripperware.ZstdCompression): + compression = tripperware.ZstdCompression } defaultCodecType := tripperware.JsonCodecType // default @@ -100,8 +108,18 @@ func (c instantQueryCodec) DecodeResponse(ctx context.Context, r *http.Response, return nil, err } + responseSize, err := strconv.Atoi(r.Header.Get("X-Uncompressed-Length")) + if err != nil { + log.Error(err) + return nil, err + } + responseSizeLimiter := limiter.ResponseSizeLimiterFromContextWithFallback(ctx) - body, err := tripperware.BodyBytes(r, responseSizeLimiter, log) + if err := responseSizeLimiter.AddResponseBytes(responseSize); err != nil { + return nil, httpgrpc.Errorf(http.StatusUnprocessableEntity, "%s", err.Error()) + } + + body, err := tripperware.BodyBytes(r, log) if err != nil { log.Error(err) return nil, err diff --git a/pkg/querier/tripperware/query.go b/pkg/querier/tripperware/query.go index e20ab6e3c4e..059fa0deb3f 100644 --- a/pkg/querier/tripperware/query.go +++ b/pkg/querier/tripperware/query.go @@ -4,7 +4,6 @@ import ( "bytes" "compress/gzip" "context" - "encoding/binary" "fmt" "io" "net/http" @@ -16,6 +15,8 @@ import ( "github.com/go-kit/log" "github.com/gogo/protobuf/proto" jsoniter "github.com/json-iterator/go" + "github.com/klauspost/compress/snappy" + "github.com/klauspost/compress/zstd" "github.com/opentracing/opentracing-go" otlog "github.com/opentracing/opentracing-go/log" "github.com/pkg/errors" @@ -27,7 +28,6 @@ import ( "github.com/cortexproject/cortex/pkg/chunk" "github.com/cortexproject/cortex/pkg/cortexpb" - "github.com/cortexproject/cortex/pkg/util/limiter" "github.com/cortexproject/cortex/pkg/util/runutil" ) @@ -44,6 +44,8 @@ type Compression string const ( GzipCompression Compression = "gzip" + ZstdCompression Compression = "zstd" + SnappyCompression Compression = "snappy" NonCompression Compression = "" JsonCodecType CodecType = "json" ProtobufCodecType CodecType = "protobuf" @@ -434,7 +436,7 @@ type Buffer interface { Bytes() []byte } -func BodyBytes(res *http.Response, responseSizeLimiter *limiter.ResponseSizeLimiter, logger log.Logger) ([]byte, error) { +func BodyBytes(res *http.Response, logger log.Logger) ([]byte, error) { var buf *bytes.Buffer // Attempt to cast the response body to a Buffer and use it if possible. @@ -452,11 +454,6 @@ func BodyBytes(res *http.Response, responseSizeLimiter *limiter.ResponseSizeLimi } } - responseSize := getResponseSize(res, buf) - if err := responseSizeLimiter.AddResponseBytes(responseSize); err != nil { - return nil, httpgrpc.Errorf(http.StatusUnprocessableEntity, "%s", err.Error()) - } - // if the response is gzipped, lets unzip it here if strings.EqualFold(res.Header.Get("Content-Encoding"), "gzip") { gReader, err := gzip.NewReader(buf) @@ -468,15 +465,33 @@ func BodyBytes(res *http.Response, responseSizeLimiter *limiter.ResponseSizeLimi return io.ReadAll(gReader) } + // if the response is snappy compressed, decode it here + if strings.EqualFold(res.Header.Get("Content-Encoding"), "snappy") { + sReader := snappy.NewReader(buf) + return io.ReadAll(sReader) + } + + // if the response is zstd compressed, decode it here + if strings.EqualFold(res.Header.Get("Content-Encoding"), "zstd") { + zReader, err := zstd.NewReader(buf) + if err != nil { + return nil, err + } + defer runutil.CloseWithLogOnErr(logger, io.NopCloser(zReader), "close zstd decoder") + + return io.ReadAll(zReader) + } + return buf.Bytes(), nil } func BodyBytesFromHTTPGRPCResponse(res *httpgrpc.HTTPResponse, logger log.Logger) ([]byte, error) { - // if the response is gzipped, lets unzip it here headers := http.Header{} for _, h := range res.Headers { headers[h.Key] = h.Values } + + // if the response is gzipped, lets unzip it here if strings.EqualFold(headers.Get("Content-Encoding"), "gzip") { gReader, err := gzip.NewReader(bytes.NewBuffer(res.Body)) if err != nil { @@ -487,16 +502,24 @@ func BodyBytesFromHTTPGRPCResponse(res *httpgrpc.HTTPResponse, logger log.Logger return io.ReadAll(gReader) } - return res.Body, nil -} + // if the response is snappy compressed, decode it here + if strings.EqualFold(headers.Get("Content-Encoding"), "snappy") { + sReader := snappy.NewReader(bytes.NewBuffer(res.Body)) + return io.ReadAll(sReader) + } + + // if the response is zstd compressed, decode it here + if strings.EqualFold(headers.Get("Content-Encoding"), "zstd") { + zReader, err := zstd.NewReader(bytes.NewBuffer(res.Body)) + if err != nil { + return nil, err + } + defer runutil.CloseWithLogOnErr(logger, io.NopCloser(zReader), "close zstd decoder") -func getResponseSize(res *http.Response, buf *bytes.Buffer) int { - if strings.EqualFold(res.Header.Get("Content-Encoding"), "gzip") && len(buf.Bytes()) >= 4 { - // GZIP body contains the size of the original (uncompressed) input data - // modulo 2^32 in the last 4 bytes (https://www.ietf.org/rfc/rfc1952.txt). - return int(binary.LittleEndian.Uint32(buf.Bytes()[len(buf.Bytes())-4:])) + return io.ReadAll(zReader) } - return len(buf.Bytes()) + + return res.Body, nil } // UnmarshalJSON implements json.Unmarshaler. @@ -755,9 +778,17 @@ func (s *PrometheusResponseStats) MarshalJSON() ([]byte, error) { } func SetRequestHeaders(h http.Header, defaultCodecType CodecType, compression Compression) { - if compression == GzipCompression { + switch compression { + case GzipCompression: h.Set("Accept-Encoding", string(GzipCompression)) + + case SnappyCompression: + h.Set("Accept-Encoding", string(SnappyCompression)) + + case ZstdCompression: + h.Set("Accept-Encoding", string(ZstdCompression)) } + if defaultCodecType == ProtobufCodecType { h.Set("Accept", ApplicationProtobuf+", "+ApplicationJson) } else { diff --git a/pkg/querier/tripperware/query_test.go b/pkg/querier/tripperware/query_test.go index 04606df99e6..08f149f43b0 100644 --- a/pkg/querier/tripperware/query_test.go +++ b/pkg/querier/tripperware/query_test.go @@ -1,10 +1,7 @@ package tripperware import ( - "bytes" - "compress/gzip" "math" - "net/http" "strconv" "testing" "time" @@ -196,50 +193,3 @@ func generateData(timeseries, datapoints int) (floatMatrix, histogramMatrix []*S } return } - -func Test_getResponseSize(t *testing.T) { - tests := []struct { - body []byte - useGzip bool - }{ - { - body: []byte(`foo`), - useGzip: false, - }, - { - body: []byte(`foo`), - useGzip: true, - }, - { - body: []byte(`{"status":"success","data":{"resultType":"vector","result":[]}}`), - useGzip: false, - }, - { - body: []byte(`{"status":"success","data":{"resultType":"vector","result":[]}}`), - useGzip: true, - }, - } - - for i, test := range tests { - t.Run(strconv.Itoa(i), func(t *testing.T) { - expectedBodyLength := len(test.body) - buf := &bytes.Buffer{} - response := &http.Response{} - - if test.useGzip { - response = &http.Response{ - Header: http.Header{"Content-Encoding": []string{"gzip"}}, - } - w := gzip.NewWriter(buf) - _, err := w.Write(test.body) - require.NoError(t, err) - w.Close() - } else { - buf = bytes.NewBuffer(test.body) - } - - bodyLength := getResponseSize(response, buf) - require.Equal(t, expectedBodyLength, bodyLength) - }) - } -} diff --git a/pkg/querier/tripperware/queryrange/query_range.go b/pkg/querier/tripperware/queryrange/query_range.go index 9d82031fc0b..8ac81ef6e31 100644 --- a/pkg/querier/tripperware/queryrange/query_range.go +++ b/pkg/querier/tripperware/queryrange/query_range.go @@ -63,8 +63,15 @@ type prometheusCodec struct { func NewPrometheusCodec(sharded bool, compressionStr string, defaultCodecTypeStr string) *prometheusCodec { //nolint:revive compression := tripperware.NonCompression // default - if compressionStr == string(tripperware.GzipCompression) { + switch compressionStr { + case string(tripperware.GzipCompression): compression = tripperware.GzipCompression + + case string(tripperware.SnappyCompression): + compression = tripperware.SnappyCompression + + case string(tripperware.ZstdCompression): + compression = tripperware.ZstdCompression } defaultCodecType := tripperware.JsonCodecType // default @@ -196,8 +203,18 @@ func (c prometheusCodec) DecodeResponse(ctx context.Context, r *http.Response, _ return nil, err } + responseSize, err := strconv.Atoi(r.Header.Get("X-Uncompressed-Length")) + if err != nil { + log.Error(err) + return nil, err + } + responseSizeLimiter := limiter.ResponseSizeLimiterFromContextWithFallback(ctx) - body, err := tripperware.BodyBytes(r, responseSizeLimiter, log) + if err := responseSizeLimiter.AddResponseBytes(responseSize); err != nil { + return nil, httpgrpc.Errorf(http.StatusUnprocessableEntity, "%s", err.Error()) + } + + body, err := tripperware.BodyBytes(r, log) if err != nil { log.Error(err) return nil, err