Skip to content

Commit 8e2c99f

Browse files
BowenBaofacebook-github-bot
authored andcommitted
Deepcopy model to another device before export to avoid OOM (#118710)
Summary: Prior to onnx export, the model is deepcopied to avoid modifications that may affect later performance profiling. However this increases the memory requirement on the device. This PR modifies the script to deepcopy and export the model on another device when possible. X-link: pytorch/pytorch#118710 Approved by: https://github.com/thiagocrepaldi Reviewed By: clee2000 Differential Revision: D53296686 fbshipit-source-id: e764fcf3c4f15f4f8793623571a09fd1b4263898
1 parent e60829a commit 8e2c99f

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

userbenchmark/dynamo/dynamobench/common.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,31 @@ def __init__(
12221222
self.model_dir / f"{model_name}_{self._COMPILER_NAME}.onnx"
12231223
)
12241224

1225+
def _determine_deepcopy_target_device(self):
1226+
if current_device == "cpu":
1227+
target_device = "cpu"
1228+
else:
1229+
if torch.cuda.device_count() > 1:
1230+
# Copy to another cuda device to avoid OOM.
1231+
target_device = "cuda:1"
1232+
else:
1233+
target_device = "cuda"
1234+
return target_device
1235+
1236+
def deepcopy_model_and_inputs_to_device(self, model, example_inputs, target_device):
1237+
# Deepcopy model before export to avoid modification to baseline model.
1238+
# To avoid OOM, the model is first moved to CPU. Both models are then moved to device.
1239+
model_device = next(model.parameters()).device
1240+
model.to("cpu")
1241+
model_copy = copy.deepcopy(model).to(target_device)
1242+
model.to(model_device)
1243+
1244+
target_device_example_inputs = tree_map_only(
1245+
torch.Tensor, lambda x: x.to(device=target_device), example_inputs
1246+
)
1247+
1248+
return model_copy, target_device_example_inputs
1249+
12251250
@classmethod
12261251
def _generate_onnx_model_directory(
12271252
cls, output_directory: str, compiler_name: str, model_name: str
@@ -1404,7 +1429,9 @@ def __init__(
14041429
def _export(self, model, example_inputs, output_path: str, /, **kwargs) -> None:
14051430
if self.copy_before_export:
14061431
# Deepcopy model before export to avoid modification to baseline model.
1407-
model = copy.deepcopy(model)
1432+
model, example_inputs = self.deepcopy_model_and_inputs_to_device(
1433+
model, example_inputs, self._determine_deepcopy_target_device()
1434+
)
14081435

14091436
# Hack for huggingface models (kwargs only).
14101437
if isinstance(example_inputs, dict):
@@ -1486,7 +1513,9 @@ def _export(
14861513
) -> torch.onnx.ONNXProgram:
14871514
if self.copy_before_export:
14881515
# Deepcopy model before export to avoid modification to baseline model.
1489-
model = copy.deepcopy(model)
1516+
model, example_inputs = self.deepcopy_model_and_inputs_to_device(
1517+
model, example_inputs, self._determine_deepcopy_target_device()
1518+
)
14901519

14911520
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
14921521
options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
@@ -1513,6 +1542,12 @@ class OnnxModelFromDynamoAotInline(OnnxModelFromDynamo):
15131542
def _export(
15141543
self, model, example_inputs, output_path: str
15151544
) -> torch.onnx.ONNXProgram:
1545+
if self.copy_before_export:
1546+
# Deepcopy model before export to avoid modification to baseline model.
1547+
model, example_inputs = self.deepcopy_model_and_inputs_to_device(
1548+
model, example_inputs, self._determine_deepcopy_target_device()
1549+
)
1550+
15161551
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
15171552
options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
15181553
onnx_program = torch.onnx.dynamo_export(

0 commit comments

Comments
 (0)