From 22af7efb75025504d866e80cafa1c416cfaaac8d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 9 Jul 2024 13:36:29 -0700 Subject: [PATCH 1/2] support delayed scaling of weight in float8 all-gather Summary: Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/float8_linear.py | 84 ++++++---- float8_experimental/float8_linear_utils.py | 13 +- float8_experimental/fsdp_utils.py | 182 ++++++++++++++++++++- test/test_fsdp2/test_fsdp2_common.py | 13 +- test/test_fsdp2/test_fsdp2_eager.py | 75 +++++++-- 5 files changed, 315 insertions(+), 52 deletions(-) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 23d6f5c..5aae675 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -34,7 +34,10 @@ tensor_to_amax, ) -from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor +from float8_experimental.fsdp_utils import ( + WeightWithDelayedFloat8CastTensor, + WeightWithDynamicFloat8CastTensor, +) def _maybe_initialize_amaxes_scales_for_float8_cast( @@ -316,25 +319,28 @@ def cast_w_to_float8( self, w: torch.Tensor, is_amax_initialized: bool ) -> torch.Tensor: if self.scaling_type_w is TensorScalingType.DELAYED: - scale_fn_name = self.recipe.scale_fn_name - _maybe_initialize_amaxes_scales_for_float8_cast( - w, - self.fp8_amax_w, - self.fp8_amax_history_w, - self.fp8_scale_w, - scale_fn_name, - e4m3_dtype, - is_amax_initialized, - reduce_amax=False, - ) - - w_fp8 = Float8Tensor.to_float8( - w, - self.fp8_scale_w, - e4m3_dtype, - self.fp8_amax_w, - self.forward_config, - ) + if isinstance(self.weight, Float8Tensor): # cast by FSDP + w_fp8 = self.weight + else: + scale_fn_name = self.recipe.scale_fn_name + _maybe_initialize_amaxes_scales_for_float8_cast( + w, + self.fp8_amax_w, + self.fp8_amax_history_w, + self.fp8_scale_w, + scale_fn_name, + e4m3_dtype, + is_amax_initialized, + reduce_amax=False, + ) + + w_fp8 = Float8Tensor.to_float8( + w, + self.fp8_scale_w, + e4m3_dtype, + self.fp8_amax_w, + self.forward_config, + ) else: assert self.scaling_type_w is TensorScalingType.DYNAMIC # TODO(future): also support FSDP integration in delayed scaling path @@ -436,18 +442,36 @@ def from_float( scaling_type_dL_dY=scaling_type_dL_dY, emulate=emulate, ) - if ( - scaling_type_w == TensorScalingType.DYNAMIC - and config.enable_fsdp_fp8_all_gather - ): - new_mod.weight = torch.nn.Parameter( - WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config) - ) - else: - assert not config.enable_fsdp_fp8_all_gather, "unsupported" - new_mod.weight = mod.weight + new_mod.weight = mod.weight new_mod.bias = mod.bias # need to create buffers again when moving from meta device to # real device new_mod.create_buffers() + + # If FSDP float8 all-gather is on, wrap the weight in a float8-aware + # tensor subclass. This must happen last because: + # 1. weight needs to be on the correct device to create the buffers + # 2. buffers need to be already created for the delayed scaling version + # of the weight wrapper to be initialized + if config.enable_fsdp_fp8_all_gather: + if scaling_type_w is TensorScalingType.DYNAMIC: + new_mod.weight = torch.nn.Parameter( + WeightWithDynamicFloat8CastTensor( + new_mod.weight, + new_mod.forward_config, + ) + ) + else: + assert scaling_type_w is TensorScalingType.DELAYED + new_mod.weight = torch.nn.Parameter( + WeightWithDelayedFloat8CastTensor( + new_mod.weight, + new_mod.fp8_amax_w, + new_mod.fp8_amax_history_w, + new_mod.fp8_scale_w, + new_mod.forward_config, + new_mod.is_amax_initialized, + ) + ) + return new_mod diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 13b47a3..75bcb86 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -291,11 +291,10 @@ def inner_func(): ), "Mismatched lengths of amax tensors." if dist.is_initialized(): - # Combine all the amax tensors into one tensor and reduce it - # Note: do not reduce the weight values, because FSDP already ensures - # the weight values on all ranks are the same after all-gather. all_amax_tensors = torch.cat( - fp8_amax_x_tensor_list + fp8_amax_dL_dY_tensor_list + fp8_amax_x_tensor_list + + fp8_amax_w_tensor_list + + fp8_amax_dL_dY_tensor_list ) all_reduced_amax_tensor = all_reduce( all_amax_tensors, "MAX", list(range(dist.get_world_size())) @@ -304,12 +303,14 @@ def inner_func(): all_reduced_amax_tensor = all_reduced_amax_tensor.wait() ( - reduced_fp8_amax_tensor, + reduced_fp8_amax_x_tensor, + reduced_fp8_amax_w_tensor, reduced_fp8_amax_dL_dY_tensor, ) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list)) for idx, child in enumerate(fp8_layers): - child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx]) + child.fp8_amax_x.copy_(reduced_fp8_amax_x_tensor[idx]) + child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx]) child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx]) # We create two stacked tensor groups, one for the amax history and one for the current scales diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 41871d8..64feca1 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -6,15 +6,17 @@ from typing import Any, Optional, Tuple +import float8_experimental.config as config + import torch import torch.utils._pytree as pytree from float8_experimental.float8_dynamic_utils import cast_to_float8_e4m3_dynamic - from float8_experimental.float8_tensor import ( Float8Tensor, merge_mm_configs, ScaledMMConfig, ) +from float8_experimental.float8_utils import e4m3_dtype from torch._prims_common import suggest_memory_format # FSDP pads its local tensor on dim-0. The subclass should be preserved such @@ -110,3 +112,181 @@ def fsdp_post_all_gather( out._scale = scale return return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) + + +class WeightWithDelayedFloat8CastTensor(torch.Tensor): + @staticmethod + def __new__( + cls, + tensor: torch.Tensor, + amax_buffer: torch.Tensor, + amax_history_buffer: torch.Tensor, + scale_buffer: torch.Tensor, + mm_config: ScaledMMConfig, + is_amax_initialized: bool, + ): + return torch.Tensor._make_wrapper_subclass( + cls, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + memory_format=suggest_memory_format(tensor), + dtype=tensor.dtype, + layout=tensor.layout, + device=tensor.device, + pin_memory=tensor.is_pinned(), + requires_grad=tensor.requires_grad, + ) + + def __init__( + self, + tensor: torch.Tensor, + amax_buffer: torch.Tensor, + amax_history_buffer: torch.Tensor, + scale_buffer: torch.Tensor, + mm_config: ScaledMMConfig, + is_amax_initialized: bool, + ): + self._tensor = tensor + self._amax_buffer = amax_buffer + self._amax_history_buffer = amax_history_buffer + self._scale_buffer = scale_buffer + self._mm_config = mm_config + + # Note: is_amax_initialized is not a buffer to avoid data dependent + # control flow visible to dynamo + # TODO(future PR): add serialization for this flag + self.is_amax_initialized = is_amax_initialized + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func == torch.ops.aten.detach.default: + return WeightWithDelayedFloat8CastTensor( + args[0]._tensor, + args[0]._amax_buffer, + args[0]._amax_history_buffer, + args[0]._scale_buffer, + args[0]._mm_config, + args[0].is_amax_initialized, + ) + mm_config: Optional[ScaledMMConfig] = None + amax_buffer: Optional[torch.Tensor] = None + amax_history_buffer: Optional[torch.Tensor] = None + scale_buffer: Optional[torch.Tensor] = None + is_amax_initialized: Optional[bool] = None + + def unwrap(t): + nonlocal mm_config + if mm_config is None: + mm_config = t._mm_config + else: + mm_config = merge_mm_configs(mm_config, t._mm_config) + nonlocal amax_buffer + if amax_buffer is None: + amax_buffer = t._amax_buffer + nonlocal amax_history_buffer + if amax_history_buffer is None: + amax_history_buffer = t._amax_history_buffer + nonlocal scale_buffer + if scale_buffer is None: + scale_buffer = t._scale_buffer + nonlocal is_amax_initialized + if is_amax_initialized is None: + is_amax_initialized = t.is_amax_initialized + return t._tensor + + args, kwargs = pytree.tree_map_only( + WeightWithDelayedFloat8CastTensor, unwrap, (args, kwargs or {}) + ) + out = func(*args, **kwargs) + if func not in _ops_to_preserve_subclass: + return out + return pytree.tree_map_only( + torch.Tensor, + lambda x: WeightWithDelayedFloat8CastTensor( + x, + amax_buffer, + amax_history_buffer, + scale_buffer, + mm_config, + is_amax_initialized, + ), + out, + ) + + def __tensor_flatten__(self): + return ( + [ + "_tensor", + "_amax_buffer", + "_amax_history_buffer", + "_scale_buffer", + ], + self._mm_config, + is_amax_initialized, + ) + + @staticmethod + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + mm_config, is_amax_initialized = flatten_spec + return WeightWithDelayedFloat8CastTensor( + inner_tensors["_tensor"], + inner_tensors["_amax_buffer"], + inner_tensors["_amax_history_buffer"], + inner_tensors["_scale_buffer"], + mm_config, + is_amax_initialized, + ) + + def __repr__(self): + return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._mm_config})" + + def fsdp_pre_all_gather(self, mesh): + # initialize if needed + # TODO(before land): ensure settings are consistent between Float8Linear and here + if not self.is_amax_initialized: + from float8_experimental.float8_linear import ( + _maybe_initialize_amaxes_scales_for_float8_cast, + ) + + _maybe_initialize_amaxes_scales_for_float8_cast( + self._tensor, + self._amax_buffer, + self._amax_history_buffer, + self._scale_buffer, + "max", # TODO(before land): read this from parent + e4m3_dtype, + self.is_amax_initialized, + reduce_amax=True, + ) + self.is_amax_initialized = True + + # this will: + # 1. cast the tensor to float8 using `_scale_buffer` + # 2. populate `_amax_buffer` inplace + # TODO(future PR): clean up all the casting functions and clearly + # separate dynamic vs delayed, tech debt has accumulated + float8_tensor = Float8Tensor.to_float8( + self._tensor, + self._scale_buffer, + e4m3_dtype, + self._amax_buffer, + self._mm_config, + ) + return (float8_tensor._data,), (float8_tensor._scale,) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[torch.Tensor] = None, + ): + (data,) = all_gather_outputs + (scale,) = metadata + if out is not None: + assert isinstance(out, Float8Tensor), f"{type(out)}" + out._scale = scale + return + return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index c20e8cc..368a82d 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -6,7 +6,11 @@ import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.float8_linear import Float8Linear +from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear_utils import ( + linear_requires_sync, + sync_float8_amax_and_scale_history, +) def check_parity_no_mp( @@ -16,6 +20,7 @@ def check_parity_no_mp( fsdp_model: nn.Module, fsdp_optim: torch.optim.Optimizer, local_inp: torch.Tensor, + scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC, ): for iter_idx in range(10): losses: List[torch.Tensor] = [] @@ -27,8 +32,12 @@ def check_parity_no_mp( for param in model.parameters(): dist.all_reduce(param.grad) param.grad.div_(dist.get_world_size()) - # TODO(future): add amax syncing once delayed scaling is supported + + if linear_requires_sync(scaling_type_w=scaling_type_w): + sync_float8_amax_and_scale_history(model) + optim.step() + test_cls.assertEqual(losses[0], losses[1]) diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 2af1dc8..2639cf2 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -80,11 +80,17 @@ def world_size(self) -> int: return min(torch.cuda.device_count(), 2) @skip_if_lt_x_gpu(2) - def test_transformer_parity_dynamic(self): - for enable_fsdp_fp8_all_gather in [False, True]: - self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather) + def test_transformer_parity(self): + choices = itertools.product( + [False, True], + [TensorScalingType.DYNAMIC, TensorScalingType.DELAYED], + ) + for enable_fsdp_fp8_all_gather, scaling_type_w in choices: + self._test_transformer_parity(enable_fsdp_fp8_all_gather, scaling_type_w) - def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): + def _test_transformer_parity( + self, enable_fsdp_fp8_all_gather: bool, scaling_type_w: TensorScalingType + ): # NOTE: Weight-tying does not compose with fp8 all-gather because the # embedding weight and output linear weight are tied but only the # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to @@ -92,9 +98,9 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): weight_tying = not enable_fsdp_fp8_all_gather module = self.init_transformer(weight_tying=weight_tying).cuda() ref_module = copy.deepcopy(module) - swap_linear_with_float8_linear(ref_module) + swap_linear_with_float8_linear(ref_module, scaling_type_w=scaling_type_w) with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - swap_linear_with_float8_linear(module) + swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w) for submodule in module.modules(): if isinstance(submodule, TransformerBlock): fully_shard(submodule) @@ -104,7 +110,15 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): local_inp = torch.randint( 0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda" ) - check_parity_no_mp(self, ref_module, ref_optim, module, optim, local_inp) + check_parity_no_mp( + self, + ref_module, + ref_optim, + module, + optim, + local_inp, + scaling_type_w=scaling_type_w, + ) @skip_if_lt_x_gpu(2) def test_transformer_memory(self): @@ -364,13 +378,21 @@ def test_fp32_fp8_single_module_parity(self): Tests numeric parity for fp32 parameters with fp8 computation with a single module/FSDP communication group. """ - for enable_fsdp_fp8_all_gather in [False, True]: + choices = itertools.product( + [False, True], + [TensorScalingType.DYNAMIC, TensorScalingType.DELAYED], + ) + for enable_fsdp_fp8_all_gather, scaling_type_w in choices: module_fp32 = self.init_single_module() ref_module = copy.deepcopy(module_fp32) - ref_module = swap_linear_with_float8_linear(ref_module) + ref_module = swap_linear_with_float8_linear( + ref_module, scaling_type_w=scaling_type_w + ) ref_module = ref_module.cuda() with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - module = swap_linear_with_float8_linear(module_fp32) + module = swap_linear_with_float8_linear( + module_fp32, scaling_type_w=scaling_type_w + ) fully_shard(module) ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) @@ -382,6 +404,7 @@ def test_fp32_fp8_single_module_parity(self): module, optim, local_inp, + scaling_type_w=scaling_type_w, ) @unittest.skipIf(not TEST_CUDA, "no cuda") @@ -390,12 +413,20 @@ def test_fp32_fp8_multi_module_parity(self): Tests numeric parity for fp32 parameters with fp8 computation with multiple modules/FSDP communication groups. """ - for enable_fsdp_fp8_all_gather in [False, True]: + choices = itertools.product( + [False, True], + [TensorScalingType.DYNAMIC, TensorScalingType.DELAYED], + ) + for enable_fsdp_fp8_all_gather, scaling_type_w in choices: module = self.init_multi_module().cuda() ref_module = copy.deepcopy(module) - ref_module = swap_linear_with_float8_linear(ref_module) + ref_module = swap_linear_with_float8_linear( + ref_module, scaling_type_w=scaling_type_w + ) with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - module = swap_linear_with_float8_linear(module) + module = swap_linear_with_float8_linear( + module, scaling_type_w=scaling_type_w + ) for submodule in module: fully_shard(submodule) fully_shard(module) @@ -409,6 +440,7 @@ def test_fp32_fp8_multi_module_parity(self): module, optim, local_inp, + scaling_type_w=scaling_type_w, ) @unittest.skipIf(not TEST_CUDA, "no cuda") @@ -443,6 +475,23 @@ def test_bf16_mp_fp8_dynamic_multi_parity(self): self.get_local_inp(torch.bfloat16), ) + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_delayed_scaling_inplace_update(self): + """ + Verify that `WeightWithDelayedFloat8CastTensor` updates buffers inplace + """ + module = self.init_single_module() + with set_enable_fsdp_fp8_all_gather(True): + m_fp8 = swap_linear_with_float8_linear( + module, + scaling_type_w=TensorScalingType.DELAYED, + ) + + fp8_amax_w_old = m_fp8.fp8_amax_w.clone().detach() + dummy_mesh = None + data, scale = m_fp8.weight.fsdp_pre_all_gather(dummy_mesh) + self.assertNotEqual(fp8_amax_w_old.item(), m_fp8.fp8_amax_w.item()) + if __name__ == "__main__": run_tests() From 62f75c67c63d5fef85ed01b27c0434fae14289de Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 10 Jul 2024 15:47:16 -0700 Subject: [PATCH 2/2] Update on "support delayed scaling of weight in float8 all-gather" Summary: Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/float8_linear.py | 1 - float8_experimental/fsdp_utils.py | 13 +++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 5aae675..7850738 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -343,7 +343,6 @@ def cast_w_to_float8( ) else: assert self.scaling_type_w is TensorScalingType.DYNAMIC - # TODO(future): also support FSDP integration in delayed scaling path if isinstance(self.weight, Float8Tensor): # cast by FSDP w_fp8 = self.weight else: diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 64feca1..365ea01 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -222,20 +222,21 @@ def __tensor_flatten__(self): "_amax_history_buffer", "_scale_buffer", ], - self._mm_config, - is_amax_initialized, + { + "mm_config": self._mm_config, + "is_amax_initialized": is_amax_initialized, + }, ) @staticmethod - def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - mm_config, is_amax_initialized = flatten_spec + def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): return WeightWithDelayedFloat8CastTensor( inner_tensors["_tensor"], inner_tensors["_amax_buffer"], inner_tensors["_amax_history_buffer"], inner_tensors["_scale_buffer"], - mm_config, - is_amax_initialized, + metadata["mm_config"], + metadata["is_amax_initialized"], ) def __repr__(self):