Skip to content

Commit 85f8cc3

Browse files
committed
refactor(//core/partitioning): Centralizing partitioning state
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent c9e504d commit 85f8cc3

22 files changed

+496
-315
lines changed

core/compiler.cpp

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -219,19 +219,16 @@ void AddIfBlockToGraph(
219219
return;
220220
}
221221

222-
GraphAndMapping ConstructFallbackGraph(
222+
GraphAndMapping ConstructFallbackGraph_(
223223
torch::jit::script::Module& new_mod,
224224
torch::jit::Block* block,
225-
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map,
226-
CompileSpec cfg,
225+
partitioning::PartitioningCtx* partitioning_ctx,
226+
conversion::ConversionInfo convert_info,
227227
ir::StaticParams static_params,
228-
std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
229-
auto convert_cfg = cfg.convert_info;
230-
auto partitioning_info = cfg.partitioning_info;
231-
228+
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map) {
232229
auto new_g = std::make_shared<torch::jit::Graph>();
233230

234-
auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partitioning_info, fallback_nodes);
231+
auto segmented_blocks = partitioning::partition(partitioning_ctx, block, example_tensor_map);
235232

236233
// the mapping from lowering graph => fallback global graph
237234
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
@@ -240,7 +237,7 @@ GraphAndMapping ConstructFallbackGraph(
240237
}
241238

242239
for (auto& seg_block : segmented_blocks) {
243-
LOG_INFO(seg_block << "(GraphInSegmentedBlock)\n");
240+
LOG_INFO("Block segment:" << seg_block);
244241
std::ostringstream trt_engine_id;
245242
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
246243

@@ -254,12 +251,12 @@ GraphAndMapping ConstructFallbackGraph(
254251
inputs.push_back(in);
255252
}
256253
// update the input ranges for each segments
257-
convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
254+
convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
258255

259256
// TODO mapping Inputs Ivalue to flatten one here
260-
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params);
257+
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_info, static_params);
261258
auto temp_g = std::make_shared<torch::jit::Graph>();
262-
auto device_spec = convert_cfg.engine_settings.device;
259+
auto device_spec = convert_info.engine_settings.device;
263260
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
264261
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
265262

@@ -272,8 +269,8 @@ GraphAndMapping ConstructFallbackGraph(
272269
// convert the 2 blocks in prim::if and get the converted graph with mappings
273270
std::vector<GraphAndMapping> graph_and_mappings;
274271
for (auto cur_block : if_node->blocks()) {
275-
graph_and_mappings.push_back(
276-
ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params, fallback_nodes));
272+
graph_and_mappings.push_back(ConstructFallbackGraph_(
273+
new_mod, cur_block, partitioning_ctx, convert_info, static_params, example_tensor_map));
277274
}
278275
AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);
279276

@@ -303,13 +300,32 @@ GraphAndMapping ConstructFallbackGraph(
303300
return {new_g, old_to_new_g};
304301
}
305302

303+
GraphAndMapping ConstructFallbackGraph(
304+
torch::jit::script::Module& new_mod,
305+
torch::jit::Block* block,
306+
CompileSpec cfg,
307+
ir::StaticParams static_params,
308+
ir::CollectionTypeMap first_use_types) {
309+
auto convert_info = cfg.convert_info;
310+
auto partitioning_info = cfg.partitioning_info;
311+
312+
auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
313+
auto collection_input_ivalues_map =
314+
partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types);
315+
316+
return ConstructFallbackGraph_(
317+
new_mod, block, &partitioning_ctx, convert_info, static_params, collection_input_ivalues_map);
318+
}
319+
306320
void MapInputsAndDetermineDTypes(
307321
CompileSpec& cfg,
308322
std::shared_ptr<torch::jit::Graph>& g,
309323
ir::StaticParams& static_params,
310324
ir::CollectionTypeMap& first_use_type_map) {
311325
cfg.convert_info.collection_input_spec_map =
312326
std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
327+
cfg.partitioning_info.collection_input_spec_map =
328+
ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map);
313329

314330
auto collection_inputs = ir::get_collection_inputs(g, static_params);
315331
LOG_DEBUG(
@@ -434,11 +450,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
434450
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
435451
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
436452
outputIsCollection)) {
437-
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
438-
auto collection_input_ivalues_map =
439-
partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types);
440-
auto graph_and_mapping = ConstructFallbackGraph(
441-
new_mod, g->block(), collection_input_ivalues_map, cfg, static_params, fallback_nodes);
453+
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), cfg, static_params, first_use_types);
442454
new_g = graph_and_mapping.first;
443455
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
444456
for (size_t i = 0; i < new_g->inputs().size(); ++i) {

core/partitioning/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
"//core/ir",
2525
"//core/conversion",
2626
"//core/lowering",
27+
"//core/partitioning/partitioningctx",
2728
"//core/partitioning/partitioninginfo",
2829
"//core/partitioning/segmentedblock",
2930
] + select({

core/partitioning/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ target_include_directories(${lib_name}
3232
PUBLIC "$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}>"
3333
)
3434

35+
add_subdirectory(partitioningctx)
3536
add_subdirectory(partitioninginfo)
3637
add_subdirectory(segmentedblock)
3738

0 commit comments

Comments
 (0)