Skip to content

Commit 952315c

Browse files
committed
DataBufferUtils does not release DataBuffer on error cases
This commit makes sure that in DataBufferUtils.write, any received data buffers are returned as part of the returned flux, even when an error occurs or is received. Issue: SPR-16782 (cherry picked from commit 1a0522b)
1 parent a006073 commit 952315c

File tree

2 files changed

+239
-24
lines changed

2 files changed

+239
-24
lines changed

spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.util.concurrent.Callable;
3232
import java.util.concurrent.atomic.AtomicBoolean;
3333
import java.util.concurrent.atomic.AtomicLong;
34+
import java.util.concurrent.atomic.AtomicReference;
3435
import java.util.function.Consumer;
3536
import java.util.function.IntPredicate;
3637

@@ -336,6 +337,7 @@ public static Flux<DataBuffer> write(Publisher<DataBuffer> source, WritableByteC
336337
sink.next(dataBuffer);
337338
}
338339
catch (IOException ex) {
340+
sink.next(dataBuffer);
339341
sink.error(ex);
340342
}
341343

@@ -355,6 +357,26 @@ public static Flux<DataBuffer> write(Publisher<DataBuffer> source, WritableByteC
355357
* @param channel the channel to write to
356358
* @return a flux containing the same buffers as in {@code source}, that starts the writing
357359
* process when subscribed to, and that publishes any writing errors and the completion signal
360+
* @since 5.0.10
361+
*/
362+
public static Flux<DataBuffer> write(
363+
Publisher<DataBuffer> source, AsynchronousFileChannel channel) {
364+
return write(source, channel, 0);
365+
}
366+
367+
368+
/**
369+
* Write the given stream of {@link DataBuffer DataBuffers} to the given {@code AsynchronousFileChannel}.
370+
* Does <strong>not</strong> close the channel when the flux is terminated, and does
371+
* <strong>not</strong> {@linkplain #release(DataBuffer) release} the data buffers in the
372+
* source. If releasing is required, then subscribe to the returned {@code Flux} with a
373+
* {@link #releaseConsumer()}.
374+
* <p>Note that the writing process does not start until the returned {@code Flux} is subscribed to.
375+
* @param source the stream of data buffers to be written
376+
* @param channel the channel to write to
377+
* @param position the file position at which the write is to begin; must be non-negative
378+
* @return a flux containing the same buffers as in {@code source}, that starts the writing
379+
* process when subscribed to, and that publishes any writing errors and the completion signal
358380
*/
359381
public static Flux<DataBuffer> write(
360382
Publisher<DataBuffer> source, AsynchronousFileChannel channel, long position) {
@@ -610,10 +632,11 @@ private static class AsynchronousFileChannelWriteCompletionHandler extends BaseS
610632

611633
private final AtomicBoolean completed = new AtomicBoolean();
612634

635+
private final AtomicReference<Throwable> error = new AtomicReference<>();
636+
613637
private final AtomicLong position;
614638

615-
@Nullable
616-
private DataBuffer dataBuffer;
639+
private final AtomicReference<DataBuffer> dataBuffer = new AtomicReference<>();
617640

618641
public AsynchronousFileChannelWriteCompletionHandler(
619642
FluxSink<DataBuffer> sink, AsynchronousFileChannel channel, long position) {
@@ -630,21 +653,27 @@ protected void hookOnSubscribe(Subscription subscription) {
630653

631654
@Override
632655
protected void hookOnNext(DataBuffer value) {
633-
this.dataBuffer = value;
656+
if (!this.dataBuffer.compareAndSet(null, value)) {
657+
throw new IllegalStateException();
658+
}
634659
ByteBuffer byteBuffer = value.asByteBuffer();
635660
this.channel.write(byteBuffer, this.position.get(), byteBuffer, this);
636661
}
637662

638663
@Override
639664
protected void hookOnError(Throwable throwable) {
640-
this.sink.error(throwable);
665+
this.error.set(throwable);
666+
667+
if (this.dataBuffer.get() == null) {
668+
this.sink.error(throwable);
669+
}
641670
}
642671

643672
@Override
644673
protected void hookOnComplete() {
645674
this.completed.set(true);
646675

647-
if (this.dataBuffer == null) {
676+
if (this.dataBuffer.get() == null) {
648677
this.sink.complete();
649678
}
650679
}
@@ -656,11 +685,13 @@ public void completed(Integer written, ByteBuffer byteBuffer) {
656685
this.channel.write(byteBuffer, pos, byteBuffer, this);
657686
return;
658687
}
659-
if (this.dataBuffer != null) {
660-
this.sink.next(this.dataBuffer);
661-
this.dataBuffer = null;
688+
sinkDataBuffer();
689+
690+
Throwable throwable = this.error.get();
691+
if (throwable != null) {
692+
this.sink.error(throwable);
662693
}
663-
if (this.completed.get()) {
694+
else if (this.completed.get()) {
664695
this.sink.complete();
665696
}
666697
else {
@@ -670,8 +701,16 @@ public void completed(Integer written, ByteBuffer byteBuffer) {
670701

671702
@Override
672703
public void failed(Throwable exc, ByteBuffer byteBuffer) {
704+
sinkDataBuffer();
673705
this.sink.error(exc);
674706
}
707+
708+
private void sinkDataBuffer() {
709+
DataBuffer dataBuffer = this.dataBuffer.get();
710+
Assert.state(dataBuffer != null, "DataBuffer should not be null");
711+
this.sink.next(dataBuffer);
712+
this.dataBuffer.set(null);
713+
}
675714
}
676715

677716

spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java

Lines changed: 191 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
package org.springframework.core.io.buffer;
1818

19+
import java.io.IOException;
1920
import java.io.OutputStream;
2021
import java.net.URI;
2122
import java.nio.ByteBuffer;
2223
import java.nio.channels.AsynchronousFileChannel;
24+
import java.nio.channels.CompletionHandler;
2325
import java.nio.channels.FileChannel;
2426
import java.nio.channels.ReadableByteChannel;
2527
import java.nio.channels.WritableByteChannel;
@@ -29,7 +31,7 @@
2931
import java.nio.file.Paths;
3032
import java.nio.file.StandardOpenOption;
3133
import java.time.Duration;
32-
import java.util.stream.Collectors;
34+
import java.util.concurrent.CountDownLatch;
3335

3436
import io.netty.buffer.ByteBuf;
3537
import org.junit.Test;
@@ -160,9 +162,7 @@ public void writeOutputStream() throws Exception {
160162
.expectComplete()
161163
.verify(Duration.ofSeconds(5));
162164

163-
String result = Files.readAllLines(tempFile)
164-
.stream()
165-
.collect(Collectors.joining());
165+
String result = String.join("", Files.readAllLines(tempFile));
166166

167167
assertEquals("foobarbazqux", result);
168168
os.close();
@@ -188,14 +188,60 @@ public void writeWritableByteChannel() throws Exception {
188188
.expectComplete()
189189
.verify(Duration.ofSeconds(5));
190190

191-
String result = Files.readAllLines(tempFile)
192-
.stream()
193-
.collect(Collectors.joining());
191+
String result = String.join("", Files.readAllLines(tempFile));
194192

195193
assertEquals("foobarbazqux", result);
196194
channel.close();
197195
}
198196

197+
@Test
198+
public void writeWritableByteChannelErrorInFlux() throws Exception {
199+
DataBuffer foo = stringBuffer("foo");
200+
DataBuffer bar = stringBuffer("bar");
201+
Flux<DataBuffer> flux = Flux.just(foo, bar).concatWith(Flux.error(new RuntimeException()));
202+
203+
Path tempFile = Files.createTempFile("DataBufferUtilsTests", null);
204+
WritableByteChannel channel = Files.newByteChannel(tempFile, StandardOpenOption.WRITE);
205+
206+
Flux<DataBuffer> writeResult = DataBufferUtils.write(flux, channel);
207+
StepVerifier.create(writeResult)
208+
.consumeNextWith(stringConsumer("foo"))
209+
.consumeNextWith(stringConsumer("bar"))
210+
.expectError()
211+
.verify(Duration.ofSeconds(5));
212+
213+
String result = String.join("", Files.readAllLines(tempFile));
214+
215+
assertEquals("foobar", result);
216+
channel.close();
217+
}
218+
219+
@Test
220+
public void writeWritableByteChannelErrorInWrite() throws Exception {
221+
DataBuffer foo = stringBuffer("foo");
222+
DataBuffer bar = stringBuffer("bar");
223+
Flux<DataBuffer> flux = Flux.just(foo, bar);
224+
225+
WritableByteChannel channel = mock(WritableByteChannel.class);
226+
when(channel.write(any()))
227+
.thenAnswer(invocation -> {
228+
ByteBuffer buffer = invocation.getArgument(0);
229+
int written = buffer.remaining();
230+
buffer.position(buffer.limit());
231+
return written;
232+
})
233+
.thenThrow(new IOException());
234+
235+
Flux<DataBuffer> writeResult = DataBufferUtils.write(flux, channel);
236+
StepVerifier.create(writeResult)
237+
.consumeNextWith(stringConsumer("foo"))
238+
.consumeNextWith(stringConsumer("bar"))
239+
.expectError(IOException.class)
240+
.verify();
241+
242+
channel.close();
243+
}
244+
199245
@Test
200246
public void writeAsynchronousFileChannel() throws Exception {
201247
DataBuffer foo = stringBuffer("foo");
@@ -208,7 +254,7 @@ public void writeAsynchronousFileChannel() throws Exception {
208254
AsynchronousFileChannel channel =
209255
AsynchronousFileChannel.open(tempFile, StandardOpenOption.WRITE);
210256

211-
Flux<DataBuffer> writeResult = DataBufferUtils.write(flux, channel, 0);
257+
Flux<DataBuffer> writeResult = DataBufferUtils.write(flux, channel);
212258
StepVerifier.create(writeResult)
213259
.consumeNextWith(stringConsumer("foo"))
214260
.consumeNextWith(stringConsumer("bar"))
@@ -217,14 +263,142 @@ public void writeAsynchronousFileChannel() throws Exception {
217263
.expectComplete()
218264
.verify(Duration.ofSeconds(5));
219265

220-
String result = Files.readAllLines(tempFile)
221-
.stream()
222-
.collect(Collectors.joining());
266+
String result = String.join("", Files.readAllLines(tempFile));
223267

224268
assertEquals("foobarbazqux", result);
225269
channel.close();
226270
}
227271

272+
@Test
273+
public void writeAsynchronousFileChannelErrorInFlux() throws Exception {
274+
DataBuffer foo = stringBuffer("foo");
275+
DataBuffer bar = stringBuffer("bar");
276+
Flux<DataBuffer> flux =
277+
Flux.just(foo, bar).concatWith(Mono.error(new RuntimeException()));
278+
279+
Path tempFile = Files.createTempFile("DataBufferUtilsTests", null);
280+
AsynchronousFileChannel channel =
281+
AsynchronousFileChannel.open(tempFile, StandardOpenOption.WRITE);
282+
283+
Flux<DataBuffer> writeResult = DataBufferUtils.write(flux, channel);
284+
StepVerifier.create(writeResult)
285+
.consumeNextWith(stringConsumer("foo"))
286+
.consumeNextWith(stringConsumer("bar"))
287+
.expectError(RuntimeException.class)
288+
.verify();
289+
290+
String result = String.join("", Files.readAllLines(tempFile));
291+
292+
assertEquals("foobar", result);
293+
channel.close();
294+
}
295+
296+
297+
@Test
298+
public void writeAsynchronousFileChannelErrorInWrite() throws Exception {
299+
DataBuffer foo = stringBuffer("foo");
300+
DataBuffer bar = stringBuffer("bar");
301+
Flux<DataBuffer> flux = Flux.just(foo, bar);
302+
303+
AsynchronousFileChannel channel = mock(AsynchronousFileChannel.class);
304+
doAnswer(invocation -> {
305+
ByteBuffer buffer = invocation.getArgument(0);
306+
long pos = invocation.getArgument(1);
307+
CompletionHandler<Integer, ByteBuffer> completionHandler = invocation.getArgument(3);
308+
309+
assertEquals(0, pos);
310+
311+
int written = buffer.remaining();
312+
buffer.position(buffer.limit());
313+
completionHandler.completed(written, buffer);
314+
315+
return null;
316+
})
317+
.doAnswer(invocation -> {
318+
ByteBuffer buffer = invocation.getArgument(0);
319+
CompletionHandler<Integer, ByteBuffer> completionHandler =
320+
invocation.getArgument(3);
321+
completionHandler.failed(new IOException(), buffer);
322+
return null;
323+
})
324+
.when(channel).write(isA(ByteBuffer.class), anyLong(), isA(ByteBuffer.class),
325+
isA(CompletionHandler.class));
326+
327+
Flux<DataBuffer> writeResult = DataBufferUtils.write(flux, channel);
328+
StepVerifier.create(writeResult)
329+
.consumeNextWith(stringConsumer("foo"))
330+
.consumeNextWith(stringConsumer("bar"))
331+
.expectError(IOException.class)
332+
.verify();
333+
334+
channel.close();
335+
}
336+
337+
@Test
338+
public void readAndWriteByteChannel() throws Exception {
339+
Path source = Paths.get(
340+
DataBufferUtilsTests.class.getResource("DataBufferUtilsTests.txt").toURI());
341+
Flux<DataBuffer> sourceFlux =
342+
DataBufferUtils
343+
.readByteChannel(() -> FileChannel.open(source, StandardOpenOption.READ),
344+
this.bufferFactory, 3);
345+
346+
Path destination = Files.createTempFile("DataBufferUtilsTests", null);
347+
WritableByteChannel channel = Files.newByteChannel(destination, StandardOpenOption.WRITE);
348+
349+
DataBufferUtils.write(sourceFlux, channel)
350+
.subscribe(DataBufferUtils.releaseConsumer(),
351+
throwable -> fail(throwable.getMessage()),
352+
() -> {
353+
try {
354+
String expected = String.join("", Files.readAllLines(source));
355+
String result = String.join("", Files.readAllLines(destination));
356+
357+
assertEquals(expected, result);
358+
channel.close();
359+
360+
}
361+
catch (IOException e) {
362+
fail(e.getMessage());
363+
}
364+
});
365+
}
366+
367+
@Test
368+
public void readAndWriteAsynchronousFileChannel() throws Exception {
369+
Path source = Paths.get(
370+
DataBufferUtilsTests.class.getResource("DataBufferUtilsTests.txt").toURI());
371+
Flux<DataBuffer> sourceFlux = DataBufferUtils.readAsynchronousFileChannel(
372+
() -> AsynchronousFileChannel.open(source, StandardOpenOption.READ),
373+
this.bufferFactory, 3);
374+
375+
Path destination = Files.createTempFile("DataBufferUtilsTests", null);
376+
AsynchronousFileChannel channel =
377+
AsynchronousFileChannel.open(destination, StandardOpenOption.WRITE);
378+
379+
CountDownLatch latch = new CountDownLatch(1);
380+
381+
DataBufferUtils.write(sourceFlux, channel)
382+
.subscribe(DataBufferUtils::release,
383+
throwable -> fail(throwable.getMessage()),
384+
() -> {
385+
try {
386+
String expected = String.join("", Files.readAllLines(source));
387+
String result = String.join("", Files.readAllLines(destination));
388+
389+
assertEquals(expected, result);
390+
channel.close();
391+
latch.countDown();
392+
393+
}
394+
catch (IOException e) {
395+
fail(e.getMessage());
396+
}
397+
});
398+
399+
latch.await();
400+
}
401+
228402
@Test
229403
public void takeUntilByteCount() {
230404

@@ -314,7 +488,8 @@ public void SPR16070() throws Exception {
314488
.thenAnswer(putByte('c'))
315489
.thenReturn(-1);
316490

317-
Flux<DataBuffer> read = DataBufferUtils.readByteChannel(() -> channel, this.bufferFactory, 1);
491+
Flux<DataBuffer> read =
492+
DataBufferUtils.readByteChannel(() -> channel, this.bufferFactory, 1);
318493

319494
StepVerifier.create(read)
320495
.consumeNextWith(stringConsumer("a"))
@@ -343,9 +518,10 @@ public void join() {
343518

344519
StepVerifier.create(result)
345520
.consumeNextWith(dataBuffer -> {
346-
assertEquals("foobarbaz", DataBufferTestUtils.dumpString(dataBuffer, StandardCharsets.UTF_8));
347-
release(dataBuffer);
348-
})
521+
assertEquals("foobarbaz",
522+
DataBufferTestUtils.dumpString(dataBuffer, StandardCharsets.UTF_8));
523+
release(dataBuffer);
524+
})
349525
.verifyComplete();
350526
}
351527

0 commit comments

Comments
 (0)