diff --git a/utils/compression/compressor_test.go b/utils/compression/compressor_test.go index 5e3aa6ad078a..0a83295b0dbe 100644 --- a/utils/compression/compressor_test.go +++ b/utils/compression/compressor_test.go @@ -9,6 +9,7 @@ import ( "runtime" "testing" + "github.com/DataDog/zstd" "github.com/stretchr/testify/require" _ "embed" @@ -149,6 +150,19 @@ func TestNewCompressorWithInvalidLimit(t *testing.T) { } } +func TestNewZstdCompressorWithLevel(t *testing.T) { + compressor, err := NewZstdCompressorWithLevel(maxMessageSize, zstd.BestSpeed) + require.NoError(t, err) + + data := utils.RandomBytes(4096) + compressed, err := compressor.Compress(data) + require.NoError(t, err) + + decompressed, err := compressor.Decompress(compressed) + require.NoError(t, err) + require.Equal(t, data, decompressed) +} + func FuzzZstdCompressor(f *testing.F) { fuzzHelper(f, TypeZstd) } diff --git a/utils/compression/zstd_compressor.go b/utils/compression/zstd_compressor.go index b23336bee9ec..cd17f2c69288 100644 --- a/utils/compression/zstd_compressor.go +++ b/utils/compression/zstd_compressor.go @@ -22,6 +22,10 @@ var ( ) func NewZstdCompressor(maxSize int64) (Compressor, error) { + return NewZstdCompressorWithLevel(maxSize, zstd.DefaultCompression) +} + +func NewZstdCompressorWithLevel(maxSize int64, level int) (Compressor, error) { if maxSize == math.MaxInt64 { // "Decompress" creates "io.LimitReader" with max size + 1: // if the max size + 1 overflows, "io.LimitReader" reads nothing @@ -29,21 +33,22 @@ func NewZstdCompressor(maxSize int64) (Compressor, error) { // require max size < math.MaxInt64 to prevent int64 overflows return nil, ErrInvalidMaxSizeCompressor } - return &zstdCompressor{ maxSize: maxSize, + level: level, }, nil } type zstdCompressor struct { maxSize int64 + level int } func (z *zstdCompressor) Compress(msg []byte) ([]byte, error) { if int64(len(msg)) > z.maxSize { return nil, fmt.Errorf("%w: (%d) > (%d)", ErrMsgTooLarge, len(msg), z.maxSize) } - return zstd.Compress(nil, msg) + return zstd.CompressLevel(nil, msg, z.level) } func (z *zstdCompressor) Decompress(msg []byte) ([]byte, error) {