Skip to content

Commit e3938f0

Browse files
committed
chore: make MinFlushBufferedWritableByteChannel capable of being non-blocking
Default is still blocking, but non-blocking can be chosen now.
1 parent 90bddd1 commit e3938f0

File tree

2 files changed

+168
-16
lines changed

2 files changed

+168
-16
lines changed

google-cloud-storage/src/main/java/com/google/cloud/storage/MinFlushBufferedWritableByteChannel.java

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package com.google.cloud.storage;
1818

19+
import static com.google.common.base.Preconditions.checkState;
20+
1921
import com.google.cloud.storage.BufferedWritableByteChannelSession.BufferedWritableByteChannel;
2022
import com.google.cloud.storage.UnbufferedWritableByteChannelSession.UnbufferedWritableByteChannel;
2123
import java.io.IOException;
@@ -55,10 +57,17 @@ final class MinFlushBufferedWritableByteChannel implements BufferedWritableByteC
5557
private final BufferHandle handle;
5658

5759
private final UnbufferedWritableByteChannel channel;
60+
private final boolean blocking;
5861

5962
MinFlushBufferedWritableByteChannel(BufferHandle handle, UnbufferedWritableByteChannel channel) {
63+
this(handle, channel, true);
64+
}
65+
66+
MinFlushBufferedWritableByteChannel(
67+
BufferHandle handle, UnbufferedWritableByteChannel channel, boolean blocking) {
6068
this.handle = handle;
6169
this.channel = channel;
70+
this.blocking = blocking;
6271
}
6372

6473
@Override
@@ -81,27 +90,43 @@ public int write(ByteBuffer src) throws IOException {
8190
}
8291

8392
int capacity = handle.capacity();
93+
int position = handle.position();
8494
int bufferPending = capacity - bufferRemaining;
8595
int totalPending = Math.addExact(srcRemaining, bufferPending);
86-
if (totalPending >= capacity) {
87-
ByteBuffer[] srcs;
88-
if (enqueuedBytes()) {
89-
ByteBuffer buffer = handle.get();
90-
Buffers.flip(buffer);
91-
srcs = new ByteBuffer[] {buffer, src};
92-
} else {
93-
srcs = new ByteBuffer[] {src};
94-
}
95-
long write = channel.write(srcs);
96-
if (enqueuedBytes()) {
97-
// we didn't write enough bytes to consume the whole buffer.
98-
Buffers.compact(handle.get());
99-
} else if (handle.position() == handle.capacity()) {
96+
ByteBuffer[] srcs;
97+
boolean usingBuffer = false;
98+
if (enqueuedBytes()) {
99+
usingBuffer = true;
100+
ByteBuffer buffer = handle.get();
101+
Buffers.flip(buffer);
102+
srcs = new ByteBuffer[] {buffer, src};
103+
} else {
104+
srcs = new ByteBuffer[] {src};
105+
}
106+
long written = channel.write(srcs);
107+
checkState(written >= 0, "written >= 0 (%s > 0)", written);
108+
if (usingBuffer) {
109+
if (written >= bufferPending) {
100110
// we wrote enough to consume the buffer
101111
Buffers.clear(handle.get());
112+
} else if (written > 0) {
113+
// we didn't write enough bytes to consume the whole buffer.
114+
Buffers.compact(handle.get());
115+
} else /*if (written == 0)*/ {
116+
// if none of the buffer was consumed, flip it back so we retain all bytes
117+
Buffers.position(handle.get(), position);
118+
Buffers.limit(handle.get(), capacity);
102119
}
103-
int srcConsumed = Math.toIntExact(write) - bufferPending;
104-
bytesConsumed += srcConsumed;
120+
}
121+
122+
int srcConsumed = Math.max(0, Math.toIntExact(written) - bufferPending);
123+
bytesConsumed += srcConsumed;
124+
125+
if (!blocking && written != totalPending) {
126+
// we're configured in non-blocking mode, and we weren't able to make any progress on our
127+
// call, break out to allow more bytes to be written to us or to allow underlying space
128+
// to clear.
129+
break;
105130
}
106131
}
107132
return bytesConsumed;

google-cloud-storage/src/test/java/com/google/cloud/storage/MinFlushBufferedWritableByteChannelTest.java

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import com.google.cloud.storage.DefaultBufferedWritableByteChannelTest.AuditingBufferHandle;
2828
import com.google.cloud.storage.DefaultBufferedWritableByteChannelTest.CountingWritableByteChannelAdapter;
2929
import com.google.cloud.storage.UnbufferedWritableByteChannelSession.UnbufferedWritableByteChannel;
30+
import com.google.cloud.storage.it.ChecksummedTestContent;
3031
import com.google.common.collect.ImmutableList;
3132
import java.io.ByteArrayOutputStream;
3233
import java.io.IOException;
@@ -52,14 +53,93 @@
5253
import net.jqwik.api.Provide;
5354
import net.jqwik.api.providers.TypeUsage;
5455
import org.checkerframework.checker.nullness.qual.NonNull;
56+
import org.slf4j.Logger;
57+
import org.slf4j.LoggerFactory;
58+
import org.slf4j.Marker;
59+
import org.slf4j.MarkerFactory;
5560

5661
public final class MinFlushBufferedWritableByteChannelTest {
62+
private static final Logger LOGGER =
63+
LoggerFactory.getLogger(MinFlushBufferedWritableByteChannelTest.class);
64+
private static final Marker TRACE_ENTER = MarkerFactory.getMarker("enter");
65+
private static final Marker TRACE_EXIT = MarkerFactory.getMarker("exit");
5766

5867
@Example
5968
void edgeCases() {
6069
JqwikTest.report(TypeUsage.of(WriteOps.class), arbitraryWriteOps());
6170
}
6271

72+
@Example
73+
void nonBlockingWrite0DoesNotBlock() throws IOException {
74+
BufferHandle handle = BufferHandle.allocate(5);
75+
MinFlushBufferedWritableByteChannel c =
76+
new MinFlushBufferedWritableByteChannel(handle, new OnlyConsumeNBytes(0, 1), false);
77+
78+
ChecksummedTestContent all = ChecksummedTestContent.gen(11);
79+
ByteBuffer s_0_4 = ByteBuffer.wrap(all.slice(0, 4).getBytes());
80+
ByteBuffer s_4_4 = ByteBuffer.wrap(all.slice(0, 4).getBytes());
81+
ByteBuffer s_8_3 = ByteBuffer.wrap(all.slice(0, 3).getBytes());
82+
int written1 = c.write(s_0_4);
83+
assertThat(written1).isEqualTo(4);
84+
assertThat(s_0_4.remaining()).isEqualTo(0);
85+
86+
int written2 = c.write(s_4_4);
87+
assertThat(written2).isEqualTo(0);
88+
assertThat(s_4_4.remaining()).isEqualTo(4);
89+
90+
int written3 = c.write(s_8_3);
91+
assertThat(written3).isEqualTo(0);
92+
assertThat(s_8_3.remaining()).isEqualTo(3);
93+
94+
assertThat(handle.remaining()).isEqualTo(1);
95+
}
96+
97+
@Example
98+
void nonBlockingWritePartialDoesNotBlock() throws IOException {
99+
BufferHandle handle = BufferHandle.allocate(5);
100+
MinFlushBufferedWritableByteChannel c =
101+
new MinFlushBufferedWritableByteChannel(handle, new OnlyConsumeNBytes(6, 5), false);
102+
103+
ChecksummedTestContent all = ChecksummedTestContent.gen(11);
104+
ByteBuffer s_0_4 = ByteBuffer.wrap(all.slice(0, 4).getBytes());
105+
ByteBuffer s_4_4 = ByteBuffer.wrap(all.slice(0, 4).getBytes());
106+
int written1 = c.write(s_0_4);
107+
assertThat(written1).isEqualTo(4);
108+
assertThat(s_0_4.remaining()).isEqualTo(0);
109+
assertThat(handle.remaining()).isEqualTo(1);
110+
111+
int written2 = c.write(s_4_4);
112+
assertThat(written2).isEqualTo(1);
113+
assertThat(s_4_4.remaining()).isEqualTo(3);
114+
assertThat(handle.remaining()).isEqualTo(5);
115+
}
116+
117+
@Example
118+
void illegalStateExceptionIfWrittenLt0() throws IOException {
119+
BufferHandle handle = BufferHandle.allocate(4);
120+
MinFlushBufferedWritableByteChannel c =
121+
new MinFlushBufferedWritableByteChannel(
122+
handle,
123+
new UnbufferedWritableByteChannel() {
124+
@Override
125+
public long write(ByteBuffer[] srcs, int offset, int length) {
126+
return -1;
127+
}
128+
129+
@Override
130+
public boolean isOpen() {
131+
return true;
132+
}
133+
134+
@Override
135+
public void close() {}
136+
});
137+
138+
ChecksummedTestContent all = ChecksummedTestContent.gen(11);
139+
ByteBuffer s_0_4 = ByteBuffer.wrap(all.slice(0, 4).getBytes());
140+
assertThrows(IllegalStateException.class, () -> c.write(s_0_4));
141+
}
142+
63143
@Property
64144
void bufferingEagerlyFlushesWhenFull(@ForAll("WriteOps") WriteOps writeOps) throws IOException {
65145
ByteBuffer buffer = ByteBuffer.allocate(writeOps.bufferSize);
@@ -580,4 +660,51 @@ static WriteOps of(int numBytes, int bufferSize, int writeSize) {
580660
dbgExpectedWriteSizes);
581661
}
582662
}
663+
664+
private static final class OnlyConsumeNBytes implements UnbufferedWritableByteChannel {
665+
private static final Logger LOGGER = LoggerFactory.getLogger(OnlyConsumeNBytes.class);
666+
private final long bytesToConsume;
667+
private final int consumptionIncrement;
668+
private long bytesConsumed;
669+
670+
private OnlyConsumeNBytes(int bytesToConsume, int consumptionIncrement) {
671+
this.bytesToConsume = bytesToConsume;
672+
this.consumptionIncrement = consumptionIncrement;
673+
this.bytesConsumed = 0;
674+
}
675+
676+
@Override
677+
public long write(ByteBuffer[] srcs, int offset, int length) {
678+
LOGGER.info(TRACE_ENTER, "write(srcs : {}, offset : {}, length : {})", srcs, offset, length);
679+
try {
680+
if (bytesConsumed >= bytesToConsume) {
681+
return 0;
682+
}
683+
684+
long consumed = 0;
685+
int toConsume = consumptionIncrement;
686+
for (int i = offset; i < length && toConsume > 0; i++) {
687+
ByteBuffer src = srcs[i];
688+
int remaining = src.remaining();
689+
int position = src.position();
690+
int consumable = Math.min(toConsume, remaining);
691+
toConsume -= consumable;
692+
consumed += consumable;
693+
src.position(position + consumable);
694+
}
695+
bytesConsumed += consumed;
696+
return consumed;
697+
} finally {
698+
LOGGER.info(TRACE_EXIT, "write(srcs : {}, offset : {}, length : {})", srcs, offset, length);
699+
}
700+
}
701+
702+
@Override
703+
public boolean isOpen() {
704+
return true;
705+
}
706+
707+
@Override
708+
public void close() {}
709+
}
583710
}

0 commit comments

Comments
 (0)