Skip to content

Commit 355b851

Browse files
divijvaidyaluben
authored andcommitted
Fix bug in ZstdBufferDecompressingStream when array backing ByteBuffer is shared
1 parent 709ffb5 commit 355b851

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

src/main/java/com/github/luben/zstd/ZstdBufferDecompressingStreamNoFinalizer.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ long initDStream(long stream) {
4343
}
4444

4545
@Override
46-
long decompressStream(long stream, ByteBuffer dst, int dstOffset, int dstSize, ByteBuffer src, int srcOffset, int srcSize) {
46+
long decompressStream(long stream, ByteBuffer dst, int dstBufPos, int dstSize, ByteBuffer src, int srcBufPos, int srcSize) {
4747
if (!src.hasArray()) {
4848
throw new IllegalArgumentException("provided source ByteBuffer lacks array");
4949
}
@@ -53,7 +53,10 @@ long decompressStream(long stream, ByteBuffer dst, int dstOffset, int dstSize, B
5353
byte[] targetArr = dst.array();
5454
byte[] sourceArr = src.array();
5555

56-
return decompressStreamNative(stream, targetArr, dstOffset, dstSize, sourceArr, srcOffset, srcSize);
56+
// We are interested in array data corresponding to the pos represented by the ByteBuffer view.
57+
// A ByteBuffer may share an underlying array with other ByteBuffers. In such scenario, we need to adjust the
58+
// index of the array by adding an offset using arrayOffset().
59+
return decompressStreamNative(stream, targetArr, dstBufPos + dst.arrayOffset(), dstSize, sourceArr, srcBufPos + src.arrayOffset(), srcSize);
5760
}
5861

5962
public static int recommendedTargetBufferSize() {

src/test/scala/Zstd.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -679,8 +679,15 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
679679
val channel = FileChannel.open(file.toPath, StandardOpenOption.READ)
680680
// write some garbage bytes at the beginning of buffer containing compressed data to prove that
681681
// this buffer's position doesn't have to start from 0.
682-
val garbageBytes = "garbage bytes".getBytes(Charset.defaultCharset());
683-
val readBuffer = ByteBuffer.allocate(channel.size().toInt + garbageBytes.length)
682+
val garbageBytes = "garbage bytes".getBytes(Charset.defaultCharset())
683+
// add some extra bytes to the underlying array of the ByteBuffer. The ByteBuffer view does not include these
684+
// extra bytes. These are added to the underlying array to test for scenarios where the ByteBuffer view is a slice
685+
// of the underlying array.
686+
val extraBytes = "extra bytes".getBytes(Charset.defaultCharset())
687+
// Create a read buffer with extraBytes, we will later carve a slice out of it to store the compressed data.
688+
val bigReadBuffer = ByteBuffer.allocate(channel.size().toInt + garbageBytes.length + extraBytes.length)
689+
bigReadBuffer.put(extraBytes)
690+
val readBuffer = bigReadBuffer.slice()
684691
readBuffer.put(garbageBytes)
685692
channel.read(readBuffer)
686693
// set pos to 0 and limit to containing bytes
@@ -694,7 +701,9 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
694701
var pos = 0
695702
// write some garbage bytes at the beginning of buffer containing uncompressed data to prove that
696703
// this buffer's position doesn't have to start from 0.
697-
val block = ByteBuffer.allocate(1 + garbageBytes.length)
704+
val bigBlock = ByteBuffer.allocate(1 + garbageBytes.length + extraBytes.length)
705+
bigBlock.put(extraBytes)
706+
var block = bigBlock.slice()
698707
while (pos < length && zis.hasRemaining) {
699708
block.clear
700709
block.put(garbageBytes)

0 commit comments

Comments
 (0)