Commit 0c73c32
Mikhail Zolotukhin
[RFC] Add LoopNest class that implements Schedule's API in a different way.
LoopNest is my attempt to simplify our core abstraction. The main idea
behind this change is to merge two classes: `TensorExprNode` and `For`
(derived from `Stmt`). Currently they represent basically the same
thing, but in a slightly different way. `TensorExprNode` attaches some
metadata and provides a different way for traversing through
siblings/parents/children. `For` represents the same structure, but
without any metadata. Once a kernel is lowered to `For` statements, they
are immediately consumed by a codegen, which lowers them to LLVMIR or
prints as a CUDA string.
This PR adds some functionality to `For` statements (and to other types
of statements as well) and implements `SplitWithTail` and
`ComputeInline` using only those. The implementation is just a proof of
concept: it doesn't cover all corner cases, but they are trivial to add.
As a demo, I added a test where we create a simple tensor-expression,
then split one of the axis and then lower it to a Stmt. The demo shows
that we're producing exactly the same result.
For the reference, below is the output of the test (Root stmt - produced
by the new implementation, Ref stmt - the product of the existing one):
```
[ RUN ] TensorExprTest.LoopNest_LLVM
Root stmt:
for (int n = 0; n < N; n++) {
for (int i = 0; i < 1024; i++) {
for (int j_outer = 0; j_outer < ((256 - 0) / 17); j_outer++) {
for (int j_inner = 0; j_inner < 17; j_inner++) {
g[(((n * (1024 * 256)) + (i * 256)) + (((j_outer * 17) + j_inner) * 1))] = (((A[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))] + B[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))]) + C[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))]) + D[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))]);
}
}
for (int j_tail = 0; j_tail < ((256 - 0) % 17); j_tail++) {
g[(((n * (1024 * 256)) + (i * 256)) + ((j_tail + (((256 - 0) / 17) * 17)) * 1))] = (((A[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))] + B[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))]) + C[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))]) + D[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))]);
}
}
}
Ref stmt:
for (int n = 0; n < N; n++) {
for (int i = 0; i < 1024; i++) {
for (int j_outer = 0; j_outer < ((256 - 0) / 17); j_outer++) {
for (int j_inner = 0; j_inner < 17; j_inner++) {
g[(((n * (1024 * 256)) + (i * 256)) + (((j_outer * 17) + j_inner) * 1))] = (((A[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))] + B[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))]) + C[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))]) + D[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))]);
}
}
for (int j_tail = 0; j_tail < ((256 - 0) % 17); j_tail++) {
g[(((n * (1024 * 256)) + (i * 256)) + ((j_tail + (((256 - 0) / 17) * 17)) * 1))] = (((A[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))] + B[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))]) + C[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))]) + D[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))]);
}
}
}
[ OK ] TensorExprTest.LoopNest_LLVM (3 ms)
```1 parent af20070 commit 0c73c32
File tree
5 files changed
+223
-4
lines changed- test/cpp/tensorexpr
- torch/csrc/jit/tensorexpr
5 files changed
+223
-4
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
545 | 545 | | |
546 | 546 | | |
547 | 547 | | |
| 548 | + | |
| 549 | + | |
| 550 | + | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
| 564 | + | |
| 565 | + | |
| 566 | + | |
| 567 | + | |
| 568 | + | |
| 569 | + | |
| 570 | + | |
| 571 | + | |
| 572 | + | |
| 573 | + | |
| 574 | + | |
| 575 | + | |
| 576 | + | |
| 577 | + | |
| 578 | + | |
| 579 | + | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
| 583 | + | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
| 588 | + | |
| 589 | + | |
| 590 | + | |
| 591 | + | |
| 592 | + | |
| 593 | + | |
| 594 | + | |
| 595 | + | |
| 596 | + | |
| 597 | + | |
| 598 | + | |
| 599 | + | |
| 600 | + | |
| 601 | + | |
| 602 | + | |
| 603 | + | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
| 609 | + | |
548 | 610 | | |
549 | 611 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
160 | 160 | | |
161 | 161 | | |
162 | 162 | | |
163 | | - | |
| 163 | + | |
| 164 | + | |
164 | 165 | | |
165 | 166 | | |
166 | 167 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
873 | 873 | | |
874 | 874 | | |
875 | 875 | | |
| 876 | + | |
| 877 | + | |
| 878 | + | |
| 879 | + | |
| 880 | + | |
| 881 | + | |
| 882 | + | |
| 883 | + | |
| 884 | + | |
| 885 | + | |
| 886 | + | |
| 887 | + | |
| 888 | + | |
| 889 | + | |
| 890 | + | |
| 891 | + | |
| 892 | + | |
| 893 | + | |
| 894 | + | |
| 895 | + | |
| 896 | + | |
| 897 | + | |
| 898 | + | |
| 899 | + | |
| 900 | + | |
| 901 | + | |
| 902 | + | |
| 903 | + | |
| 904 | + | |
| 905 | + | |
| 906 | + | |
| 907 | + | |
| 908 | + | |
| 909 | + | |
| 910 | + | |
| 911 | + | |
| 912 | + | |
| 913 | + | |
| 914 | + | |
| 915 | + | |
| 916 | + | |
| 917 | + | |
| 918 | + | |
| 919 | + | |
| 920 | + | |
| 921 | + | |
| 922 | + | |
| 923 | + | |
| 924 | + | |
| 925 | + | |
| 926 | + | |
| 927 | + | |
| 928 | + | |
| 929 | + | |
| 930 | + | |
| 931 | + | |
| 932 | + | |
| 933 | + | |
| 934 | + | |
| 935 | + | |
| 936 | + | |
| 937 | + | |
| 938 | + | |
| 939 | + | |
| 940 | + | |
| 941 | + | |
| 942 | + | |
| 943 | + | |
| 944 | + | |
| 945 | + | |
| 946 | + | |
| 947 | + | |
| 948 | + | |
| 949 | + | |
| 950 | + | |
| 951 | + | |
| 952 | + | |
| 953 | + | |
| 954 | + | |
| 955 | + | |
| 956 | + | |
| 957 | + | |
| 958 | + | |
| 959 | + | |
| 960 | + | |
| 961 | + | |
| 962 | + | |
| 963 | + | |
| 964 | + | |
| 965 | + | |
| 966 | + | |
| 967 | + | |
| 968 | + | |
| 969 | + | |
| 970 | + | |
| 971 | + | |
| 972 | + | |
| 973 | + | |
| 974 | + | |
| 975 | + | |
| 976 | + | |
| 977 | + | |
| 978 | + | |
876 | 979 | | |
877 | 980 | | |
878 | 981 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
671 | 671 | | |
672 | 672 | | |
673 | 673 | | |
| 674 | + | |
| 675 | + | |
| 676 | + | |
| 677 | + | |
| 678 | + | |
| 679 | + | |
| 680 | + | |
| 681 | + | |
| 682 | + | |
| 683 | + | |
| 684 | + | |
| 685 | + | |
| 686 | + | |
| 687 | + | |
| 688 | + | |
| 689 | + | |
| 690 | + | |
| 691 | + | |
| 692 | + | |
| 693 | + | |
| 694 | + | |
| 695 | + | |
| 696 | + | |
674 | 697 | | |
675 | 698 | | |
676 | 699 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
| 19 | + | |
| 20 | + | |
19 | 21 | | |
20 | 22 | | |
21 | 23 | | |
| |||
84 | 86 | | |
85 | 87 | | |
86 | 88 | | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
87 | 101 | | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
88 | 107 | | |
89 | | - | |
| 108 | + | |
90 | 109 | | |
91 | 110 | | |
92 | 111 | | |
| |||
358 | 377 | | |
359 | 378 | | |
360 | 379 | | |
361 | | - | |
| 380 | + | |
362 | 381 | | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
363 | 388 | | |
364 | 389 | | |
365 | 390 | | |
| |||
370 | 395 | | |
371 | 396 | | |
372 | 397 | | |
373 | | - | |
374 | 398 | | |
375 | 399 | | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
376 | 406 | | |
377 | 407 | | |
378 | 408 | | |
| |||
0 commit comments