From f262a2cee6daaf73acd9a0f218f30ea5156b8259 Mon Sep 17 00:00:00 2001 From: BenWhitehead Date: Mon, 4 Aug 2025 17:22:03 -0400 Subject: [PATCH] chore: make MinFlushBufferedWritableByteChannel capable of being non-blocking Default is still blocking, but non-blocking can be chosen now. --- .../MinFlushBufferedWritableByteChannel.java | 57 +++++--- ...nFlushBufferedWritableByteChannelTest.java | 127 ++++++++++++++++++ 2 files changed, 168 insertions(+), 16 deletions(-) diff --git a/google-cloud-storage/src/main/java/com/google/cloud/storage/MinFlushBufferedWritableByteChannel.java b/google-cloud-storage/src/main/java/com/google/cloud/storage/MinFlushBufferedWritableByteChannel.java index 6de770441..30e8206ea 100644 --- a/google-cloud-storage/src/main/java/com/google/cloud/storage/MinFlushBufferedWritableByteChannel.java +++ b/google-cloud-storage/src/main/java/com/google/cloud/storage/MinFlushBufferedWritableByteChannel.java @@ -16,6 +16,8 @@ package com.google.cloud.storage; +import static com.google.common.base.Preconditions.checkState; + import com.google.cloud.storage.BufferedWritableByteChannelSession.BufferedWritableByteChannel; import com.google.cloud.storage.UnbufferedWritableByteChannelSession.UnbufferedWritableByteChannel; import java.io.IOException; @@ -55,10 +57,17 @@ final class MinFlushBufferedWritableByteChannel implements BufferedWritableByteC private final BufferHandle handle; private final UnbufferedWritableByteChannel channel; + private final boolean blocking; MinFlushBufferedWritableByteChannel(BufferHandle handle, UnbufferedWritableByteChannel channel) { + this(handle, channel, true); + } + + MinFlushBufferedWritableByteChannel( + BufferHandle handle, UnbufferedWritableByteChannel channel, boolean blocking) { this.handle = handle; this.channel = channel; + this.blocking = blocking; } @Override @@ -81,27 +90,43 @@ public int write(ByteBuffer src) throws IOException { } int capacity = handle.capacity(); + int position = handle.position(); int bufferPending = capacity - bufferRemaining; int totalPending = Math.addExact(srcRemaining, bufferPending); - if (totalPending >= capacity) { - ByteBuffer[] srcs; - if (enqueuedBytes()) { - ByteBuffer buffer = handle.get(); - Buffers.flip(buffer); - srcs = new ByteBuffer[] {buffer, src}; - } else { - srcs = new ByteBuffer[] {src}; - } - long write = channel.write(srcs); - if (enqueuedBytes()) { - // we didn't write enough bytes to consume the whole buffer. - Buffers.compact(handle.get()); - } else if (handle.position() == handle.capacity()) { + ByteBuffer[] srcs; + boolean usingBuffer = false; + if (enqueuedBytes()) { + usingBuffer = true; + ByteBuffer buffer = handle.get(); + Buffers.flip(buffer); + srcs = new ByteBuffer[] {buffer, src}; + } else { + srcs = new ByteBuffer[] {src}; + } + long written = channel.write(srcs); + checkState(written >= 0, "written >= 0 (%s > 0)", written); + if (usingBuffer) { + if (written >= bufferPending) { // we wrote enough to consume the buffer Buffers.clear(handle.get()); + } else if (written > 0) { + // we didn't write enough bytes to consume the whole buffer. + Buffers.compact(handle.get()); + } else /*if (written == 0)*/ { + // if none of the buffer was consumed, flip it back so we retain all bytes + Buffers.position(handle.get(), position); + Buffers.limit(handle.get(), capacity); } - int srcConsumed = Math.toIntExact(write) - bufferPending; - bytesConsumed += srcConsumed; + } + + int srcConsumed = Math.max(0, Math.toIntExact(written) - bufferPending); + bytesConsumed += srcConsumed; + + if (!blocking && written != totalPending) { + // we're configured in non-blocking mode, and we weren't able to make any progress on our + // call, break out to allow more bytes to be written to us or to allow underlying space + // to clear. + break; } } return bytesConsumed; diff --git a/google-cloud-storage/src/test/java/com/google/cloud/storage/MinFlushBufferedWritableByteChannelTest.java b/google-cloud-storage/src/test/java/com/google/cloud/storage/MinFlushBufferedWritableByteChannelTest.java index 14d146e53..513aa2859 100644 --- a/google-cloud-storage/src/test/java/com/google/cloud/storage/MinFlushBufferedWritableByteChannelTest.java +++ b/google-cloud-storage/src/test/java/com/google/cloud/storage/MinFlushBufferedWritableByteChannelTest.java @@ -27,6 +27,7 @@ import com.google.cloud.storage.DefaultBufferedWritableByteChannelTest.AuditingBufferHandle; import com.google.cloud.storage.DefaultBufferedWritableByteChannelTest.CountingWritableByteChannelAdapter; import com.google.cloud.storage.UnbufferedWritableByteChannelSession.UnbufferedWritableByteChannel; +import com.google.cloud.storage.it.ChecksummedTestContent; import com.google.common.collect.ImmutableList; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -52,14 +53,93 @@ import net.jqwik.api.Provide; import net.jqwik.api.providers.TypeUsage; import org.checkerframework.checker.nullness.qual.NonNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.slf4j.Marker; +import org.slf4j.MarkerFactory; public final class MinFlushBufferedWritableByteChannelTest { + private static final Logger LOGGER = + LoggerFactory.getLogger(MinFlushBufferedWritableByteChannelTest.class); + private static final Marker TRACE_ENTER = MarkerFactory.getMarker("enter"); + private static final Marker TRACE_EXIT = MarkerFactory.getMarker("exit"); @Example void edgeCases() { JqwikTest.report(TypeUsage.of(WriteOps.class), arbitraryWriteOps()); } + @Example + void nonBlockingWrite0DoesNotBlock() throws IOException { + BufferHandle handle = BufferHandle.allocate(5); + MinFlushBufferedWritableByteChannel c = + new MinFlushBufferedWritableByteChannel(handle, new OnlyConsumeNBytes(0, 1), false); + + ChecksummedTestContent all = ChecksummedTestContent.gen(11); + ByteBuffer s_0_4 = ByteBuffer.wrap(all.slice(0, 4).getBytes()); + ByteBuffer s_4_4 = ByteBuffer.wrap(all.slice(0, 4).getBytes()); + ByteBuffer s_8_3 = ByteBuffer.wrap(all.slice(0, 3).getBytes()); + int written1 = c.write(s_0_4); + assertThat(written1).isEqualTo(4); + assertThat(s_0_4.remaining()).isEqualTo(0); + + int written2 = c.write(s_4_4); + assertThat(written2).isEqualTo(0); + assertThat(s_4_4.remaining()).isEqualTo(4); + + int written3 = c.write(s_8_3); + assertThat(written3).isEqualTo(0); + assertThat(s_8_3.remaining()).isEqualTo(3); + + assertThat(handle.remaining()).isEqualTo(1); + } + + @Example + void nonBlockingWritePartialDoesNotBlock() throws IOException { + BufferHandle handle = BufferHandle.allocate(5); + MinFlushBufferedWritableByteChannel c = + new MinFlushBufferedWritableByteChannel(handle, new OnlyConsumeNBytes(6, 5), false); + + ChecksummedTestContent all = ChecksummedTestContent.gen(11); + ByteBuffer s_0_4 = ByteBuffer.wrap(all.slice(0, 4).getBytes()); + ByteBuffer s_4_4 = ByteBuffer.wrap(all.slice(0, 4).getBytes()); + int written1 = c.write(s_0_4); + assertThat(written1).isEqualTo(4); + assertThat(s_0_4.remaining()).isEqualTo(0); + assertThat(handle.remaining()).isEqualTo(1); + + int written2 = c.write(s_4_4); + assertThat(written2).isEqualTo(1); + assertThat(s_4_4.remaining()).isEqualTo(3); + assertThat(handle.remaining()).isEqualTo(5); + } + + @Example + void illegalStateExceptionIfWrittenLt0() throws IOException { + BufferHandle handle = BufferHandle.allocate(4); + MinFlushBufferedWritableByteChannel c = + new MinFlushBufferedWritableByteChannel( + handle, + new UnbufferedWritableByteChannel() { + @Override + public long write(ByteBuffer[] srcs, int offset, int length) { + return -1; + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public void close() {} + }); + + ChecksummedTestContent all = ChecksummedTestContent.gen(11); + ByteBuffer s_0_4 = ByteBuffer.wrap(all.slice(0, 4).getBytes()); + assertThrows(IllegalStateException.class, () -> c.write(s_0_4)); + } + @Property void bufferingEagerlyFlushesWhenFull(@ForAll("WriteOps") WriteOps writeOps) throws IOException { ByteBuffer buffer = ByteBuffer.allocate(writeOps.bufferSize); @@ -580,4 +660,51 @@ static WriteOps of(int numBytes, int bufferSize, int writeSize) { dbgExpectedWriteSizes); } } + + private static final class OnlyConsumeNBytes implements UnbufferedWritableByteChannel { + private static final Logger LOGGER = LoggerFactory.getLogger(OnlyConsumeNBytes.class); + private final long bytesToConsume; + private final int consumptionIncrement; + private long bytesConsumed; + + private OnlyConsumeNBytes(int bytesToConsume, int consumptionIncrement) { + this.bytesToConsume = bytesToConsume; + this.consumptionIncrement = consumptionIncrement; + this.bytesConsumed = 0; + } + + @Override + public long write(ByteBuffer[] srcs, int offset, int length) { + LOGGER.info(TRACE_ENTER, "write(srcs : {}, offset : {}, length : {})", srcs, offset, length); + try { + if (bytesConsumed >= bytesToConsume) { + return 0; + } + + long consumed = 0; + int toConsume = consumptionIncrement; + for (int i = offset; i < length && toConsume > 0; i++) { + ByteBuffer src = srcs[i]; + int remaining = src.remaining(); + int position = src.position(); + int consumable = Math.min(toConsume, remaining); + toConsume -= consumable; + consumed += consumable; + src.position(position + consumable); + } + bytesConsumed += consumed; + return consumed; + } finally { + LOGGER.info(TRACE_EXIT, "write(srcs : {}, offset : {}, length : {})", srcs, offset, length); + } + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public void close() {} + } }