Skip to content

Commit 165c6d1

Browse files
authored
[mlir] Add support for parsing nested PassPipelineOptions (#101118)
- Added a default parsing implementation to `PassOptions` to allow `Option`/`ListOption` to wrap PassOption objects. This is helpful when creating meta-pipelines (pass pipelines composed of pass pipelines). - Updated `ListOption` printing to enable round-tripping the output of `dump-pass-pipeline` back into `mlir-opt` for more complex structures.
1 parent 842789b commit 165c6d1

File tree

4 files changed

+131
-47
lines changed

4 files changed

+131
-47
lines changed

mlir/include/mlir/Pass/PassOptions.h

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,25 @@ class PassOptions : protected llvm::cl::SubCommand {
139139
}
140140
};
141141

142+
/// This is the parser that is used by pass options that wrap PassOptions
143+
/// instances. Like GenericOptionParser, this is a thin wrapper around
144+
/// llvm::cl::basic_parser.
145+
template <typename PassOptionsT>
146+
struct PassOptionsParser : public llvm::cl::basic_parser<PassOptionsT> {
147+
using llvm::cl::basic_parser<PassOptionsT>::basic_parser;
148+
// Parse the options object by delegating to
149+
// `PassOptionsT::parseFromString`.
150+
bool parse(llvm::cl::Option &, StringRef, StringRef arg,
151+
PassOptionsT &value) {
152+
return failed(value.parseFromString(arg));
153+
}
154+
155+
// Print the options object by delegating to `PassOptionsT::print`.
156+
static void print(llvm::raw_ostream &os, const PassOptionsT &value) {
157+
value.print(os);
158+
}
159+
};
160+
142161
/// Utility methods for printing option values.
143162
template <typename DataT>
144163
static void printValue(raw_ostream &os, GenericOptionParser<DataT> &parser,
@@ -154,19 +173,24 @@ class PassOptions : protected llvm::cl::SubCommand {
154173
}
155174

156175
public:
157-
/// The specific parser to use depending on llvm::cl parser used. This is only
158-
/// necessary because we need to provide additional methods for certain data
159-
/// type parsers.
160-
/// TODO: We should upstream the methods in GenericOptionParser to avoid the
161-
/// need to do this.
176+
/// The specific parser to use. This is necessary because we need to provide
177+
/// additional methods for certain data type parsers.
162178
template <typename DataType>
163-
using OptionParser =
179+
using OptionParser = std::conditional_t<
180+
// If the data type is derived from PassOptions, use the
181+
// PassOptionsParser.
182+
std::is_base_of_v<PassOptions, DataType>, PassOptionsParser<DataType>,
183+
// Otherwise, use GenericOptionParser where it is well formed, and fall
184+
// back to llvm::cl::parser otherwise.
185+
// TODO: We should upstream the methods in GenericOptionParser to avoid
186+
// the need to do this.
164187
std::conditional_t<std::is_base_of<llvm::cl::generic_parser_base,
165188
llvm::cl::parser<DataType>>::value,
166189
GenericOptionParser<DataType>,
167-
llvm::cl::parser<DataType>>;
190+
llvm::cl::parser<DataType>>>;
168191

169-
/// This class represents a specific pass option, with a provided data type.
192+
/// This class represents a specific pass option, with a provided
193+
/// data type.
170194
template <typename DataType, typename OptionParser = OptionParser<DataType>>
171195
class Option
172196
: public llvm::cl::opt<DataType, /*ExternalStorage=*/false, OptionParser>,
@@ -278,11 +302,12 @@ class PassOptions : protected llvm::cl::SubCommand {
278302
if ((**this).empty())
279303
return;
280304

281-
os << this->ArgStr << '=';
305+
os << this->ArgStr << "={";
282306
auto printElementFn = [&](const DataType &value) {
283307
printValue(os, this->getParser(), value);
284308
};
285309
llvm::interleave(*this, os, printElementFn, ",");
310+
os << "}";
286311
}
287312

288313
/// Copy the value from the given option into this one.
@@ -311,7 +336,7 @@ class PassOptions : protected llvm::cl::SubCommand {
311336

312337
/// Print the options held by this struct in a form that can be parsed via
313338
/// 'parseFromString'.
314-
void print(raw_ostream &os);
339+
void print(raw_ostream &os) const;
315340

316341
/// Print the help string for the options held by this struct. `descIndent` is
317342
/// the indent that the descriptions should be aligned.

mlir/lib/Pass/PassRegistry.cpp

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Pass/PassManager.h"
1313
#include "llvm/ADT/DenseMap.h"
1414
#include "llvm/ADT/ScopeExit.h"
15+
#include "llvm/ADT/StringRef.h"
1516
#include "llvm/Support/Format.h"
1617
#include "llvm/Support/ManagedStatic.h"
1718
#include "llvm/Support/MemoryBuffer.h"
@@ -185,6 +186,31 @@ const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {
185186
// PassOptions
186187
//===----------------------------------------------------------------------===//
187188

189+
/// Extract an argument from 'options' and update it to point after the arg.
190+
/// Returns the cleaned argument string.
191+
static StringRef extractArgAndUpdateOptions(StringRef &options,
192+
size_t argSize) {
193+
StringRef str = options.take_front(argSize).trim();
194+
options = options.drop_front(argSize).ltrim();
195+
196+
// Early exit if there's no escape sequence.
197+
if (str.size() <= 2)
198+
return str;
199+
200+
const auto escapePairs = {std::make_pair('\'', '\''),
201+
std::make_pair('"', '"'), std::make_pair('{', '}')};
202+
for (const auto &escape : escapePairs) {
203+
if (str.front() == escape.first && str.back() == escape.second) {
204+
// Drop the escape characters and trim.
205+
str = str.drop_front().drop_back().trim();
206+
// Don't process additional escape sequences.
207+
break;
208+
}
209+
}
210+
211+
return str;
212+
}
213+
188214
LogicalResult detail::pass_options::parseCommaSeparatedList(
189215
llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
190216
function_ref<LogicalResult(StringRef)> elementParseFn) {
@@ -213,13 +239,16 @@ LogicalResult detail::pass_options::parseCommaSeparatedList(
213239
size_t nextElePos = findChar(optionStr, 0, ',');
214240
while (nextElePos != StringRef::npos) {
215241
// Process the portion before the comma.
216-
if (failed(elementParseFn(optionStr.substr(0, nextElePos))))
242+
if (failed(
243+
elementParseFn(extractArgAndUpdateOptions(optionStr, nextElePos))))
217244
return failure();
218245

219-
optionStr = optionStr.substr(nextElePos + 1);
246+
// Drop the leading ','
247+
optionStr = optionStr.drop_front();
220248
nextElePos = findChar(optionStr, 0, ',');
221249
}
222-
return elementParseFn(optionStr.substr(0, nextElePos));
250+
return elementParseFn(
251+
extractArgAndUpdateOptions(optionStr, optionStr.size()));
223252
}
224253

225254
/// Out of line virtual function to provide home for the class.
@@ -239,27 +268,6 @@ void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
239268
/// `options` string pointing after the parsed option].
240269
static std::tuple<StringRef, StringRef, StringRef>
241270
parseNextArg(StringRef options) {
242-
// Functor used to extract an argument from 'options' and update it to point
243-
// after the arg.
244-
auto extractArgAndUpdateOptions = [&](size_t argSize) {
245-
StringRef str = options.take_front(argSize).trim();
246-
options = options.drop_front(argSize).ltrim();
247-
// Handle escape sequences
248-
if (str.size() > 2) {
249-
const auto escapePairs = {std::make_pair('\'', '\''),
250-
std::make_pair('"', '"'),
251-
std::make_pair('{', '}')};
252-
for (const auto &escape : escapePairs) {
253-
if (str.front() == escape.first && str.back() == escape.second) {
254-
// Drop the escape characters and trim.
255-
str = str.drop_front().drop_back().trim();
256-
// Don't process additional escape sequences.
257-
break;
258-
}
259-
}
260-
}
261-
return str;
262-
};
263271
// Try to process the given punctuation, properly escaping any contained
264272
// characters.
265273
auto tryProcessPunct = [&](size_t &currentPos, char punct) {
@@ -276,13 +284,13 @@ parseNextArg(StringRef options) {
276284
for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
277285
// Check for the end of the full option.
278286
if (argEndIt == optionsE || options[argEndIt] == ' ') {
279-
argName = extractArgAndUpdateOptions(argEndIt);
287+
argName = extractArgAndUpdateOptions(options, argEndIt);
280288
return std::make_tuple(argName, StringRef(), options);
281289
}
282290

283291
// Check for the end of the name and the start of the value.
284292
if (options[argEndIt] == '=') {
285-
argName = extractArgAndUpdateOptions(argEndIt);
293+
argName = extractArgAndUpdateOptions(options, argEndIt);
286294
options = options.drop_front();
287295
break;
288296
}
@@ -292,7 +300,7 @@ parseNextArg(StringRef options) {
292300
for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
293301
// Handle the end of the options string.
294302
if (argEndIt == optionsE || options[argEndIt] == ' ') {
295-
StringRef value = extractArgAndUpdateOptions(argEndIt);
303+
StringRef value = extractArgAndUpdateOptions(options, argEndIt);
296304
return std::make_tuple(argName, value, options);
297305
}
298306

@@ -344,7 +352,7 @@ LogicalResult detail::PassOptions::parseFromString(StringRef options,
344352

345353
/// Print the options held by this struct in a form that can be parsed via
346354
/// 'parseFromString'.
347-
void detail::PassOptions::print(raw_ostream &os) {
355+
void detail::PassOptions::print(raw_ostream &os) const {
348356
// If there are no options, there is nothing left to do.
349357
if (OptionsMap.empty())
350358
return;

mlir/test/Pass/pipeline-options-parsing.mlir

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,19 @@
1111
// RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(builtin.module(func.func(test-options-pass{list=3}), func.func(test-options-pass{enum=one list=1,2,3,4 string="foo bar baz"})))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_5 %s
1212
// RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(builtin.module(func.func(test-options-pass{list=3}), func.func(test-options-pass{enum=one list=1,2,3,4 string={foo bar baz}})))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_5 %s
1313
// RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(builtin.module(func.func(test-options-pass{list=3}), func.func(test-options-pass{enum=one list=1,2,3,4 string=foo"bar"baz})))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_6 %s
14+
// RUN: mlir-opt %s -verify-each=false '-test-options-super-pass-pipeline=super-list={{enum=zero list=1 string=foo},{enum=one list=2 string="bar"},{enum=two list=3 string={baz}}}' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_7 %s
15+
// RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(func.func(test-options-super-pass{list={{enum=zero list={1} string=foo },{enum=one list={2} string=bar },{enum=two list={3} string=baz }}}))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_7 %s
1416

1517
// CHECK_ERROR_1: missing closing '}' while processing pass options
1618
// CHECK_ERROR_2: no such option test-option
1719
// CHECK_ERROR_3: no such option invalid-option
1820
// CHECK_ERROR_4: 'notaninteger' value invalid for integer argument
1921
// CHECK_ERROR_5: for the --enum option: Cannot find option named 'invalid'!
2022

21-
// CHECK_1: test-options-pass{enum=zero list=1,2,3,4,5 string=nested_pipeline{arg1=10 arg2=" {} " arg3=true} string-list=a,b,c,d}
22-
// CHECK_2: test-options-pass{enum=one list=1 string= string-list=a,b}
23-
// CHECK_3: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string= })))
24-
// CHECK_4: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string=foobar })))
25-
// CHECK_5: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string={foo bar baz} })))
26-
// CHECK_6: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string=foo"bar"baz })))
23+
// CHECK_1: test-options-pass{enum=zero list={1,2,3,4,5} string=nested_pipeline{arg1=10 arg2=" {} " arg3=true} string-list={a,b,c,d}}
24+
// CHECK_2: test-options-pass{enum=one list={1} string= string-list={a,b}}
25+
// CHECK_3: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string= })))
26+
// CHECK_4: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string=foobar })))
27+
// CHECK_5: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string={foo bar baz} })))
28+
// CHECK_6: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string=foo"bar"baz })))
29+
// CHECK_7{LITERAL}: builtin.module(func.func(test-options-super-pass{list={{enum=zero list={1} string=foo },{enum=one list={2} string=bar },{enum=two list={3} string=baz }}}))

mlir/test/lib/Pass/TestPassManager.cpp

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ struct TestOptionsPass
5454
: public PassWrapper<TestOptionsPass, OperationPass<func::FuncOp>> {
5555
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsPass)
5656

57-
enum Enum { One, Two };
57+
enum Enum { Zero, One, Two };
5858

5959
struct Options : public PassPipelineOptions<Options> {
6060
ListOption<int> listOption{*this, "list",
@@ -66,7 +66,15 @@ struct TestOptionsPass
6666
Option<Enum> enumOption{
6767
*this, "enum", llvm::cl::desc("Example enum option"),
6868
llvm::cl::values(clEnumValN(0, "zero", "Example zero value"),
69-
clEnumValN(1, "one", "Example one value"))};
69+
clEnumValN(1, "one", "Example one value"),
70+
clEnumValN(2, "two", "Example two value"))};
71+
72+
Options() = default;
73+
Options(const Options &rhs) { *this = rhs; }
74+
Options &operator=(const Options &rhs) {
75+
copyOptionValuesFrom(rhs);
76+
return *this;
77+
}
7078
};
7179
TestOptionsPass() = default;
7280
TestOptionsPass(const TestOptionsPass &) : PassWrapper() {}
@@ -92,7 +100,37 @@ struct TestOptionsPass
92100
Option<Enum> enumOption{
93101
*this, "enum", llvm::cl::desc("Example enum option"),
94102
llvm::cl::values(clEnumValN(0, "zero", "Example zero value"),
95-
clEnumValN(1, "one", "Example one value"))};
103+
clEnumValN(1, "one", "Example one value"),
104+
clEnumValN(2, "two", "Example two value"))};
105+
};
106+
107+
struct TestOptionsSuperPass
108+
: public PassWrapper<TestOptionsSuperPass, OperationPass<func::FuncOp>> {
109+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsSuperPass)
110+
111+
struct Options : public PassPipelineOptions<Options> {
112+
ListOption<TestOptionsPass::Options> listOption{
113+
*this, "super-list",
114+
llvm::cl::desc("Example list of PassPipelineOptions option")};
115+
116+
Options() = default;
117+
};
118+
119+
TestOptionsSuperPass() = default;
120+
TestOptionsSuperPass(const TestOptionsSuperPass &) : PassWrapper() {}
121+
TestOptionsSuperPass(const Options &options) {
122+
listOption = options.listOption;
123+
}
124+
125+
void runOnOperation() final {}
126+
StringRef getArgument() const final { return "test-options-super-pass"; }
127+
StringRef getDescription() const final {
128+
return "Test options of options parsing capabilities";
129+
}
130+
131+
ListOption<TestOptionsPass::Options> listOption{
132+
*this, "list",
133+
llvm::cl::desc("Example list of PassPipelineOptions option")};
96134
};
97135

98136
/// A test pass that always aborts to enable testing the crash recovery
@@ -220,6 +258,7 @@ static void testNestedPipelineTextual(OpPassManager &pm) {
220258
namespace mlir {
221259
void registerPassManagerTestPass() {
222260
PassRegistration<TestOptionsPass>();
261+
PassRegistration<TestOptionsSuperPass>();
223262

224263
PassRegistration<TestModulePass>();
225264

@@ -248,5 +287,14 @@ void registerPassManagerTestPass() {
248287
[](OpPassManager &pm, const TestOptionsPass::Options &options) {
249288
pm.addPass(std::make_unique<TestOptionsPass>(options));
250289
});
290+
291+
PassPipelineRegistration<TestOptionsSuperPass::Options>
292+
registerOptionsSuperPassPipeline(
293+
"test-options-super-pass-pipeline",
294+
"Parses options of PassPipelineOptions using pass pipeline "
295+
"registration",
296+
[](OpPassManager &pm, const TestOptionsSuperPass::Options &options) {
297+
pm.addPass(std::make_unique<TestOptionsSuperPass>(options));
298+
});
251299
}
252300
} // namespace mlir

0 commit comments

Comments
 (0)