diff --git a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/GenerativeForkIT.java b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/GenerativeForkIT.java new file mode 100644 index 0000000000000..d95cd0aecda0c --- /dev/null +++ b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/GenerativeForkIT.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.qa.single_node; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters; + +import org.elasticsearch.test.TestClustersThreadFilter; +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.xpack.esql.CsvSpecReader; +import org.elasticsearch.xpack.esql.qa.rest.generative.GenerativeForkRestTest; +import org.junit.ClassRule; + +@ThreadLeakFilters(filters = TestClustersThreadFilter.class) +public class GenerativeForkIT extends GenerativeForkRestTest { + @ClassRule + public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> spec.plugin("inference-service-test")); + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } + + public GenerativeForkIT( + String fileName, + String groupName, + String testName, + Integer lineNumber, + CsvSpecReader.CsvTestCase testCase, + String instructions, + Mode mode + ) { + super(fileName, groupName, testName, lineNumber, testCase, instructions, mode); + } + + @Override + protected boolean enableRoundingDoubleValuesOnAsserting() { + // This suite runs with more than one node and three shards in serverless + return cluster.getNumNodes() > 1; + } + + @Override + protected boolean supportsSourceFieldMapping() { + return cluster.getNumNodes() == 1; + } +} diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java index 69df40899a0a8..e4c8b67d4eb72 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java @@ -259,15 +259,19 @@ protected boolean supportsSourceFieldMapping() throws IOException { return true; } - protected final void doTest() throws Throwable { + protected void doTest() throws Throwable { + doTest(testCase.query); + } + + protected final void doTest(String query) throws Throwable { RequestObjectBuilder builder = new RequestObjectBuilder(randomFrom(XContentType.values())); - if (testCase.query.toUpperCase(Locale.ROOT).contains("LOOKUP_\uD83D\uDC14")) { + if (query.toUpperCase(Locale.ROOT).contains("LOOKUP_\uD83D\uDC14")) { builder.tables(tables()); } Map prevTooks = supportsTook() ? tooks() : null; - Map answer = runEsql(builder.query(testCase.query), testCase.assertWarnings(deduplicateExactWarnings())); + Map answer = runEsql(builder.query(query), testCase.assertWarnings(deduplicateExactWarnings())); var expectedColumnsWithValues = loadCsvSpecValues(testCase.expectedResults); diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeForkRestTest.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeForkRestTest.java new file mode 100644 index 0000000000000..9cfbc7e69b6c5 --- /dev/null +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeForkRestTest.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.qa.rest.generative; + +import org.elasticsearch.xpack.esql.CsvSpecReader; +import org.elasticsearch.xpack.esql.qa.rest.EsqlSpecTestCase; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.*; + +/** + * Tests for FORK. We generate tests for FORK from existing CSV tests. + * We append a `| FORK (WHERE true) (WHERE true) | WHERE _fork == "fork1" | DROP _fork` suffix to existing + * CSV test cases. This will produce a query that executes multiple FORK branches but expects the same results + * as the initial CSV test case. + * For now, we skip tests that already require FORK, since multiple FORK commands are not allowed. + */ +public abstract class GenerativeForkRestTest extends EsqlSpecTestCase { + public GenerativeForkRestTest( + String fileName, + String groupName, + String testName, + Integer lineNumber, + CsvSpecReader.CsvTestCase testCase, + String instructions, + Mode mode + ) { + super(fileName, groupName, testName, lineNumber, testCase, instructions, mode); + } + + @Override + protected void doTest() throws Throwable { + String query = testCase.query + " | FORK (WHERE true) (WHERE true) | WHERE _fork == \"fork1\" | DROP _fork"; + doTest(query); + } + + @Override + protected void shouldSkipTest(String testName) throws IOException { + super.shouldSkipTest(testName); + + assumeFalse( + "Tests using FORK or RRF already are skipped since we don't support multiple FORKs", + testCase.requiredCapabilities.contains(FORK_V7.capabilityName()) || testCase.requiredCapabilities.contains(RRF.capabilityName()) + ); + + assumeFalse( + "Tests using INSIST are not supported for now", + testCase.requiredCapabilities.contains(UNMAPPED_FIELDS.capabilityName()) + ); + + assumeFalse( + "Tests using implicit_casting_date_and_date_nanos are not supported for now", + testCase.requiredCapabilities.contains(IMPLICIT_CASTING_DATE_AND_DATE_NANOS.capabilityName()) + ); + + assumeTrue("Cluster needs to support FORK", hasCapabilities(client(), List.of(FORK_V7.capabilityName()))); + } +} diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ForkIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ForkIT.java index 86051e7e4164d..4860740e7babc 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ForkIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ForkIT.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.compute.operator.DriverProfile; +import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.xpack.esql.VerificationException; import org.elasticsearch.xpack.esql.parser.ParsingException; import org.junit.Before; @@ -26,13 +27,13 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.getValuesList; import static org.hamcrest.Matchers.equalTo; -// @TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug") +@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug") public class ForkIT extends AbstractEsqlIntegTestCase { @Before public void setupIndex() { assumeTrue("requires FORK capability", EsqlCapabilities.Cap.FORK.isEnabled()); - createAndPopulateIndex(); + createAndPopulateIndices(); } public void testSimple() { @@ -706,6 +707,52 @@ public void testWithLookUpAfterFork() { } } + public void testWithUnionTypesBeforeFork() { + var query = """ + FROM test,test-other + | EVAL x = id::keyword + | EVAL id = id::keyword + | EVAL content = content::keyword + | FORK (WHERE x == "2") + (WHERE x == "1") + | SORT _fork, x, content + | KEEP content, id, x, _fork + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("content", "id", "x", "_fork")); + Iterable> expectedValues = List.of( + List.of("This is a brown dog", "2", "2", "fork1"), + List.of("This is a brown dog", "2", "2", "fork1"), + List.of("This is a brown fox", "1", "1", "fork2"), + List.of("This is a brown fox", "1", "1", "fork2") + ); + assertValues(resp.values(), expectedValues); + } + } + + public void testWithUnionTypesInBranches() { + var query = """ + FROM test,test-other + | EVAL content = content::keyword + | FORK (EVAL x = id::keyword | WHERE x == "2" | EVAL id = x::integer) + (EVAL x = "a" | WHERE id::keyword == "1" | EVAL id = id::integer) + | SORT _fork, x + | KEEP content, id, x, _fork + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("content", "id", "x", "_fork")); + Iterable> expectedValues = List.of( + List.of("This is a brown dog", 2, "2", "fork1"), + List.of("This is a brown dog", 2, "2", "fork1"), + List.of("This is a brown fox", 1, "a", "fork2"), + List.of("This is a brown fox", 1, "a", "fork2") + ); + assertValues(resp.values(), expectedValues); + } + } + public void testWithEvalWithConflictingTypes() { var query = """ FROM test @@ -833,7 +880,7 @@ public void testProfile() { } } - private void createAndPopulateIndex() { + private void createAndPopulateIndices() { var indexName = "test"; var client = client().admin().indices(); var createRequest = client.prepareCreate(indexName) @@ -867,6 +914,20 @@ private void createAndPopulateIndex() { .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .get(); ensureYellow(lookupIndex); + + var otherTestIndex = "test-other"; + + createRequest = client.prepareCreate(otherTestIndex) + .setSettings(Settings.builder().put("index.number_of_shards", 1)) + .setMapping("id", "type=keyword", "content", "type=keyword"); + assertAcked(createRequest); + client().prepareBulk() + .add(new IndexRequest(otherTestIndex).id("1").source("id", "1", "content", "This is a brown fox")) + .add(new IndexRequest(otherTestIndex).id("2").source("id", "2", "content", "This is a brown dog")) + .add(new IndexRequest(otherTestIndex).id("3").source("id", "3", "content", "This dog is really brown")) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get(); + ensureYellow(indexName); } static Iterator> valuesFilter(Iterator> values, Predicate> filter) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 259062cd14f57..200bb5b182588 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -1615,20 +1615,21 @@ record TypeResolutionKey(String fieldName, DataType fieldType) {} @Override public LogicalPlan apply(LogicalPlan plan) { unionFieldAttributes = new ArrayList<>(); + return plan.transformUp(LogicalPlan.class, p -> p.childrenResolved() == false ? p : doRule(p)); + } + + private LogicalPlan doRule(LogicalPlan plan) { + Holder alreadyAddedUnionFieldAttributes = new Holder<>(unionFieldAttributes.size()); // Collect field attributes from previous runs - plan.forEachUp(EsRelation.class, rel -> { + if (plan instanceof EsRelation rel) { + unionFieldAttributes.clear(); for (Attribute attr : rel.output()) { if (attr instanceof FieldAttribute fa && fa.field() instanceof MultiTypeEsField && fa.synthetic()) { unionFieldAttributes.add(fa); } } - }); - - return plan.transformUp(LogicalPlan.class, p -> p.childrenResolved() == false ? p : doRule(p)); - } + } - private LogicalPlan doRule(LogicalPlan plan) { - int alreadyAddedUnionFieldAttributes = unionFieldAttributes.size(); // See if the eval function has an unresolved MultiTypeEsField field // Replace the entire convert function with a new FieldAttribute (containing type conversion knowledge) plan = plan.transformExpressionsOnly(e -> { @@ -1637,8 +1638,9 @@ private LogicalPlan doRule(LogicalPlan plan) { } return e; }); + // If no union fields were generated, return the plan as is - if (unionFieldAttributes.size() == alreadyAddedUnionFieldAttributes) { + if (unionFieldAttributes.size() == alreadyAddedUnionFieldAttributes.get()) { return plan; }