From 89cd7ed50e42eddb63d91457efc7c04affa5719e Mon Sep 17 00:00:00 2001 From: Mark Kittisopikul Date: Tue, 28 May 2024 23:14:01 -0400 Subject: [PATCH] Implement ZstdFrameCompressor via endOp (#52) * Implement ZstdFrameCompressor via endOp * Repeat calling compress! with same input until code == 0 with ZSTD_e_end * Adopt additional tests from #53 * Allocate an input buffer when using ZstdFrameCompressor * Simplify, remove buffer, just keep ibuffer pos and size same to complete frame * Reset input and output buffers of Cstream on initialize and finalize * Reset buffers on decompression --- src/CodecZstd.jl | 1 + src/compression.jl | 70 +++++++++++++++++++++++++++++++++++++----- src/decompression.jl | 4 +++ src/libzstd.jl | 18 +++++++++-- test/compress_endOp.jl | 12 ++++++++ test/runtests.jl | 38 +++++++++++++++++++++++ 6 files changed, 134 insertions(+), 9 deletions(-) diff --git a/src/CodecZstd.jl b/src/CodecZstd.jl index 315f5f2..dffbcc1 100644 --- a/src/CodecZstd.jl +++ b/src/CodecZstd.jl @@ -3,6 +3,7 @@ module CodecZstd export ZstdCompressor, ZstdCompressorStream, + ZstdFrameCompressor, ZstdDecompressor, ZstdDecompressorStream diff --git a/src/compression.jl b/src/compression.jl index 36b93a4..cabc3f9 100644 --- a/src/compression.jl +++ b/src/compression.jl @@ -4,10 +4,15 @@ struct ZstdCompressor <: TranscodingStreams.Codec cstream::CStream level::Int + endOp::LibZstd.ZSTD_EndDirective end function Base.show(io::IO, codec::ZstdCompressor) - print(io, summary(codec), "(level=$(codec.level))") + if codec.endOp == LibZstd.ZSTD_e_end + print(io, "ZstdFrameCompressor(level=$(codec.level))") + else + print(io, summary(codec), "(level=$(codec.level))") + end end # Same as the zstd command line tool (v1.2.0). @@ -28,6 +33,34 @@ function ZstdCompressor(;level::Integer=DEFAULT_COMPRESSION_LEVEL) end return ZstdCompressor(CStream(), level) end +ZstdCompressor(cstream, level) = ZstdCompressor(cstream, level, :continue) + +""" + ZstdFrameCompressor(;level=$(DEFAULT_COMPRESSION_LEVEL)) + +Create a new zstd compression codec that reads the available input and then +closes the frame, encoding the decompressed size of that frame. + +Arguments +--------- +- `level`: compression level (1..$(MAX_CLEVEL)) +""" +function ZstdFrameCompressor(;level::Integer=DEFAULT_COMPRESSION_LEVEL) + if !(1 ≤ level ≤ MAX_CLEVEL) + throw(ArgumentError("level must be within 1..$(MAX_CLEVEL)")) + end + return ZstdCompressor(CStream(), level, :end) +end +# pretend that ZstdFrameCompressor is a compressor type +function TranscodingStreams.transcode(C::typeof(ZstdFrameCompressor), args...) + codec = C() + initialize(codec) + try + return transcode(codec, args...) + finally + finalize(codec) + end +end const ZstdCompressorStream{S} = TranscodingStream{ZstdCompressor,S} where S<:IO @@ -50,6 +83,8 @@ function TranscodingStreams.initialize(codec::ZstdCompressor) if iserror(code) zstderror(codec.cstream, code) end + reset!(codec.cstream.ibuffer) + reset!(codec.cstream.obuffer) return end @@ -61,6 +96,8 @@ function TranscodingStreams.finalize(codec::ZstdCompressor) end codec.cstream.ptr = C_NULL end + reset!(codec.cstream.ibuffer) + reset!(codec.cstream.obuffer) return end @@ -75,21 +112,40 @@ end function TranscodingStreams.process(codec::ZstdCompressor, input::Memory, output::Memory, error::Error) cstream = codec.cstream - cstream.ibuffer.src = input.ptr - cstream.ibuffer.size = input.size - cstream.ibuffer.pos = 0 + ibuffer_starting_pos = UInt(0) + if codec.endOp == LibZstd.ZSTD_e_end && + cstream.ibuffer.size != cstream.ibuffer.pos + # While saving a frame, the prior process run did not finish writing the frame. + # A positive code indicates the need for additional output buffer space. + # Re-run with the same cstream.ibuffer.size as pledged for the frame, + # otherwise a "Src size is incorrect" error will occur. + + # For the current frame, cstream.ibuffer.size - cstream.ibuffer.pos + # must reflect the remaining data. Thus neither size or pos can change. + # Store the starting pos since it will be non-zero. + ibuffer_starting_pos = cstream.ibuffer.pos + + # Set the pointer relative to input.ptr such that + # cstream.ibuffer.src + cstream.ibuffer.pos == input.ptr + cstream.ibuffer.src = input.ptr - cstream.ibuffer.pos + else + cstream.ibuffer.src = input.ptr + cstream.ibuffer.size = input.size + cstream.ibuffer.pos = 0 + end cstream.obuffer.dst = output.ptr cstream.obuffer.size = output.size cstream.obuffer.pos = 0 if input.size == 0 code = finish!(cstream) else - code = compress!(cstream) + code = compress!(cstream; endOp = codec.endOp) end - Δin = Int(cstream.ibuffer.pos) + Δin = Int(cstream.ibuffer.pos - ibuffer_starting_pos) Δout = Int(cstream.obuffer.pos) if iserror(code) - error[] = ErrorException("zstd error") + ptr = LibZstd.ZSTD_getErrorName(code) + error[] = ErrorException("zstd error: " * unsafe_string(ptr)) return Δin, Δout, :error else return Δin, Δout, input.size == 0 && code == 0 ? :end : :ok diff --git a/src/decompression.jl b/src/decompression.jl index 6767634..765ce2c 100644 --- a/src/decompression.jl +++ b/src/decompression.jl @@ -38,6 +38,8 @@ function TranscodingStreams.initialize(codec::ZstdDecompressor) if iserror(code) zstderror(codec.dstream, code) end + reset!(codec.dstream.ibuffer) + reset!(codec.dstream.obuffer) return end @@ -49,6 +51,8 @@ function TranscodingStreams.finalize(codec::ZstdDecompressor) end codec.dstream.ptr = C_NULL end + reset!(codec.dstream.ibuffer) + reset!(codec.dstream.obuffer) return end diff --git a/src/libzstd.jl b/src/libzstd.jl index 9906b2b..79d021d 100644 --- a/src/libzstd.jl +++ b/src/libzstd.jl @@ -16,12 +16,26 @@ end const MAX_CLEVEL = max_clevel() +# InBuffer is the C struct ZSTD_inBuffer const InBuffer = LibZstd.ZSTD_inBuffer InBuffer() = InBuffer(C_NULL, 0, 0) Base.unsafe_convert(::Type{Ptr{InBuffer}}, buffer::InBuffer) = Ptr{InBuffer}(pointer_from_objref(buffer)) +function reset!(buf::InBuffer) + buf.src = C_NULL + buf.pos = 0 + buf.size = 0 +end + +# OutBuffer is the C struct ZSTD_outBuffer const OutBuffer = LibZstd.ZSTD_outBuffer OutBuffer() = OutBuffer(C_NULL, 0, 0) Base.unsafe_convert(::Type{Ptr{OutBuffer}}, buffer::OutBuffer) = Ptr{OutBuffer}(pointer_from_objref(buffer)) +function reset!(buf::OutBuffer) + buf.dst = C_NULL + buf.pos = 0 + buf.size = 0 +end + # ZSTD_CStream mutable struct CStream @@ -60,9 +74,9 @@ function reset!(cstream::CStream, srcsize::Integer) # explicitly specified. srcsize = ZSTD_CONTENTSIZE_UNKNOWN end + reset!(cstream.ibuffer) + reset!(cstream.obuffer) return LibZstd.ZSTD_CCtx_setPledgedSrcSize(cstream, srcsize) - #return ccall((:ZSTD_resetCStream, libzstd), Csize_t, (Ptr{Cvoid}, Culonglong), cstream.ptr, srcsize) - end """ diff --git a/test/compress_endOp.jl b/test/compress_endOp.jl index ad646f0..0594f1f 100644 --- a/test/compress_endOp.jl +++ b/test/compress_endOp.jl @@ -59,3 +59,15 @@ end Base.Libc.free(cstream.obuffer.dst) end end + +@testset "ZstdFrameCompressor" begin + data = rand(1:100, 1024*1024) + compressed = transcode(ZstdFrameCompressor, copy(reinterpret(UInt8, data))) + GC.@preserve compressed begin + @test CodecZstd.find_decompressed_size(pointer(compressed), sizeof(compressed)) == sizeof(data) + end + @test reinterpret(Int, transcode(ZstdDecompressor, compressed)) == data + iob = IOBuffer() + print(iob, ZstdFrameCompressor()) + @test startswith(String(take!(iob)), "ZstdFrameCompressor") +end diff --git a/test/runtests.jl b/test/runtests.jl index cdb1f64..7ca5875 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,5 +44,43 @@ Random.seed!(1234) TranscodingStreams.test_roundtrip_lines(ZstdCompressorStream, ZstdDecompressorStream) TranscodingStreams.test_roundtrip_transcode(ZstdCompressor, ZstdDecompressor) + frame_encoder = io -> TranscodingStream(ZstdFrameCompressor(), io) + TranscodingStreams.test_roundtrip_read(frame_encoder, ZstdDecompressorStream) + TranscodingStreams.test_roundtrip_write(frame_encoder, ZstdDecompressorStream) + TranscodingStreams.test_roundtrip_lines(frame_encoder, ZstdDecompressorStream) + TranscodingStreams.test_roundtrip_transcode(ZstdFrameCompressor, ZstdDecompressor) + + @testset "ZstdFrameCompressor streaming edge case" begin + codec = ZstdFrameCompressor() + TranscodingStreams.initialize(codec) + e = TranscodingStreams.Error() + r = TranscodingStreams.startproc(codec, :write, e) + @test r == :ok + # data buffers + data = rand(UInt8, 32*1024*1024) + buffer1 = copy(data) + buffer2 = zeros(UInt8, length(data)*2) + GC.@preserve buffer1 buffer2 begin + total_out = 0 + total_in = 0 + while total_in < length(data) || r != :end + in_size = min(length(buffer1) - total_in, 1024*1024) + out_size = min(length(buffer2) - total_out, 1024) + input = TranscodingStreams.Memory(pointer(buffer1, total_in + 1), UInt(in_size)) + output = TranscodingStreams.Memory(pointer(buffer2, total_out + 1), UInt(out_size)) + Δin, Δout, r = TranscodingStreams.process(codec, input, output, e) + if r == :error + throw(e[]) + end + total_out += Δout + total_in += Δin + end + @test r == :end + end + TranscodingStreams.finalize(codec) + resize!(buffer2, total_out) + @test transcode(ZstdDecompressor, buffer2) == data + end + include("compress_endOp.jl") end