Skip to content

ES|QL: Add FORK generative tests #129135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.qa.single_node;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;

import org.elasticsearch.test.TestClustersThreadFilter;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.xpack.esql.CsvSpecReader;
import org.elasticsearch.xpack.esql.qa.rest.generative.GenerativeForkRestTest;
import org.junit.ClassRule;

@ThreadLeakFilters(filters = TestClustersThreadFilter.class)
public class GenerativeForkIT extends GenerativeForkRestTest {
@ClassRule
public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> spec.plugin("inference-service-test"));

@Override
protected String getTestRestCluster() {
return cluster.getHttpAddresses();
}

public GenerativeForkIT(
String fileName,
String groupName,
String testName,
Integer lineNumber,
CsvSpecReader.CsvTestCase testCase,
String instructions,
Mode mode
) {
super(fileName, groupName, testName, lineNumber, testCase, instructions, mode);
}

@Override
protected boolean enableRoundingDoubleValuesOnAsserting() {
// This suite runs with more than one node and three shards in serverless
return cluster.getNumNodes() > 1;
}

@Override
protected boolean supportsSourceFieldMapping() {
return cluster.getNumNodes() == 1;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,19 @@ protected boolean supportsSourceFieldMapping() throws IOException {
return true;
}

protected final void doTest() throws Throwable {
protected void doTest() throws Throwable {
doTest(testCase.query);
}

protected final void doTest(String query) throws Throwable {
RequestObjectBuilder builder = new RequestObjectBuilder(randomFrom(XContentType.values()));

if (testCase.query.toUpperCase(Locale.ROOT).contains("LOOKUP_\uD83D\uDC14")) {
if (query.toUpperCase(Locale.ROOT).contains("LOOKUP_\uD83D\uDC14")) {
builder.tables(tables());
}

Map<?, ?> prevTooks = supportsTook() ? tooks() : null;
Map<String, Object> answer = runEsql(builder.query(testCase.query), testCase.assertWarnings(deduplicateExactWarnings()));
Map<String, Object> answer = runEsql(builder.query(query), testCase.assertWarnings(deduplicateExactWarnings()));

var expectedColumnsWithValues = loadCsvSpecValues(testCase.expectedResults);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.qa.rest.generative;

import org.elasticsearch.xpack.esql.CsvSpecReader;
import org.elasticsearch.xpack.esql.qa.rest.EsqlSpecTestCase;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.*;

/**
* Tests for FORK. We generate tests for FORK from existing CSV tests.
* We append a `| FORK (WHERE true) (WHERE true) | WHERE _fork == "fork1" | DROP _fork` suffix to existing
* CSV test cases. This will produce a query that executes multiple FORK branches but expects the same results
* as the initial CSV test case.
* For now, we skip tests that already require FORK, since multiple FORK commands are not allowed.
*/
public abstract class GenerativeForkRestTest extends EsqlSpecTestCase {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be obvious from the query, but maybe add a short description javadoc that summarises the intention.

public GenerativeForkRestTest(
String fileName,
String groupName,
String testName,
Integer lineNumber,
CsvSpecReader.CsvTestCase testCase,
String instructions,
Mode mode
) {
super(fileName, groupName, testName, lineNumber, testCase, instructions, mode);
}

@Override
protected void doTest() throws Throwable {
String query = testCase.query + " | FORK (WHERE true) (WHERE true) | WHERE _fork == \"fork1\" | DROP _fork";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

doTest(query);
}

@Override
protected void shouldSkipTest(String testName) throws IOException {
super.shouldSkipTest(testName);

assumeFalse(
"Tests using FORK or RRF already are skipped since we don't support multiple FORKs",
testCase.requiredCapabilities.contains(FORK_V7.capabilityName()) || testCase.requiredCapabilities.contains(RRF.capabilityName())
);

assumeFalse(
"Tests using INSIST are not supported for now",
testCase.requiredCapabilities.contains(UNMAPPED_FIELDS.capabilityName())
);

assumeFalse(
"Tests using implicit_casting_date_and_date_nanos are not supported for now",
testCase.requiredCapabilities.contains(IMPLICIT_CASTING_DATE_AND_DATE_NANOS.capabilityName())
);

assumeTrue("Cluster needs to support FORK", hasCapabilities(client(), List.of(FORK_V7.capabilityName())));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.compute.operator.DriverProfile;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.xpack.esql.VerificationException;
import org.elasticsearch.xpack.esql.parser.ParsingException;
import org.junit.Before;
Expand All @@ -26,13 +27,13 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getValuesList;
import static org.hamcrest.Matchers.equalTo;

// @TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug")
@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug")
public class ForkIT extends AbstractEsqlIntegTestCase {

@Before
public void setupIndex() {
assumeTrue("requires FORK capability", EsqlCapabilities.Cap.FORK.isEnabled());
createAndPopulateIndex();
createAndPopulateIndices();
}

public void testSimple() {
Expand Down Expand Up @@ -706,6 +707,52 @@ public void testWithLookUpAfterFork() {
}
}

public void testWithUnionTypesBeforeFork() {
var query = """
FROM test,test-other
| EVAL x = id::keyword
| EVAL id = id::keyword
| EVAL content = content::keyword
| FORK (WHERE x == "2")
(WHERE x == "1")
| SORT _fork, x, content
| KEEP content, id, x, _fork
""";

try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("content", "id", "x", "_fork"));
Iterable<Iterable<Object>> expectedValues = List.of(
List.of("This is a brown dog", "2", "2", "fork1"),
List.of("This is a brown dog", "2", "2", "fork1"),
List.of("This is a brown fox", "1", "1", "fork2"),
List.of("This is a brown fox", "1", "1", "fork2")
);
assertValues(resp.values(), expectedValues);
}
}

public void testWithUnionTypesInBranches() {
var query = """
FROM test,test-other
| EVAL content = content::keyword
| FORK (EVAL x = id::keyword | WHERE x == "2" | EVAL id = x::integer)
(EVAL x = "a" | WHERE id::keyword == "1" | EVAL id = id::integer)
| SORT _fork, x
| KEEP content, id, x, _fork
""";

try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("content", "id", "x", "_fork"));
Iterable<Iterable<Object>> expectedValues = List.of(
List.of("This is a brown dog", 2, "2", "fork1"),
List.of("This is a brown dog", 2, "2", "fork1"),
List.of("This is a brown fox", 1, "a", "fork2"),
List.of("This is a brown fox", 1, "a", "fork2")
);
assertValues(resp.values(), expectedValues);
}
}

public void testWithEvalWithConflictingTypes() {
var query = """
FROM test
Expand Down Expand Up @@ -833,7 +880,7 @@ public void testProfile() {
}
}

private void createAndPopulateIndex() {
private void createAndPopulateIndices() {
var indexName = "test";
var client = client().admin().indices();
var createRequest = client.prepareCreate(indexName)
Expand Down Expand Up @@ -867,6 +914,20 @@ private void createAndPopulateIndex() {
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get();
ensureYellow(lookupIndex);

var otherTestIndex = "test-other";

createRequest = client.prepareCreate(otherTestIndex)
.setSettings(Settings.builder().put("index.number_of_shards", 1))
.setMapping("id", "type=keyword", "content", "type=keyword");
assertAcked(createRequest);
client().prepareBulk()
.add(new IndexRequest(otherTestIndex).id("1").source("id", "1", "content", "This is a brown fox"))
.add(new IndexRequest(otherTestIndex).id("2").source("id", "2", "content", "This is a brown dog"))
.add(new IndexRequest(otherTestIndex).id("3").source("id", "3", "content", "This dog is really brown"))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get();
ensureYellow(indexName);
}

static Iterator<Iterator<Object>> valuesFilter(Iterator<Iterator<Object>> values, Predicate<Iterator<Object>> filter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1615,20 +1615,21 @@ record TypeResolutionKey(String fieldName, DataType fieldType) {}
@Override
public LogicalPlan apply(LogicalPlan plan) {
unionFieldAttributes = new ArrayList<>();
return plan.transformUp(LogicalPlan.class, p -> p.childrenResolved() == false ? p : doRule(p));
}

private LogicalPlan doRule(LogicalPlan plan) {
Holder<Integer> alreadyAddedUnionFieldAttributes = new Holder<>(unionFieldAttributes.size());
// Collect field attributes from previous runs
plan.forEachUp(EsRelation.class, rel -> {
if (plan instanceof EsRelation rel) {
unionFieldAttributes.clear();
for (Attribute attr : rel.output()) {
if (attr instanceof FieldAttribute fa && fa.field() instanceof MultiTypeEsField && fa.synthetic()) {
unionFieldAttributes.add(fa);
}
}
});

return plan.transformUp(LogicalPlan.class, p -> p.childrenResolved() == false ? p : doRule(p));
}
}

private LogicalPlan doRule(LogicalPlan plan) {
int alreadyAddedUnionFieldAttributes = unionFieldAttributes.size();
// See if the eval function has an unresolved MultiTypeEsField field
// Replace the entire convert function with a new FieldAttribute (containing type conversion knowledge)
plan = plan.transformExpressionsOnly(e -> {
Expand All @@ -1637,8 +1638,9 @@ private LogicalPlan doRule(LogicalPlan plan) {
}
return e;
});

// If no union fields were generated, return the plan as is
if (unionFieldAttributes.size() == alreadyAddedUnionFieldAttributes) {
if (unionFieldAttributes.size() == alreadyAddedUnionFieldAttributes.get()) {
return plan;
}

Expand Down