Skip to content

Commit 0e9b716

Browse files
authored
Merge pull request #1 from pytorch-labs/fmassa/compute_model
Add compute cost in optimization problem
2 parents ee20175 + aa08b62 commit 0e9b716

File tree

2 files changed

+93
-15
lines changed

2 files changed

+93
-15
lines changed

autoparallel/compute_estimation.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
from torch.utils._pytree import tree_map_only
3+
from torch.utils.flop_counter import FlopCounterMode
4+
5+
6+
def _get_device_tflops(dtype):
7+
# for some reason the function from PyTorch is giving
8+
# wildly different TFlops compared to the specs. I'm
9+
# using had-coded values for now that I pulled from xFormers
10+
# https://github.com/fairinternal/xformers/blob/main/xformers/profiler/device_limits.py
11+
# TODO: fix PyTorch's implementation
12+
# from torch._inductor.utils import get_device_tflops
13+
14+
device = None
15+
device_name = torch.cuda.get_device_name(device)
16+
assert "H100" in device_name, f"Only H100 supported from now, got {device_name}"
17+
18+
return {
19+
torch.float64: 67,
20+
# NOTE: NVIDIA gives all numbers "with 2:4 sparsity"
21+
# but we want the full GEMM numbers
22+
torch.float32: 989 // 2,
23+
torch.float16: 1979 // 2,
24+
torch.bfloat16: 1979 // 2,
25+
torch.int8: 3958 // 2,
26+
}[dtype]
27+
28+
29+
def _get_sharded_shape(spec):
30+
mesh = spec.mesh
31+
tensor_shape = spec.tensor_meta.shape
32+
# TODO: take dtype into account as well
33+
# tensor_dtype = spec.tensor_meta.dtype
34+
placements = spec.placements
35+
# TODO: find a better heuristic other than
36+
# running DTensor
37+
new_tensor_shape = list(tensor_shape)
38+
for mesh_size, placement in zip(mesh.shape, placements):
39+
if placement.is_shard():
40+
dim = placement.dim
41+
new_tensor_shape[dim] = (
42+
new_tensor_shape[dim] + mesh_size - 1
43+
) // mesh_size
44+
return new_tensor_shape
45+
46+
47+
def estimate_strategy_runtime_cost(node, strategy):
48+
if node.op != "call_function":
49+
return 0
50+
# suppose only matmul-like ops
51+
if not isinstance(node.target, torch._ops.OpOverload):
52+
return 0
53+
54+
if node.target.is_view:
55+
return 0
56+
57+
args = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.args)
58+
kwargs = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.kwargs)
59+
fake_mode = next(arg.fake_mode for arg in args if isinstance(arg, torch._subclasses.fake_tensor.FakeTensor))
60+
assert len(kwargs) == 0
61+
args_shapes = tuple(_get_sharded_shape(spec) for spec in strategy.input_specs)
62+
63+
counter = 0
64+
args = list(args)
65+
for i, arg in enumerate(args):
66+
if isinstance(arg, torch.Tensor):
67+
with fake_mode:
68+
args[i] = torch.empty(args_shapes[counter], device=arg.device, dtype=arg.dtype)
69+
counter += 1
70+
71+
# TODO: maybe cache the flop_counter to avoid recreating it
72+
# all the time
73+
with FlopCounterMode(display=False) as flop_counter:
74+
out = node.target(*args, **kwargs)
75+
76+
flops = flop_counter.get_total_flops()
77+
78+
# TODO: fix this
79+
dtype = strategy.input_specs[0].tensor_meta.dtype
80+
81+
# TODO: use PyTorch's version once it's giving correct results
82+
gpu_flops = _get_device_tflops(dtype) * 10 ** 12
83+
84+
# suppose 50% efficiency for the operator
85+
factor = 1 / 0.5
86+
compute_time = factor * flops / gpu_flops * 1e6 # us
87+
88+
return compute_time

autoparallel/optimize_sharding.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.distributed.tensor._dtensor_spec import DTensorSpec
77
from torch.distributed.tensor.placement_types import Replicate, Shard
88
from torch.utils._pytree import tree_flatten, tree_map_only
9+
from .compute_estimation import _get_sharded_shape, estimate_strategy_runtime_cost
910
from .utils import get_placement_options
1011

1112

@@ -84,15 +85,16 @@ def build_ds(self):
8485
"num_output_strat": len(s.strategies),
8586
}
8687
for ss, ssi in enumerate(s.strategies):
88+
compute_cost = estimate_strategy_runtime_cost(node, ssi)
8789
for argi, xxi in enumerate(ssi.redistribute_cost):
88-
for ii, input_p in enumerate(xxi):
90+
for ii, comm_cost in enumerate(xxi):
8991
va = pulp.LpVariable(
9092
f"n={node},s={s_i},arg={argi},output_p={ss},input_p={ii}",
9193
cat=pulp.LpBinary,
9294
)
9395
ds[(s_i, argi, ss, ii)] = {
9496
"va": va,
95-
"cost": input_p,
97+
"cost": comm_cost + compute_cost,
9698
"full_strat": ssi,
9799
"out_strat": ssi.output_specs,
98100
"inp_strat": ssi.input_specs[argi],
@@ -533,20 +535,8 @@ def add_parameter_memory_constraint(self, memory_factor_low, memory_factor_high)
533535
for ii in range(vv["num_output_strat"]):
534536
data = self.ds[(s_i, 0, ii, 0)]
535537
spec = data["inp_strat"]
536-
mesh = spec.mesh
537538
tensor_shape = spec.tensor_meta.shape
538-
# TODO: take dtype into account as well
539-
# tensor_dtype = spec.tensor_meta.dtype
540-
placements = spec.placements
541-
# TODO: find a better heuristic other than
542-
# running DTensor
543-
new_tensor_shape = list(tensor_shape)
544-
for mesh_size, placement in zip(mesh.shape, placements):
545-
if placement.is_shard():
546-
dim = placement.dim
547-
new_tensor_shape[dim] = (
548-
new_tensor_shape[dim] + mesh_size - 1
549-
) // mesh_size
539+
new_tensor_shape = _get_sharded_shape(spec)
550540
new_size = math.prod(new_tensor_shape)
551541
old_size = math.prod(tensor_shape)
552542
elms.append(data["va"] * new_size / old_size)

0 commit comments

Comments
 (0)