Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Add config to enable padding on inner dims for scaled_mm inputs #145

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 204 additions & 0 deletions benchmarks/bench_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from dataclasses import dataclass
from typing import Optional

import fire

import torch
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from float8_experimental.float8_utils import pad_tensor_for_matmul
from tabulate import tabulate
from torch._inductor.utils import do_bench_using_profiling
from tqdm import tqdm

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N

# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
h100_peak_flops_float32 = 67e12
h100_peak_flops_fp16_tc = 1979e12
h100_peak_tops_float8_tc = 3958e12

dtype_to_peak_tops = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}


def benchmark_fn_in_usec(f, *args, **kwargs):
no_args = lambda: f(*args, **kwargs)
time = do_bench_using_profiling(no_args)
return time * 1e3


def get_tops_info(tops, time, peak_tops):
time_sec = time / 1e6
tops_sec = float(tops) / time_sec
pct_top_peak = tops_sec / peak_tops
return tops_sec, pct_top_peak


def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)

a_config = ScaledMMConfig(
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
)
b_config = ScaledMMConfig(
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
)

a_fp8 = Float8Tensor.to_float8(A, scale_a, fp8_dtype, mm_config=a_config)
b_fp8 = Float8Tensor.to_float8(B, scale_b, fp8_dtype, mm_config=b_config)

return a_fp8 @ b_fp8


def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
# Breaks with compile due to trying to pad on fp8 dtype
# return do_fp8_matmul(A, B, fp8_dtype, out_dtype)
A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy
B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy

scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)

A_pad = A_pad.to(fp8_dtype) # mem copy
B_pad = B_pad.to(fp8_dtype) # mem copy

B_pad = B_pad.t().contiguous().t() # mem copy

return torch._scaled_mm(
A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True
)


def do_hp_matmul(A, B):
return torch.matmul(A, B)


def do_aligned_bf16_matmul(A, B):
A_pad = pad_tensor_for_matmul(A, dims=1)
B_pad = pad_tensor_for_matmul(B, dims=0)
return torch.matmul(A_pad, B_pad)


@dataclass
class Experiment_config:
M: int
K: int
N: int
output_dtype: torch.dtype
fp8_dtype: torch.dtype

def __iter__(self):
return iter((self.M, self.K, self.N, self.output_dtype, self.fp8_dtype))


def gen_configs():
shapes = shapes = [
(8193, 2501, 5008),
(65, 253, 4096),
(1023, 1029, 2512),
(4095, 511, 10000),
(2047, 3073, 8192),
(511, 769, 7504),
(127, 4097, 12288),
(32769, 15, 15024),
(9217, 8191, 20480),
(16385, 1025, 25008),
]
output_dtype = torch.bfloat16
fp8_dtype = torch.float8_e4m3fn
return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes]


@torch.no_grad()
def run(compile: bool = False, n_limit: Optional[int] = None):
device = "cuda"
experiments = gen_configs()
results = []
tops_table = []
tops_headers = [
"Shape",
"Ref Dtype",
"Ref Tops",
"Aligned BF16 Tops",
"FP8 Tops",
"Ref % Peak",
"Aligned BF16 % Peak",
"FP8 % Peak",
]

for experiment in tqdm(experiments):
M, K, N, output_dtype, fp8_dtype = experiment
tops = 2 * M * N * K

A_base = torch.rand(M, K, device=device, dtype=output_dtype)
B_base = torch.rand(K, N, device=device, dtype=output_dtype)

hp_func = torch.compile(do_hp_matmul) if compile else do_hp_matmul
aligned_bf16_func = (
torch.compile(do_aligned_bf16_matmul) if compile else do_aligned_bf16_matmul
)
fp8_func = torch.compile(do_fp8_pad_first_matmul) if compile else do_fp8_matmul

ref_time = benchmark_fn_in_usec(hp_func, A_base, B_base)
aligned_bf16_time = benchmark_fn_in_usec(aligned_bf16_func, A_base, B_base)
fp8_time = benchmark_fn_in_usec(
fp8_func, A_base, B_base, fp8_dtype, output_dtype
)

ref_tops_sec, ref_pct_top_peak = get_tops_info(
tops, ref_time, dtype_to_peak_tops[output_dtype]
)
aligned_bf16_tops_sec, aligned_bf16_pct_top_peak = get_tops_info(
tops, aligned_bf16_time, dtype_to_peak_tops[torch.bfloat16]
)
fp8_tops_sec, fp8_pct_top_peak = get_tops_info(
tops, fp8_time, dtype_to_peak_tops[fp8_dtype]
)
tops_table.append(
[
f"({M}x{K}x{N})",
f"{output_dtype}",
f"{ref_tops_sec:.2E}",
f"{aligned_bf16_tops_sec:.2E}",
f"{fp8_tops_sec:.2E}",
f"{ref_pct_top_peak:.3f}",
f"{aligned_bf16_pct_top_peak:.3f}",
f"{fp8_pct_top_peak:.3f}",
]
)
results.append(
[
(M, K, N),
output_dtype,
ref_time,
aligned_bf16_time,
fp8_time,
ref_time / aligned_bf16_time,
ref_time / fp8_time,
]
)

print("TOPs".center(80, "*"))
print(tabulate(tops_table, headers=tops_headers))
print("Speed Results".center(80, "*"))
headers = [
"Shape",
"Ref Dtype",
"Ref Time",
"Aligned BF16 Time",
"FP8 Time",
"Aligned BF16 Speedup",
"FP8 Speedup",
]
print(tabulate(results, headers=headers, tablefmt="grid"))


if __name__ == "__main__":
fire.Fire(run)
6 changes: 6 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@
# If True, use 'fnuz' float8 types for calculations.
# Currently, ROCm only supports fnuz variants.
use_fnuz_dtype = False

# If True, then prior to performing the fp8 scaled mamtmul we will pad the
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
# _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.
# This can cause a memory spike however so we keep this off by default.
pad_inner_dim = False
15 changes: 13 additions & 2 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,19 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
"bias": False,
}
new_mod = cls(**super_kwargs)
new_mod.forward_config = ScaledMMConfig(emulate, not bool(emulate))
new_mod.backward_config = ScaledMMConfig(emulate, False)

new_mod.forward_config = ScaledMMConfig(
emulate=emulate,
use_fast_accum=not bool(emulate),
fp8_output=False,
pad_inner_dim=config.pad_inner_dim,
)
new_mod.backward_config = ScaledMMConfig(
emulate=emulate,
use_fast_accum=False,
fp8_output=False,
pad_inner_dim=config.pad_inner_dim,
)
if config.enable_fsdp_fp8_all_gather:
new_mod.weight = nn.Parameter(
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
Expand Down
8 changes: 6 additions & 2 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def from_float(cls, mod, emulate: bool = False):
new_mod.create_buffers()
# Defines the behavior of the matmul in the forward and backward
# Forward we use fast_accum, backwards we do not
new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False)
new_mod.backward_config = ScaledMMConfig(emulate, False)
new_mod.forward_config = ScaledMMConfig(
emulate, True if not emulate else False, False, config.pad_inner_dim
)
new_mod.backward_config = ScaledMMConfig(
emulate, False, False, config.pad_inner_dim
)
return new_mod
13 changes: 12 additions & 1 deletion float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
merge_mm_configs,
ScaledMMConfig,
)
from float8_experimental.float8_utils import is_row_major
from float8_experimental.float8_utils import is_row_major, pad_tensor_for_matmul

from torch.utils._pytree import tree_map

aten = torch.ops.aten
Expand Down Expand Up @@ -121,6 +122,16 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
a_scale = a._scale
b_data = b._data

if a._mm_config.pad_inner_dim:
assert (
b._mm_config.pad_inner_dim
), "Both mm configs must have pad_inner_dim set to True"
assert a._data.size(1) == b._data.size(
0
), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}"
a_data = pad_tensor_for_matmul(a_data, dims=1)
b_data = pad_tensor_for_matmul(b_data, dims=0)

if not is_row_major(a_data.stride()):
a_data = a_data.contiguous()
if is_row_major(b_data.stride()):
Expand Down
1 change: 0 additions & 1 deletion float8_experimental/float8_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
to simplify the product code.
"""


from typing import Optional

import float8_experimental.float8_aten_api # noqa
Expand Down
6 changes: 4 additions & 2 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
# emulate: whether to emulate the matmuls in fp32
# use_fast_accum: whether to use the fast-accumulation option for scaled_mm
# fp8_output: whether to output the result of the scaled_mm in fp8
# pad_inner_dim: whether to pad the inner dimension of a and b with 0s. This is needed for matmuls not aligned to 16.
ScaledMMConfig = namedtuple(
"ScaledMMConfig",
["emulate", "use_fast_accum", "fp8_output"],
defaults=[False, False, False],
["emulate", "use_fast_accum", "fp8_output", "pad_inner_dim"],
defaults=[False, False, False, False],
)


Expand All @@ -48,6 +49,7 @@ def merge_mm_configs(
emulate=a_mm_config.emulate,
use_fast_accum=a_mm_config.use_fast_accum and b_mm_config.use_fast_accum,
fp8_output=a_mm_config.fp8_output and b_mm_config.fp8_output,
pad_inner_dim=a_mm_config.pad_inner_dim and b_mm_config.pad_inner_dim,
)


Expand Down
68 changes: 67 additions & 1 deletion float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Literal, Tuple
from typing import Iterable, Literal, Tuple, Union

import float8_experimental.config as config

Expand Down Expand Up @@ -179,3 +179,69 @@ def fp8_tensor_statistics(
def is_row_major(stride):
assert len(stride) == 2, "is_row_major only supports 2D tensors"
return stride[0] > stride[1] and stride[1] == 1


def _get_min_alignment(size: int, alignment_value: int) -> int:
"""
Returns the minimum alignment value that is greater than or equal to the given size.

Args:
size: The size of the data to be aligned.
alignment_value: The alignment value to be used.

Returns:
int: The minimum alignment value that is greater than or equal to the given size.

Usage:
```
>>> _get_min_alignment(10, 8)
16
```
"""
if size % alignment_value == 0:
return size
return (1 + (size // alignment_value)) * alignment_value


def pad_tensor_for_matmul(
tensor: torch.Tensor, dims: Union[int, Iterable[int]]
) -> torch.Tensor:
"""
Pads a 2D tensor with zeros to ensure that its dimensions are multiples of 16, which is required `torch._scaled_mm`

Args:
tensor: The tensor to pad.
both: Whether to pad both dimensions or just the second dimension.

Returns:
torch.Tensor: The padded tensor.

Usage:
```
>>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=0).shape
torch.Size([16, 10])
>>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=1).shape
torch.Size([10, 16])
>>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=(0, 1)).shape
torch.Size([16, 16])
```
"""
assert tensor.dim() == 2
dim1, dim2 = tensor.shape

if isinstance(dims, int):
dims = (dims,)

# Calculate aligned dimensions based on the specified dims
dim1_aligned = _get_min_alignment(dim1, 16) if 0 in dims else dim1
dim2_aligned = _get_min_alignment(dim2, 16) if 1 in dims else dim2

# Check if padding is needed for either dimension
if dim1 == dim1_aligned and dim2 == dim2_aligned:
return tensor

# Calculate padding values for both dimensions
pad_dim1 = dim1_aligned - dim1
pad_dim2 = dim2_aligned - dim2

return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1))
Loading
Loading