diff --git a/test/framework/src/main/java/org/elasticsearch/indices/CrankyCircuitBreakerService.java b/test/framework/src/main/java/org/elasticsearch/indices/CrankyCircuitBreakerService.java index 52cd1a4ef652c..529af99e6dd4d 100644 --- a/test/framework/src/main/java/org/elasticsearch/indices/CrankyCircuitBreakerService.java +++ b/test/framework/src/main/java/org/elasticsearch/indices/CrankyCircuitBreakerService.java @@ -29,7 +29,7 @@ public class CrankyCircuitBreakerService extends CircuitBreakerService { */ public static final String ERROR_MESSAGE = "cranky breaker"; - private final CircuitBreaker breaker = new CircuitBreaker() { + public static final class CrankyCircuitBreaker implements CircuitBreaker { private final AtomicLong used = new AtomicLong(); @Override @@ -82,7 +82,9 @@ public Durability getDurability() { public void setLimitAndOverhead(long limit, double overhead) { } - }; + } + + private final CrankyCircuitBreaker breaker = new CrankyCircuitBreaker(); @Override public CircuitBreaker getBreaker(String name) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/LocalCircuitBreaker.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/LocalCircuitBreaker.java index 2b11b06a74702..cf570415498bb 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/LocalCircuitBreaker.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/LocalCircuitBreaker.java @@ -28,6 +28,7 @@ public final class LocalCircuitBreaker implements CircuitBreaker, Releasable { private final long maxOverReservedBytes; private long reservedBytes; private final AtomicBoolean closed = new AtomicBoolean(false); + private volatile Thread activeThread; public record SizeSettings(long overReservedBytes, long maxOverReservedBytes) { public SizeSettings(Settings settings) { @@ -57,6 +58,7 @@ public void circuitBreak(String fieldName, long bytesNeeded) { @Override public void addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException { + assert assertSingleThread(); if (bytes <= reservedBytes) { reservedBytes -= bytes; maybeReduceReservedBytes(); @@ -68,6 +70,7 @@ public void addEstimateBytesAndMaybeBreak(long bytes, String label) throws Circu @Override public void addWithoutBreaking(long bytes) { + assert assertSingleThread(); if (bytes <= reservedBytes) { reservedBytes -= bytes; maybeReduceReservedBytes(); @@ -130,6 +133,7 @@ public void setLimitAndOverhead(long limit, double overhead) { @Override public void close() { + assert assertSingleThread(); if (closed.compareAndSet(false, true)) { breaker.addWithoutBreaking(-reservedBytes); } @@ -139,4 +143,34 @@ public void close() { public String toString() { return "LocalCircuitBreaker[" + reservedBytes + "/" + overReservedBytes + ":" + maxOverReservedBytes + "]"; } + + private boolean assertSingleThread() { + Thread activeThread = this.activeThread; + Thread currentThread = Thread.currentThread(); + assert activeThread == null || activeThread == currentThread + : "Local breaker must be accessed by a single thread at a time: expected [" + + activeThread + + "] != actual [" + + currentThread + + "]"; + return true; + } + + /** + * Marks the beginning of a run loop for assertion purposes. + * Sets the current thread as the only thread allowed to access this breaker. + */ + public boolean assertBeginRunLoop() { + activeThread = Thread.currentThread(); + return true; + } + + /** + * Marks the end of a run loop for assertion purposes. + * Clears the active thread to allow other threads to access this breaker. + */ + public boolean assertEndRunLoop() { + activeThread = null; + return true; + } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java index 631006a1cf8f3..ef7eef4c111bf 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java @@ -181,10 +181,13 @@ SubscribableListener run(TimeValue maxTime, int maxIterations, LongSupplie while (true) { IsBlockedResult isBlocked = Operator.NOT_BLOCKED; try { + assert driverContext.assertBeginRunLoop(); isBlocked = runSingleLoopIteration(); } catch (DriverEarlyTerminationException unused) { closeEarlyFinishedOperators(); assert isFinished() : "not finished after early termination"; + } finally { + assert driverContext.assertEndRunLoop(); } totalIterationsThisRun++; iterationsSinceLastStatusUpdate++; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java index 1877f564677ba..f26a1d11ad059 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java @@ -10,9 +10,9 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.common.breaker.CircuitBreaker; -import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.LocalCircuitBreaker; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; @@ -74,14 +74,6 @@ private DriverContext(BigArrays bigArrays, BlockFactory blockFactory, WarningsMo this.warningsMode = warningsMode; } - public static DriverContext getLocalDriver() { - return new DriverContext( - BigArrays.NON_RECYCLING_INSTANCE, - // TODO maybe this should have a small fixed limit? - new BlockFactory(new NoopCircuitBreaker(CircuitBreaker.REQUEST), BigArrays.NON_RECYCLING_INSTANCE) - ); - } - public BigArrays bigArrays() { return bigArrays; } @@ -208,6 +200,26 @@ public enum WarningsMode { IGNORE } + /** + * Marks the beginning of a run loop for assertion purposes. + */ + public boolean assertBeginRunLoop() { + if (blockFactory.breaker() instanceof LocalCircuitBreaker localBreaker) { + assert localBreaker.assertBeginRunLoop(); + } + return true; + } + + /** + * Marks the end of a run loop for assertion purposes. + */ + public boolean assertEndRunLoop() { + if (blockFactory.breaker() instanceof LocalCircuitBreaker localBreaker) { + assert localBreaker.assertEndRunLoop(); + } + return true; + } + private static class AsyncActions { private final SubscribableListener completion = new SubscribableListener<>(); private final AtomicBoolean finished = new AtomicBoolean(); diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/OperatorTestCase.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/OperatorTestCase.java index 698ae339060db..56ae2fb4119a8 100644 --- a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/OperatorTestCase.java +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/OperatorTestCase.java @@ -23,6 +23,7 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.AsyncOperator; import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.DriverRunner; @@ -247,9 +248,20 @@ public void testSimpleFinishClose() { try (var operator = simple().get(driverContext)) { assert operator.needsInput(); for (Page page : input) { - operator.addInput(page); + if (operator.needsInput()) { + operator.addInput(page); + } else { + page.releaseBlocks(); + } } operator.finish(); + // for async operator, we need to wait for async actions to finish. + if (operator instanceof AsyncOperator || randomBoolean()) { + driverContext.finish(); + PlainActionFuture waitForAsync = new PlainActionFuture<>(); + driverContext.waitForAsyncActions(waitForAsync); + waitForAsync.actionGet(TimeValue.timeValueSeconds(30)); + } } } diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TestDriverFactory.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TestDriverFactory.java index bd1d4d5fc53d1..f3a0d4b7d0797 100644 --- a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TestDriverFactory.java +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TestDriverFactory.java @@ -7,12 +7,18 @@ package org.elasticsearch.compute.test; +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.LocalCircuitBreaker; import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.SinkOperator; import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.indices.CrankyCircuitBreakerService; import java.util.List; @@ -38,6 +44,20 @@ public static Driver create( SinkOperator sink, Releasable releasable ) { + // Do not wrap the local breaker for small local breakers, as the output mights not match expectations. + if (driverContext.breaker() instanceof CrankyCircuitBreakerService.CrankyCircuitBreaker == false + && driverContext.breaker() instanceof LocalCircuitBreaker == false + && driverContext.breaker().getLimit() >= ByteSizeValue.ofMb(100).getBytes() + && Randomness.get().nextBoolean()) { + final int overReservedBytes = Randomness.get().nextInt(1024 * 1024); + final int maxOverReservedBytes = overReservedBytes + Randomness.get().nextInt(1024 * 1024); + var localBreaker = new LocalCircuitBreaker(driverContext.breaker(), overReservedBytes, maxOverReservedBytes); + BlockFactory localBlockFactory = driverContext.blockFactory().newChildFactory(localBreaker); + driverContext = new DriverContext(localBlockFactory.bigArrays(), localBlockFactory); + } + if (driverContext.breaker() instanceof LocalCircuitBreaker localBreaker) { + releasable = Releasables.wrap(releasable, localBreaker); + } return new Driver( "unset", "test-task", diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java index 00f4261b4707a..0b5d0384c56c1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java @@ -19,6 +19,7 @@ import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; @@ -26,7 +27,7 @@ import java.util.List; -public class RerankOperator extends AsyncOperator { +public class RerankOperator extends AsyncOperator { // Move to a setting. private static final int MAX_INFERENCE_WORKER = 10; @@ -85,20 +86,16 @@ public RerankOperator( } @Override - protected void performAsync(Page inputPage, ActionListener listener) { + protected void performAsync(Page inputPage, ActionListener listener) { // Ensure input page blocks are released when the listener is called. - final ActionListener outputListener = ActionListener.runAfter(listener, () -> { releasePageOnAnyThread(inputPage); }); - + listener = listener.delegateResponse((l, e) -> { + releasePageOnAnyThread(inputPage); + l.onFailure(e); + }); try { - inferenceRunner.doInference( - buildInferenceRequest(inputPage), - ActionListener.wrap( - inferenceResponse -> outputListener.onResponse(buildOutput(inputPage, inferenceResponse)), - outputListener::onFailure - ) - ); + inferenceRunner.doInference(buildInferenceRequest(inputPage), listener.map(resp -> new OngoingRerank(inputPage, resp))); } catch (Exception e) { - outputListener.onFailure(e); + listener.onFailure(e); } } @@ -108,13 +105,18 @@ protected void doClose() { } @Override - protected void releaseFetchedOnAnyThread(Page page) { - releasePageOnAnyThread(page); + protected void releaseFetchedOnAnyThread(OngoingRerank result) { + releasePageOnAnyThread(result.inputPage); } @Override public Page getOutput() { - return fetchFromBuffer(); + var fetched = fetchFromBuffer(); + if (fetched == null) { + return null; + } else { + return fetched.buildOutput(blockFactory, scoreChannel); + } } @Override @@ -122,77 +124,87 @@ public String toString() { return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]"; } - private Page buildOutput(Page inputPage, InferenceAction.Response inferenceResponse) { - if (inferenceResponse.getResults() instanceof RankedDocsResults rankedDocsResults) { - return buildOutput(inputPage, rankedDocsResults); - - } - - throw new IllegalStateException( - "Inference result has wrong type. Got [" - + inferenceResponse.getResults().getClass() - + "] while expecting [" - + RankedDocsResults.class - + "]" - ); - } - - private Page buildOutput(Page inputPage, RankedDocsResults rankedDocsResults) { - int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1); - Block[] blocks = new Block[blockCount]; + private InferenceAction.Request buildInferenceRequest(Page inputPage) { + try (BytesRefBlock encodedRowsBlock = (BytesRefBlock) rowEncoder.eval(inputPage)) { + assert (encodedRowsBlock.getPositionCount() == inputPage.getPositionCount()); + String[] inputs = new String[inputPage.getPositionCount()]; + BytesRef buffer = new BytesRef(); - try { - for (int b = 0; b < blockCount; b++) { - if (b == scoreChannel) { - blocks[b] = buildScoreBlock(inputPage, rankedDocsResults); + for (int pos = 0; pos < inputPage.getPositionCount(); pos++) { + if (encodedRowsBlock.isNull(pos)) { + inputs[pos] = ""; } else { - blocks[b] = inputPage.getBlock(b); - blocks[b].incRef(); + buffer = encodedRowsBlock.getBytesRef(encodedRowsBlock.getFirstValueIndex(pos), buffer); + inputs[pos] = BytesRefs.toString(buffer); } } - return new Page(blocks); - } catch (Exception e) { - Releasables.closeExpectNoException(blocks); - throw (e); + + return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build(); } } - private Block buildScoreBlock(Page inputPage, RankedDocsResults rankedDocsResults) { - Double[] sortedRankedDocsScores = new Double[inputPage.getPositionCount()]; + public static final class OngoingRerank { + final Page inputPage; + final Double[] rankedScores; + + OngoingRerank(Page inputPage, InferenceAction.Response resp) { + if (resp.getResults() instanceof RankedDocsResults == false) { + releasePageOnAnyThread(inputPage); + throw new IllegalStateException( + "Inference result has wrong type. Got [" + + resp.getResults().getClass() + + "] while expecting [" + + RankedDocsResults.class + + "]" + ); - try (DoubleBlock.Builder scoreBlockFactory = blockFactory.newDoubleBlockBuilder(inputPage.getPositionCount())) { + } + final var results = (RankedDocsResults) resp.getResults(); + this.inputPage = inputPage; + this.rankedScores = extractRankedScores(inputPage.getPositionCount(), results); + } + + private static Double[] extractRankedScores(int positionCount, RankedDocsResults rankedDocsResults) { + Double[] sortedRankedDocsScores = new Double[positionCount]; for (RankedDocsResults.RankedDoc rankedDoc : rankedDocsResults.getRankedDocs()) { sortedRankedDocsScores[rankedDoc.index()] = (double) rankedDoc.relevanceScore(); } + return sortedRankedDocsScores; + } - for (int pos = 0; pos < inputPage.getPositionCount(); pos++) { - if (sortedRankedDocsScores[pos] != null) { - scoreBlockFactory.appendDouble(sortedRankedDocsScores[pos]); - } else { - scoreBlockFactory.appendNull(); + Page buildOutput(BlockFactory blockFactory, int scoreChannel) { + int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1); + Block[] blocks = new Block[blockCount]; + Page outputPage = null; + try (Releasable ignored = inputPage::releaseBlocks) { + for (int b = 0; b < blockCount; b++) { + if (b == scoreChannel) { + blocks[b] = buildScoreBlock(blockFactory); + } else { + blocks[b] = inputPage.getBlock(b); + blocks[b].incRef(); + } + } + outputPage = new Page(blocks); + return outputPage; + } finally { + if (outputPage == null) { + Releasables.closeExpectNoException(blocks); } } - - return scoreBlockFactory.build(); } - } - - private InferenceAction.Request buildInferenceRequest(Page inputPage) { - try (BytesRefBlock encodedRowsBlock = (BytesRefBlock) rowEncoder.eval(inputPage)) { - assert (encodedRowsBlock.getPositionCount() == inputPage.getPositionCount()); - String[] inputs = new String[inputPage.getPositionCount()]; - BytesRef buffer = new BytesRef(); - for (int pos = 0; pos < inputPage.getPositionCount(); pos++) { - if (encodedRowsBlock.isNull(pos)) { - inputs[pos] = ""; - } else { - buffer = encodedRowsBlock.getBytesRef(encodedRowsBlock.getFirstValueIndex(pos), buffer); - inputs[pos] = BytesRefs.toString(buffer); + private Block buildScoreBlock(BlockFactory blockFactory) { + try (DoubleBlock.Builder scoreBlockFactory = blockFactory.newDoubleBlockBuilder(rankedScores.length)) { + for (Double rankedScore : rankedScores) { + if (rankedScore != null) { + scoreBlockFactory.appendDouble(rankedScore); + } else { + scoreBlockFactory.appendNull(); + } } + return scoreBlockFactory.build(); } - - return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build(); } } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/RerankOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/RerankOperatorTests.java index c24602cd74d9b..b1335901361b2 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/RerankOperatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/RerankOperatorTests.java @@ -30,6 +30,7 @@ import org.elasticsearch.compute.test.OperatorTestCase; import org.elasticsearch.compute.test.RandomBlock; import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; @@ -97,16 +98,23 @@ private InferenceRunner mockedSimpleInferenceRunner() { InferenceRunner inferenceRunner = mock(InferenceRunner.class); when(inferenceRunner.getThreadContext()).thenReturn(threadPool.getThreadContext()); doAnswer(invocation -> { - @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocation.getArgument( - 1, - ActionListener.class - ); - InferenceAction.Response inferenceResponse = mock(InferenceAction.Response.class); - when(inferenceResponse.getResults()).thenReturn( - mockedRankedDocResults(invocation.getArgument(0, InferenceAction.Request.class)) - ); - listener.onResponse(inferenceResponse); + Runnable sendResponse = () -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArgument( + 1, + ActionListener.class + ); + InferenceAction.Response inferenceResponse = mock(InferenceAction.Response.class); + when(inferenceResponse.getResults()).thenReturn( + mockedRankedDocResults(invocation.getArgument(0, InferenceAction.Request.class)) + ); + listener.onResponse(inferenceResponse); + }; + if (randomBoolean()) { + sendResponse.run(); + } else { + threadPool.schedule(sendResponse, TimeValue.timeValueNanos(between(1, 1_000)), threadPool.executor(ESQL_TEST_EXECUTOR)); + } return null; }).when(inferenceRunner).doInference(any(), any()); @@ -137,7 +145,8 @@ protected Matcher expectedToStringOfSimple() { @Override protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { - return new AbstractBlockSourceOperator(blockFactory, 8 * 1024) { + final int minPageSize = Math.max(1, size / 100); + return new AbstractBlockSourceOperator(blockFactory, between(minPageSize, size)) { @Override protected int remaining() { return size - currentPosition;