Skip to content

Commit

Permalink
Add defensive code to handle a corrupted bitstream
Browse files Browse the repository at this point in the history
  • Loading branch information
flanglet committed Jun 23, 2024
1 parent 5bbc6a1 commit 4bc9413
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 53 deletions.
25 changes: 16 additions & 9 deletions v2/entropy/BinaryEntropyCodec.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,17 @@ func (this *BinaryEntropyEncoder) Write(block []byte) (int, error) {
length = 64
}

bufSize := length + (length >> 3)

if len(this.buffer) < bufSize {
this.buffer = make([]byte, bufSize)
}

// Split block into chunks, read bit array from bitstream and decode chunk
for startChunk < end {
chunkSize := min(length, end-startChunk)

if len(this.buffer) < (chunkSize + (chunkSize >> 3)) {
this.buffer = make([]byte, chunkSize+(chunkSize>>3))
}

this.index = 0
buf := block[startChunk : startChunk+chunkSize]
this.index = 0

for i := range buf {
this.EncodeByte(buf[i])
Expand Down Expand Up @@ -286,15 +287,21 @@ func (this *BinaryEntropyDecoder) Read(block []byte) (int, error) {
length = 64
}

bufSize := length + (length >> 3)

if len(this.buffer) < bufSize {
this.buffer = make([]byte, bufSize)
}

// Split block into chunks, read bit array from bitstream and decode chunk
for startChunk < end {
chunkSize := min(length, end-startChunk)
szBytes := ReadVarInt(this.bitstream)

if len(this.buffer) < chunkSize+(chunkSize>>3) {
this.buffer = make([]byte, chunkSize+(chunkSize>>3))
if szBytes > uint32(bufSize) {
return startChunk, errors.New("Binary entropy codec: Invalid bitstream")
}

szBytes := ReadVarInt(this.bitstream)
this.current = this.bitstream.ReadBits(56)

if szBytes != 0 {
Expand Down
18 changes: 12 additions & 6 deletions v2/entropy/HuffmanCodec.go
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,11 @@ func (this *HuffmanDecoder) readLengths() (int, error) {

// max(CodeLen) must be <= _HUF_MAX_SYMBOL_SIZE
func (this *HuffmanDecoder) buildDecodingTable(count int) bool {
// Initialize table with non zero value.
// If the bitstream is altered, the decoder may access these default table values.
// The number of consumed bits cannot be 0.
for i := range this.table {
this.table[i] = 0
this.table[i] = 8
}

length := 0
Expand All @@ -637,19 +640,17 @@ func (this *HuffmanDecoder) buildDecodingTable(count int) bool {
length = int(this.sizes[s])
}

// code -> size, symbol
val := (uint16(s) << 8) | uint16(this.sizes[s])
code := this.codes[s]

// All DECODING_BATCH_SIZE bit values read from the bit stream and
// starting with the same prefix point to symbol s
idx := code << (shift - length)
idx := this.codes[s] << (shift - length)
end := idx + (1 << (shift - length))

if int(end) > len(this.table) {
return false
}

// code -> size, symbol
val := (uint16(s) << 8) | uint16(this.sizes[s])
t := this.table[idx:end]

for j := range t {
Expand Down Expand Up @@ -831,6 +832,11 @@ func (this *HuffmanDecoder) Read(block []byte) (int, error) {
bits += 8
}

// Sanity check
if bits > 64 {
return n, errors.New("Invalid bitstream: incorrect symbol size")
}

var val uint16

if bits >= _HUF_MAX_SYMBOL_SIZE_V4 {
Expand Down
56 changes: 41 additions & 15 deletions v2/transform/EXECodec.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,23 +109,23 @@ func (this *EXECodec) Forward(src, dst []byte) (uint, uint, error) {
count := len(src)

if count < _EXE_MIN_BLOCK_SIZE {
return 0, 0, fmt.Errorf("Block too small - size: %d, min %d)", count, _EXE_MIN_BLOCK_SIZE)
return 0, 0, fmt.Errorf("ExeCodec forward failed: Block too small - size: %d, min %d)", count, _EXE_MIN_BLOCK_SIZE)
}

if count > _EXE_MAX_BLOCK_SIZE {
return 0, 0, fmt.Errorf("Block too big - size: %d, max %d", count, _EXE_MAX_BLOCK_SIZE)
return 0, 0, fmt.Errorf("ExeCodec forward failed: Block too big - size: %d, max %d", count, _EXE_MAX_BLOCK_SIZE)
}

if n := this.MaxEncodedLen(count); len(dst) < n {
return 0, 0, fmt.Errorf("Output buffer too small - size: %d, required %d", len(dst), n)
return 0, 0, fmt.Errorf("ExeCodec forward failed: Output buffer too small - size: %d, required %d", len(dst), n)
}

if this.ctx != nil {
if val, containsKey := (*this.ctx)["dataType"]; containsKey {
dt := val.(internal.DataType)

if dt != internal.DT_UNDEFINED && dt != internal.DT_EXE && dt != internal.DT_BIN {
return 0, 0, fmt.Errorf("Input is not an executable, skip")
return 0, 0, fmt.Errorf("ExeCodec forward failed: Input is not an executable, skip")
}
}
}
Expand All @@ -139,7 +139,7 @@ func (this *EXECodec) Forward(src, dst []byte) (uint, uint, error) {
(*this.ctx)["dataType"] = internal.DataType(mode & _EXE_MASK_DT)
}

return 0, 0, fmt.Errorf("Input is not an executable, skip")
return 0, 0, fmt.Errorf("ExeCodec forward failed: Input is not an executable, skip")
}

mode &= ^byte(_EXE_MASK_DT)
Expand All @@ -156,7 +156,7 @@ func (this *EXECodec) Forward(src, dst []byte) (uint, uint, error) {
return this.forwardARM(src, dst, codeStart, codeEnd)
}

return 0, 0, fmt.Errorf("Input is not a supported executable format, skip")
return 0, 0, fmt.Errorf("ExeCodec forward failed: Input is not a supported executable format, skip")
}

func (this *EXECodec) forwardX86(src, dst []byte, codeStart, codeEnd int) (uint, uint, error) {
Expand All @@ -167,6 +167,10 @@ func (this *EXECodec) forwardX86(src, dst []byte, codeStart, codeEnd int) (uint,
dst[0] = _EXE_X86
matches = 0

if codeStart > len(src) || codeEnd > len(src) {
return 0, 0, fmt.Errorf("ExeCodec forward failed: Input is not a supported executable format")
}

if codeStart > 0 {
copy(dst[dstIdx:], src[0:codeStart])
dstIdx += codeStart
Expand Down Expand Up @@ -232,14 +236,14 @@ func (this *EXECodec) forwardX86(src, dst []byte, codeStart, codeEnd int) (uint,
}

if matches < 16 {
return uint(srcIdx), uint(dstIdx), errors.New("Too few calls/jumps, skip")
return uint(srcIdx), uint(dstIdx), errors.New("ExeCodec forward: Too few calls/jumps, skip")
}

count := len(src)

// Cap expansion due to false positives
if srcIdx < codeEnd || dstIdx+(count-srcIdx) > dstEnd {
return uint(srcIdx), uint(dstIdx), errors.New("Too many false positives, skip")
return uint(srcIdx), uint(dstIdx), errors.New("ExeCodec forward: Too many false positives, skip")
}

binary.LittleEndian.PutUint32(dst[1:], uint32(codeStart))
Expand Down Expand Up @@ -272,7 +276,7 @@ func (this *EXECodec) Inverse(src, dst []byte) (uint, uint, error) {
return this.inverseARM(src, dst)
}

return 0, 0, errors.New("Invalid data: unknown binary type")
return 0, 0, errors.New("ExeCodec inverse failed: unknown binary type")
}

func (this *EXECodec) inverseX86(src, dst []byte) (uint, uint, error) {
Expand All @@ -281,6 +285,11 @@ func (this *EXECodec) inverseX86(src, dst []byte) (uint, uint, error) {
codeStart := int(binary.LittleEndian.Uint32(src[1:]))
codeEnd := int(binary.LittleEndian.Uint32(src[5:]))

// Sanity check
if codeStart+srcIdx > len(src) || codeStart+dstIdx > len(dst) || codeEnd > len(src) {
return 0, 0, errors.New("ExeCodec inverse failed: invalid bitstream")
}

if codeStart > 0 {
copy(dst[dstIdx:], src[srcIdx:srcIdx+codeStart])
dstIdx += codeStart
Expand Down Expand Up @@ -334,8 +343,12 @@ func (this *EXECodec) inverseX86(src, dst []byte) (uint, uint, error) {
}

count := len(src)
copy(dst[dstIdx:], src[srcIdx:count])
dstIdx += (count - srcIdx)

if srcIdx < count {
copy(dst[dstIdx:], src[srcIdx:count])
dstIdx += (count - srcIdx)
}

return uint(count), uint(dstIdx), nil
}

Expand Down Expand Up @@ -399,6 +412,10 @@ func (this *EXECodec) forwardARM(src, dst []byte, codeStart, codeEnd int) (uint,
dst[0] = _EXE_ARM64
matches = 0

if codeStart > len(src) || codeEnd > len(src) {
return 0, 0, fmt.Errorf("ExeCodec forward failed: Input is not a supported executable format")
}

if codeStart > 0 {
copy(dst[dstIdx:], src[0:codeStart])
dstIdx += codeStart
Expand Down Expand Up @@ -472,14 +489,14 @@ func (this *EXECodec) forwardARM(src, dst []byte, codeStart, codeEnd int) (uint,
}

if matches < 16 {
return uint(srcIdx), uint(dstIdx), errors.New("Too few calls/jumps, skip")
return uint(srcIdx), uint(dstIdx), errors.New("ExeCodec forward: Too few calls/jumps, skip")
}

count := len(src)

// Cap expansion due to false positives
if srcIdx < codeEnd || dstIdx+(count-srcIdx) > dstEnd {
return uint(srcIdx), uint(dstIdx), errors.New("Too many false positives, skip")
return uint(srcIdx), uint(dstIdx), errors.New("ExeCodec forward: Too many false positives, skip")
}

binary.LittleEndian.PutUint32(dst[1:], uint32(codeStart))
Expand All @@ -495,6 +512,11 @@ func (this *EXECodec) inverseARM(src, dst []byte) (uint, uint, error) {
codeStart := int(binary.LittleEndian.Uint32(src[1:]))
codeEnd := int(binary.LittleEndian.Uint32(src[5:]))

// Sanity check
if codeStart+srcIdx > len(src) || codeStart+dstIdx > len(dst) || codeEnd > len(src) {
return 0, 0, errors.New("ExeCodec inverse failed: invalid bitstream")
}

if codeStart > 0 {
copy(dst[dstIdx:], src[srcIdx:srcIdx+codeStart])
dstIdx += codeStart
Expand Down Expand Up @@ -544,8 +566,12 @@ func (this *EXECodec) inverseARM(src, dst []byte) (uint, uint, error) {
}

count := len(src)
copy(dst[dstIdx:], src[srcIdx:count])
dstIdx += (count - srcIdx)

if srcIdx < count {
copy(dst[dstIdx:], src[srcIdx:count])
dstIdx += (count - srcIdx)
}

return uint(count), uint(dstIdx), nil
}

Expand Down
12 changes: 5 additions & 7 deletions v2/transform/LZCodec.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,7 @@ func readLengthLZ(block []byte) (int, int) {
}

func emitLiteralsLZ(src, dst []byte) {
for i := 0; i < len(src); i += 8 {
copy(dst[i:], src[i:i+8])
}
copy(dst, src)
}

func (this *LZXCodec) hash(p []byte) uint32 {
Expand Down Expand Up @@ -733,17 +731,17 @@ func (this *LZXCodec) inverseV3(src, dst []byte) (uint, uint, error) {
mIdx := int(binary.LittleEndian.Uint32(src[4:]))
mLenIdx := int(binary.LittleEndian.Uint32(src[8:]))

// Sanity checks
if (tkIdx < 0) || (mIdx < 0) || (mLenIdx < 0) {
return 0, 0, errors.New("LZCodec: inverse transform failed, invalid data")
}

mIdx += tkIdx
mLenIdx += mIdx

if (tkIdx > count) || (mIdx > count) || (mLenIdx > count) {
if (tkIdx > count) || (mIdx > count-tkIdx) || (mLenIdx > count-tkIdx-mIdx) {
return 0, 0, errors.New("LZCodec: inverse transform failed, invalid data")
}

mIdx += tkIdx
mLenIdx += mIdx
srcEnd := tkIdx - 13
dstEnd := len(dst) - 16
maxDist := _LZX_MAX_DISTANCE2
Expand Down
Loading

0 comments on commit 4bc9413

Please sign in to comment.