Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.google.cloud.storage;

import static com.google.common.base.Preconditions.checkState;

import com.google.cloud.storage.BufferedWritableByteChannelSession.BufferedWritableByteChannel;
import com.google.cloud.storage.UnbufferedWritableByteChannelSession.UnbufferedWritableByteChannel;
import java.io.IOException;
Expand Down Expand Up @@ -55,10 +57,17 @@ final class MinFlushBufferedWritableByteChannel implements BufferedWritableByteC
private final BufferHandle handle;

private final UnbufferedWritableByteChannel channel;
private final boolean blocking;

MinFlushBufferedWritableByteChannel(BufferHandle handle, UnbufferedWritableByteChannel channel) {
this(handle, channel, true);
}

MinFlushBufferedWritableByteChannel(
BufferHandle handle, UnbufferedWritableByteChannel channel, boolean blocking) {
this.handle = handle;
this.channel = channel;
this.blocking = blocking;
}

@Override
Expand All @@ -81,27 +90,43 @@ public int write(ByteBuffer src) throws IOException {
}

int capacity = handle.capacity();
int position = handle.position();
int bufferPending = capacity - bufferRemaining;
int totalPending = Math.addExact(srcRemaining, bufferPending);
if (totalPending >= capacity) {
ByteBuffer[] srcs;
if (enqueuedBytes()) {
ByteBuffer buffer = handle.get();
Buffers.flip(buffer);
srcs = new ByteBuffer[] {buffer, src};
} else {
srcs = new ByteBuffer[] {src};
}
long write = channel.write(srcs);
if (enqueuedBytes()) {
// we didn't write enough bytes to consume the whole buffer.
Buffers.compact(handle.get());
} else if (handle.position() == handle.capacity()) {
ByteBuffer[] srcs;
boolean usingBuffer = false;
if (enqueuedBytes()) {
usingBuffer = true;
ByteBuffer buffer = handle.get();
Buffers.flip(buffer);
srcs = new ByteBuffer[] {buffer, src};
} else {
srcs = new ByteBuffer[] {src};
}
long written = channel.write(srcs);
checkState(written >= 0, "written >= 0 (%s > 0)", written);
if (usingBuffer) {
if (written >= bufferPending) {
// we wrote enough to consume the buffer
Buffers.clear(handle.get());
} else if (written > 0) {
// we didn't write enough bytes to consume the whole buffer.
Buffers.compact(handle.get());
} else /*if (written == 0)*/ {
// if none of the buffer was consumed, flip it back so we retain all bytes
Buffers.position(handle.get(), position);
Buffers.limit(handle.get(), capacity);
}
int srcConsumed = Math.toIntExact(write) - bufferPending;
bytesConsumed += srcConsumed;
}

int srcConsumed = Math.max(0, Math.toIntExact(written) - bufferPending);
bytesConsumed += srcConsumed;

if (!blocking && written != totalPending) {
// we're configured in non-blocking mode, and we weren't able to make any progress on our
// call, break out to allow more bytes to be written to us or to allow underlying space
// to clear.
break;
}
}
return bytesConsumed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.google.cloud.storage.DefaultBufferedWritableByteChannelTest.AuditingBufferHandle;
import com.google.cloud.storage.DefaultBufferedWritableByteChannelTest.CountingWritableByteChannelAdapter;
import com.google.cloud.storage.UnbufferedWritableByteChannelSession.UnbufferedWritableByteChannel;
import com.google.cloud.storage.it.ChecksummedTestContent;
import com.google.common.collect.ImmutableList;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
Expand All @@ -52,14 +53,93 @@
import net.jqwik.api.Provide;
import net.jqwik.api.providers.TypeUsage;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Marker;
import org.slf4j.MarkerFactory;

public final class MinFlushBufferedWritableByteChannelTest {
private static final Logger LOGGER =
LoggerFactory.getLogger(MinFlushBufferedWritableByteChannelTest.class);
private static final Marker TRACE_ENTER = MarkerFactory.getMarker("enter");
private static final Marker TRACE_EXIT = MarkerFactory.getMarker("exit");

@Example
void edgeCases() {
JqwikTest.report(TypeUsage.of(WriteOps.class), arbitraryWriteOps());
}

@Example
void nonBlockingWrite0DoesNotBlock() throws IOException {
BufferHandle handle = BufferHandle.allocate(5);
MinFlushBufferedWritableByteChannel c =
new MinFlushBufferedWritableByteChannel(handle, new OnlyConsumeNBytes(0, 1), false);

ChecksummedTestContent all = ChecksummedTestContent.gen(11);
ByteBuffer s_0_4 = ByteBuffer.wrap(all.slice(0, 4).getBytes());
ByteBuffer s_4_4 = ByteBuffer.wrap(all.slice(0, 4).getBytes());
ByteBuffer s_8_3 = ByteBuffer.wrap(all.slice(0, 3).getBytes());
int written1 = c.write(s_0_4);
assertThat(written1).isEqualTo(4);
assertThat(s_0_4.remaining()).isEqualTo(0);

int written2 = c.write(s_4_4);
assertThat(written2).isEqualTo(0);
assertThat(s_4_4.remaining()).isEqualTo(4);

int written3 = c.write(s_8_3);
assertThat(written3).isEqualTo(0);
assertThat(s_8_3.remaining()).isEqualTo(3);

assertThat(handle.remaining()).isEqualTo(1);
}

@Example
void nonBlockingWritePartialDoesNotBlock() throws IOException {
BufferHandle handle = BufferHandle.allocate(5);
MinFlushBufferedWritableByteChannel c =
new MinFlushBufferedWritableByteChannel(handle, new OnlyConsumeNBytes(6, 5), false);

ChecksummedTestContent all = ChecksummedTestContent.gen(11);
ByteBuffer s_0_4 = ByteBuffer.wrap(all.slice(0, 4).getBytes());
ByteBuffer s_4_4 = ByteBuffer.wrap(all.slice(0, 4).getBytes());
int written1 = c.write(s_0_4);
assertThat(written1).isEqualTo(4);
assertThat(s_0_4.remaining()).isEqualTo(0);
assertThat(handle.remaining()).isEqualTo(1);

int written2 = c.write(s_4_4);
assertThat(written2).isEqualTo(1);
assertThat(s_4_4.remaining()).isEqualTo(3);
assertThat(handle.remaining()).isEqualTo(5);
}

@Example
void illegalStateExceptionIfWrittenLt0() throws IOException {
BufferHandle handle = BufferHandle.allocate(4);
MinFlushBufferedWritableByteChannel c =
new MinFlushBufferedWritableByteChannel(
handle,
new UnbufferedWritableByteChannel() {
@Override
public long write(ByteBuffer[] srcs, int offset, int length) {
return -1;
}

@Override
public boolean isOpen() {
return true;
}

@Override
public void close() {}
});

ChecksummedTestContent all = ChecksummedTestContent.gen(11);
ByteBuffer s_0_4 = ByteBuffer.wrap(all.slice(0, 4).getBytes());
assertThrows(IllegalStateException.class, () -> c.write(s_0_4));
}

@Property
void bufferingEagerlyFlushesWhenFull(@ForAll("WriteOps") WriteOps writeOps) throws IOException {
ByteBuffer buffer = ByteBuffer.allocate(writeOps.bufferSize);
Expand Down Expand Up @@ -580,4 +660,51 @@ static WriteOps of(int numBytes, int bufferSize, int writeSize) {
dbgExpectedWriteSizes);
}
}

private static final class OnlyConsumeNBytes implements UnbufferedWritableByteChannel {
private static final Logger LOGGER = LoggerFactory.getLogger(OnlyConsumeNBytes.class);
private final long bytesToConsume;
private final int consumptionIncrement;
private long bytesConsumed;

private OnlyConsumeNBytes(int bytesToConsume, int consumptionIncrement) {
this.bytesToConsume = bytesToConsume;
this.consumptionIncrement = consumptionIncrement;
this.bytesConsumed = 0;
}

@Override
public long write(ByteBuffer[] srcs, int offset, int length) {
LOGGER.info(TRACE_ENTER, "write(srcs : {}, offset : {}, length : {})", srcs, offset, length);
try {
if (bytesConsumed >= bytesToConsume) {
return 0;
}

long consumed = 0;
int toConsume = consumptionIncrement;
for (int i = offset; i < length && toConsume > 0; i++) {
ByteBuffer src = srcs[i];
int remaining = src.remaining();
int position = src.position();
int consumable = Math.min(toConsume, remaining);
toConsume -= consumable;
consumed += consumable;
src.position(position + consumable);
}
bytesConsumed += consumed;
return consumed;
} finally {
LOGGER.info(TRACE_EXIT, "write(srcs : {}, offset : {}, length : {})", srcs, offset, length);
}
}

@Override
public boolean isOpen() {
return true;
}

@Override
public void close() {}
}
}
Loading