Skip to content

Commit f38404b

Browse files
authored
[KQP RBO] Add expression on aggregation. (#29129)
1 parent b76d011 commit f38404b

File tree

2 files changed

+61
-22
lines changed

2 files changed

+61
-22
lines changed

ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,7 @@ TExprNode::TPtr ReplacePgOps(TExprNode::TPtr input, TExprContext &ctx) {
321321
.Seal()
322322
.Build();
323323
// clnag-format on
324-
}
325-
else if (input->IsCallable()){
324+
} else if (input->IsCallable()){
326325
TVector<TExprNode::TPtr> newChildren;
327326
for (auto c : input->Children()) {
328327
newChildren.push_back(ReplacePgOps(c, ctx));
@@ -635,38 +634,50 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
635634
auto pgAgg = GetPgCallable(lambda.Body().Ptr(), "PgAgg");
636635
if (pgAgg) {
637636
// Collect original column names processing `PgAgg` callable.
638-
TVector<TInfoUnit> originalColNames;
639-
GetAllMembers(pgAgg, originalColNames);
640-
auto pgResolvedOp = GetPgCallable(lambda.Body().Ptr(), "PgResolvedOp");
641-
642-
auto originalColName = originalColNames.front();
643-
auto renamedColName = originalColName;
637+
TInfoUnit originalColName;
638+
TInfoUnit renamedColName;
644639

645-
if (pgResolvedOp) {
640+
auto lambdaBody = lambda.Body().Ptr();
641+
auto pgResolvedOp = GetPgCallable(lambda.Body().Ptr(), "PgResolvedOp");
642+
if (pgResolvedOp && !lambdaBody->IsCallable("PgResolvedOp")) {
643+
// Aggregation on expression f(a x b).
644+
// We pull expression outside a given aggregation and rename result of a given expression with unique name
645+
// to later process result with aggregate function.
646+
// For example: (a x b) as uqique_result_name -> f(unique_result_name)
646647
auto fromPg = ctx.NewCallable(node->Pos(), "FromPg", {pgResolvedOp});
648+
647649
// clang-format off
648650
auto exprLambda = Build<TCoLambda>(ctx, node->Pos())
649651
.Args(lambda.Args())
650652
.Body(fromPg)
651653
.Done().Ptr();
652654
// clang-format on
653655

654-
// Just any unique name for expression result, physical plan should be AsSturct(`unique_name (expression))
656+
// Just any unique name for expression result, physical plan should be AsSturct(`unique_name (expression)).
655657
originalColName = TInfoUnit(GenerateUniqueColumnName("_expr_"));
656658
renamedColName = originalColName;
657659
aggFieldsExpressionsMap.push_back({originalColName, exprLambda});
658660
} else {
659-
// Rename agg column we will add a map to map same column to different renames.
661+
// Either an aggregation f(a) or expression on aggregation f(a) x b.
662+
// Here we want to get just a column name for aggregation.
663+
Y_ENSURE(pgAgg->ChildrenSize() == 3, "Invalid children size for `PgAgg`");
664+
auto toPg = pgAgg->ChildPtr(2);
665+
Y_ENSURE(toPg->IsCallable("ToPg") && toPg->ChildPtr(0)->IsCallable("Member"), "PgAgg not a member");
666+
auto member = TCoMember(toPg->ChildPtr(0));
667+
originalColName = TInfoUnit(member.Name().StringValue());
668+
renamedColName = originalColName;
669+
670+
// Aggregation columns should be unique, so we have to add rename map.
671+
// For example f(a), g(a) => map((a -> a), (a -> a_0)) -> f(a), g(a_0).
660672
if (uniqueColumnNames.count(originalColName.GetFullName())) {
661673
renamedColName = TInfoUnit(originalColName.Alias, GenerateUniqueColumnName(originalColName.ColumnName));
662674
needRenameMap = true;
663675
}
664676
aggFieldsRenamesMap.push_back({originalColName, renamedColName});
665677
}
666678
uniqueColumnNames.insert(renamedColName.GetFullName());
667-
//Y_ENSURE(!GetAtom(pgAgg->ChildPtr(1), "distinct"));
668679

669-
// Distinct for column or expression.
680+
// Distinct for column or expression f(distinct a) => distinct a as result -> f(result).
670681
if (!!GetAtom(pgAgg->ChildPtr(1), "distinct")) {
671682
const auto colName = renamedColName.GetFullName();
672683
auto distinctAggTraits =
@@ -676,12 +687,15 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
676687
distinctPreAggregate = true;
677688
}
678689

690+
// Aggregation on pg columns requires type cast from yql type to pg type.
679691
aggregationColumnsRequireCastToPgType.insert(resultColName);
680692
const TString aggFuncName = TString(pgAgg->ChildPtr(0)->Content());
693+
// Build an aggregation traits.
681694
auto aggregationTraits = BuildAggregationTraits(renamedColName.GetFullName(), resultColName, aggFuncName,
682695
aggFuncResultType, ctx, node->Pos());
683696
aggTraits.AggTraitsList.push_back(aggregationTraits);
684697

698+
// Case for distinct after aggregation.
685699
if (distinctAll) {
686700
auto distinctAggTraits =
687701
BuildAggregationTraits(resultColName, resultColName, "distinct", aggFuncResultType, ctx, node->Pos());
@@ -793,16 +807,33 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
793807
Y_ENSURE(SupportedAggregationFunctions.count(pgAgg->ChildPtr(0)->Content()),
794808
"Aggregation function " + TString(pgAgg->ChildPtr(0)->Content()) + " is not supported ");
795809

810+
// clang-format off
811+
auto newBody = Build<TCoMember>(ctx, node->Pos())
812+
.Struct(lambda.Args().Arg(0))
813+
.Name<TCoAtom>()
814+
.Value(columnName)
815+
.Build()
816+
.Done().Ptr();
817+
// clang-format on
818+
819+
auto lambdaBody = lambda.Body().Ptr();
796820
// Build a projection lambda, we do not need `PgAgg` inside.
821+
if (lambdaBody->IsCallable("PgResolvedOp")) {
822+
// Replace PgResolvedOp(PgAgg(arg)) -> PgResolvedOp(PgCast(ToPg(Member(arg, columnName))))
823+
auto toPg = ctx.NewCallable(node->Pos(), "ToPg", {newBody});
824+
auto pgType =
825+
ctx.NewCallable(node->Pos(), "PgType", {ctx.NewAtom(node->Pos(), NPg::LookupType(expectedType->GetId()).Name)});
826+
auto pgCast = ctx.NewCallable(node->Pos(), "PgCast", {toPg, pgType});
827+
828+
TNodeOnNodeOwnedMap replaces;
829+
replaces[pgAgg.Get()] = pgCast;
830+
newBody = ctx.ReplaceNodes(lambda.Body().Ptr(), replaces);
831+
}
832+
797833
// clang-format off
798834
lambda = Build<TCoLambda>(ctx, node->Pos())
799835
.Args(lambda.Args())
800-
.Body<TCoMember>()
801-
.Struct(lambda.Args().Arg(0))
802-
.Name<TCoAtom>()
803-
.Value(columnName)
804-
.Build()
805-
.Build()
836+
.Body(newBody)
806837
.Done();
807838
// clang-format on
808839
}
@@ -853,7 +884,7 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
853884
.Body(toPg)
854885
.Done();
855886
// clang-format on
856-
} else if (needPgCast || needPgCastForAgg) {
887+
} else if ((needPgCast || needPgCastForAgg)) {
857888

858889
auto pgType =
859890
ctx.NewCallable(node->Pos(), "PgType", {ctx.NewAtom(node->Pos(), NPg::LookupType(expectedType->GetId()).Name)});
@@ -961,6 +992,7 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
961992
.ColumnOrder(columnOrder)
962993
.Done().Ptr();
963994
// clang-format on
995+
964996
}
965997

966998
TExprNode::TPtr PushTakeIntoPlan(const TExprNode::TPtr &node, TExprContext &ctx, const TTypeAnnotationContext &typeCtx) {
@@ -1081,7 +1113,8 @@ IGraphTransformer::TStatus TKqpRBOCleanupTransformer::DoTransform(TExprNode::TPt
10811113

10821114
Y_UNUSED(ctx);
10831115

1084-
//YQL_CLOG(TRACE, CoreDq) << "Cleanup input plan: " << output->Dump();
1116+
YQL_CLOG(TRACE, CoreDq) << "Cleanup input plan: " << KqpExprToPrettyString(TExprBase(output), ctx) << Endl;
1117+
10851118

10861119
if (output->IsList() && output->ChildrenSize() >= 1) {
10871120
auto child_level_1 = output->Child(0);

ydb/core/kqp/ut/rbo/kqp_rbo_ut.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,12 @@ Y_UNIT_TEST_SUITE(KqpRbo) {
628628
--!syntax_pg
629629
SET TablePathPrefix = "/Root/";
630630
select sum(distinct t1.b) as sum, t1.a from t1 group by t1.a order by sum;
631-
)"
631+
)",
632+
R"(
633+
--!syntax_pg
634+
SET TablePathPrefix = "/Root/";
635+
select sum(t1.a) + 1, t1.b from t1 group by t1.b order by t1.b;
636+
)",
632637
};
633638

634639
std::vector<std::string> results = {
@@ -650,6 +655,7 @@ Y_UNIT_TEST_SUITE(KqpRbo) {
650655
R"([["4";"2";"4"];["6";"3";"4"]])",
651656
R"([["4"];["8"];["8"]])",
652657
R"([["1";"1"];["1";"3"];["2";"0"];["2";"2"];["2";"4"]])",
658+
R"([["5";"1"];["7";"2"]])"
653659
};
654660

655661
for (ui32 i = 0; i < queries.size(); ++i) {

0 commit comments

Comments
 (0)