Skip to content

Commit 149d416

Browse files
committed
Review DataBufferUtils for error/cancellation memory leaks
Issue: SPR-17408
1 parent 1621125 commit 149d416

File tree

2 files changed

+68
-191
lines changed

2 files changed

+68
-191
lines changed

spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java

Lines changed: 42 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import java.util.concurrent.atomic.AtomicLong;
3434
import java.util.concurrent.atomic.AtomicReference;
3535
import java.util.function.Consumer;
36-
import java.util.function.IntPredicate;
3736

3837
import org.reactivestreams.Publisher;
3938
import org.reactivestreams.Subscription;
@@ -334,14 +333,23 @@ private static void closeChannel(@Nullable Channel channel) {
334333
public static Flux<DataBuffer> takeUntilByteCount(Publisher<DataBuffer> publisher, long maxByteCount) {
335334
Assert.notNull(publisher, "Publisher must not be null");
336335
Assert.isTrue(maxByteCount >= 0, "'maxByteCount' must be a positive number");
337-
AtomicLong countDown = new AtomicLong(maxByteCount);
338336

339-
return Flux.from(publisher)
340-
.map(buffer -> {
341-
long count = countDown.addAndGet(-buffer.readableByteCount());
342-
return count >= 0 ? buffer : buffer.slice(0, buffer.readableByteCount() + (int) count);
343-
})
344-
.takeUntil(buffer -> countDown.get() <= 0);
337+
return Flux.defer(() -> {
338+
AtomicLong countDown = new AtomicLong(maxByteCount);
339+
340+
return Flux.from(publisher)
341+
.map(buffer -> {
342+
long remainder = countDown.addAndGet(-buffer.readableByteCount());
343+
if (remainder < 0) {
344+
int length = buffer.readableByteCount() + (int) remainder;
345+
return buffer.slice(0, length);
346+
}
347+
else {
348+
return buffer;
349+
}
350+
})
351+
.takeUntil(buffer -> countDown.get() <= 0);
352+
}); // no doOnDiscard necessary, as this method does not drop buffers
345353
}
346354

347355
/**
@@ -355,26 +363,28 @@ public static Flux<DataBuffer> takeUntilByteCount(Publisher<DataBuffer> publishe
355363
public static Flux<DataBuffer> skipUntilByteCount(Publisher<DataBuffer> publisher, long maxByteCount) {
356364
Assert.notNull(publisher, "Publisher must not be null");
357365
Assert.isTrue(maxByteCount >= 0, "'maxByteCount' must be a positive number");
358-
AtomicLong byteCountDown = new AtomicLong(maxByteCount);
359-
360-
return Flux.from(publisher)
361-
.skipUntil(buffer -> {
362-
int delta = -buffer.readableByteCount();
363-
if (byteCountDown.addAndGet(delta) >= 0) {
364-
DataBufferUtils.release(buffer);
365-
return false;
366-
}
367-
return true;
368-
})
369-
.map(buffer -> {
370-
long count = byteCountDown.get();
371-
if (count < 0) {
372-
int skipCount = buffer.readableByteCount() + (int) count;
373-
byteCountDown.set(0);
374-
return buffer.slice(skipCount, buffer.readableByteCount() - skipCount);
375-
}
376-
return buffer;
377-
});
366+
367+
return Flux.defer(() -> {
368+
AtomicLong countDown = new AtomicLong(maxByteCount);
369+
370+
return Flux.from(publisher)
371+
.skipUntil(buffer -> {
372+
long remainder = countDown.addAndGet(-buffer.readableByteCount());
373+
return remainder < 0;
374+
})
375+
.map(buffer -> {
376+
long remainder = countDown.get();
377+
if (remainder < 0) {
378+
countDown.set(0);
379+
int start = buffer.readableByteCount() + (int)remainder;
380+
int length = (int) -remainder;
381+
return buffer.slice(start, length);
382+
}
383+
else {
384+
return buffer;
385+
}
386+
});
387+
}).doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
378388
}
379389

380390
/**
@@ -432,24 +442,14 @@ public static Mono<DataBuffer> join(Publisher<DataBuffer> dataBuffers) {
432442
Assert.notNull(dataBuffers, "'dataBuffers' must not be null");
433443

434444
return Flux.from(dataBuffers)
435-
.onErrorResume(DataBufferUtils::exceptionDataBuffer)
436445
.collectList()
437446
.filter(list -> !list.isEmpty())
438-
.flatMap(list -> {
439-
for (int i = 0; i < list.size(); i++) {
440-
DataBuffer dataBuffer = list.get(i);
441-
if (dataBuffer instanceof ExceptionDataBuffer) {
442-
list.subList(0, i).forEach(DataBufferUtils::release);
443-
return Mono.error(((ExceptionDataBuffer) dataBuffer).throwable());
444-
}
445-
}
447+
.map(list -> {
446448
DataBufferFactory bufferFactory = list.get(0).factory();
447-
return Mono.just(bufferFactory.join(list));
448-
});
449-
}
449+
return bufferFactory.join(list);
450+
})
451+
.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
450452

451-
private static Mono<DataBuffer> exceptionDataBuffer(Throwable throwable) {
452-
return Mono.just(new ExceptionDataBuffer(throwable));
453453
}
454454

455455

@@ -638,153 +638,4 @@ private void sinkDataBuffer() {
638638
}
639639
}
640640

641-
/**
642-
* DataBuffer implementation that holds a {@link Throwable}, used in {@link #join(Publisher)}.
643-
*/
644-
private static final class ExceptionDataBuffer implements DataBuffer {
645-
646-
private final Throwable throwable;
647-
648-
649-
public ExceptionDataBuffer(Throwable throwable) {
650-
this.throwable = throwable;
651-
}
652-
653-
public Throwable throwable() {
654-
return this.throwable;
655-
}
656-
657-
// Unsupported
658-
659-
@Override
660-
public DataBufferFactory factory() {
661-
throw new UnsupportedOperationException();
662-
}
663-
664-
@Override
665-
public int indexOf(IntPredicate predicate, int fromIndex) {
666-
throw new UnsupportedOperationException();
667-
}
668-
669-
@Override
670-
public int lastIndexOf(IntPredicate predicate, int fromIndex) {
671-
throw new UnsupportedOperationException();
672-
}
673-
674-
@Override
675-
public int readableByteCount() {
676-
throw new UnsupportedOperationException();
677-
}
678-
679-
@Override
680-
public int writableByteCount() {
681-
throw new UnsupportedOperationException();
682-
}
683-
684-
@Override
685-
public int capacity() {
686-
throw new UnsupportedOperationException();
687-
}
688-
689-
@Override
690-
public DataBuffer capacity(int capacity) {
691-
throw new UnsupportedOperationException();
692-
}
693-
694-
@Override
695-
public int readPosition() {
696-
throw new UnsupportedOperationException();
697-
}
698-
699-
@Override
700-
public DataBuffer readPosition(int readPosition) {
701-
throw new UnsupportedOperationException();
702-
}
703-
704-
@Override
705-
public int writePosition() {
706-
throw new UnsupportedOperationException();
707-
}
708-
709-
@Override
710-
public DataBuffer writePosition(int writePosition) {
711-
throw new UnsupportedOperationException();
712-
}
713-
714-
@Override
715-
public byte getByte(int index) {
716-
throw new UnsupportedOperationException();
717-
}
718-
719-
@Override
720-
public byte read() {
721-
throw new UnsupportedOperationException();
722-
}
723-
724-
@Override
725-
public DataBuffer read(byte[] destination) {
726-
throw new UnsupportedOperationException();
727-
}
728-
729-
@Override
730-
public DataBuffer read(byte[] destination, int offset, int length) {
731-
throw new UnsupportedOperationException();
732-
}
733-
734-
@Override
735-
public DataBuffer write(byte b) {
736-
throw new UnsupportedOperationException();
737-
}
738-
739-
@Override
740-
public DataBuffer write(byte[] source) {
741-
throw new UnsupportedOperationException();
742-
}
743-
744-
@Override
745-
public DataBuffer write(byte[] source, int offset, int length) {
746-
throw new UnsupportedOperationException();
747-
}
748-
749-
@Override
750-
public DataBuffer write(DataBuffer... buffers) {
751-
throw new UnsupportedOperationException();
752-
}
753-
754-
@Override
755-
public DataBuffer write(ByteBuffer... buffers) {
756-
throw new UnsupportedOperationException();
757-
}
758-
759-
@Override
760-
public DataBuffer slice(int index, int length) {
761-
throw new UnsupportedOperationException();
762-
}
763-
764-
@Override
765-
public ByteBuffer asByteBuffer() {
766-
throw new UnsupportedOperationException();
767-
}
768-
769-
@Override
770-
public ByteBuffer asByteBuffer(int index, int length) {
771-
throw new UnsupportedOperationException();
772-
}
773-
774-
@Override
775-
public InputStream asInputStream() {
776-
throw new UnsupportedOperationException();
777-
}
778-
779-
@Override
780-
public InputStream asInputStream(boolean releaseOnClose) {
781-
throw new UnsupportedOperationException();
782-
}
783-
784-
@Override
785-
public OutputStream asOutputStream() {
786-
throw new UnsupportedOperationException();
787-
}
788-
}
789-
790641
}

spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,20 @@ public void takeUntilByteCount() {
412412
.verify(Duration.ofSeconds(5));
413413
}
414414

415+
@Test
416+
public void takeUntilByteCountErrorInFlux() {
417+
DataBuffer foo = stringBuffer("foo");
418+
Flux<DataBuffer> flux =
419+
Flux.just(foo).concatWith(Mono.error(new RuntimeException()));
420+
421+
Flux<DataBuffer> result = DataBufferUtils.takeUntilByteCount(flux, 5L);
422+
423+
StepVerifier.create(result)
424+
.consumeNextWith(stringConsumer("foo"))
425+
.expectError(RuntimeException.class)
426+
.verify(Duration.ofSeconds(5));
427+
}
428+
415429
@Test
416430
public void takeUntilByteCountExact() {
417431

@@ -444,6 +458,18 @@ public void skipUntilByteCount() {
444458
.verify(Duration.ofSeconds(5));
445459
}
446460

461+
@Test
462+
public void skipUntilByteCountErrorInFlux() {
463+
DataBuffer foo = stringBuffer("foo");
464+
Flux<DataBuffer> flux =
465+
Flux.just(foo).concatWith(Mono.error(new RuntimeException()));
466+
Flux<DataBuffer> result = DataBufferUtils.skipUntilByteCount(flux, 3L);
467+
468+
StepVerifier.create(result)
469+
.expectError(RuntimeException.class)
470+
.verify(Duration.ofSeconds(5));
471+
}
472+
447473
@Test
448474
public void skipUntilByteCountShouldSkipAll() {
449475
DataBuffer foo = stringBuffer("foo");

0 commit comments

Comments
 (0)