From f92709df27e34d59f1a057e095e693a3dae910ee Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 2 Jul 2024 16:01:26 -0700 Subject: [PATCH] [9/x]: make dynamic scaling default in Float8Linear Summary: 1. makes dynamic scaling default in Float8Linear for an easier migration of callsites which currently use Float8DynamicLinear. Fixes tests as needed. 2. updates the README to reference Float8Linear for dynamic scaling Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- README.md | 31 ++++++++++++++-------- float8_experimental/float8_linear_utils.py | 6 ++--- test/test_compile.py | 16 +++++++++-- test/test_fsdp.py | 9 ++++++- test/test_fsdp_compile.py | 11 ++++++-- 5 files changed, 54 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 1db10096..fa093c32 100644 --- a/README.md +++ b/README.md @@ -27,21 +27,23 @@ pip install -e ".[dev]" # User API -We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. +We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`x`), weights (`w`) and gradients (`dL_dY`). -## float8 linear with dynamic scaling +## float8 linear with dynamic scaling for `x`, `w` and `dL_dY` + +This is the most accurate recipe as every tensor is scaled dynamically. ```python from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, ) -from float8_experimental.float8_dynamic_linear import Float8DynamicLinear +from float8_experimental.float8_linear import Float8Linear # create model m = Model(...) -# convert all `torch.nn.Linear` modules to `Float8DynamicLinear` -swap_linear_with_float8_linear(m, Float8DynamicLinear) +# convert all `torch.nn.Linear` modules to `Float8Linear` +swap_linear_with_float8_linear(m, Float8Linear) # optional: use FSDP model = FSDP(model, use_orig_params=True) @@ -54,18 +56,27 @@ m = torch.compile(m) ## float8 linear with delayed scaling +This is theoretically the most performant recipe as it minimizes memory reads. + ```python from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) -from float8_experimental.float8_linear import Float8Linear +from float8_experimental.float8_linear import Float8Linear, TensorScalingType # create model m = Model(...) -# convert all `torch.nn.Linear` modules to `Float8Linear` -swap_linear_with_float8_linear(m, Float8Linear) +# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling +# type +swap_linear_with_float8_linear( + m, + Float8Linear, + scaling_type_x=TensorScalingType.DELAYED, + scaling_type_w=TensorScalingType.DELAYED, + scaling_type_dL_dY=TensorScalingType.DELAYED, +) # optional: use FSDP. Note that workarounds gated with config.enable_amax_init and # config.enable_pre_and_post_forward are needed for autocast + compile + FSDP + float8 to work @@ -93,9 +104,7 @@ for _ in range(N_ITER): # 🧭 Code Organization * `float8_experimental/float8_linear.py` - - `Float8Linear` (main user facing entry point for delayed scaling) -* `float8_experimental/float8_dynamic_linear.py` - - `Float8DynamicLinear` (main user facing entry point for dynamic scaling) + - `Float8Linear` (main user facing entry point for Float8Linear) * `float8_experimental/float8_tensor.py` - `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction - `ScaledMMConfig` defines the semantics for matmul in the forward and backwards pass diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index cbf992ef..b1a17e4f 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -191,9 +191,9 @@ def swap_linear_with_float8_linear( skip_fqn_list: Optional[List[str]] = None, emulate: bool = False, linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None, - scaling_type_x: TensorScalingType = TensorScalingType.DELAYED, - scaling_type_w: TensorScalingType = TensorScalingType.DELAYED, - scaling_type_dL_dY: TensorScalingType = TensorScalingType.DELAYED, + scaling_type_x: TensorScalingType = TensorScalingType.DYNAMIC, + scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC, + scaling_type_dL_dY: TensorScalingType = TensorScalingType.DYNAMIC, ) -> Optional[nn.Module]: """ Swaps `torch.nn.Linear` in `module` with `Float8Linear` or `Float8DynamicLinear`. diff --git a/test/test_compile.py b/test/test_compile.py index 5d904876..834d126f 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -299,7 +299,13 @@ def test_sync_amax_func(): module = torch.nn.Sequential( nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) ) - float8_mod = swap_linear_with_float8_linear(module, Float8Linear) + float8_mod = swap_linear_with_float8_linear( + module, + Float8Linear, + scaling_type_x=TensorScalingType.DELAYED, + scaling_type_w=TensorScalingType.DELAYED, + scaling_type_dL_dY=TensorScalingType.DELAYED, + ) compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts) compiled_swap_func(float8_mod) assert cnts.frame_count == 1, "Compiled graph should have 1 frame!" @@ -329,7 +335,13 @@ def test_sync_amax_func_cuda_graph_success(): my_module = nn.Sequential( nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) ).to("cuda") - swap_linear_with_float8_linear(my_module, Float8Linear) + swap_linear_with_float8_linear( + my_module, + Float8Linear, + scaling_type_x=TensorScalingType.DELAYED, + scaling_type_w=TensorScalingType.DELAYED, + scaling_type_dL_dY=TensorScalingType.DELAYED, + ) inpt = torch.randn( 16, 16, device="cuda", dtype=torch.float32, requires_grad=True ) diff --git a/test/test_fsdp.py b/test/test_fsdp.py index ff31ca31..031b40d8 100644 --- a/test/test_fsdp.py +++ b/test/test_fsdp.py @@ -23,6 +23,8 @@ import torch.nn as nn from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import ( + linear_requires_sync, + LinearType, swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) @@ -130,7 +132,12 @@ def forward_backward(model, optim, is_fp8, i): optim.zero_grad() y_local = model(ref_input_local[i]) y_local.backward(ref_grad_local[i]) - if is_fp8: + if is_fp8 and linear_requires_sync( + LinearType.DELAYED, + TensorScalingType.DYNAMIC, + scaling_type_w, + TensorScalingType.DYNAMIC, + ): sync_float8_func(model) optim.step() return y_local diff --git a/test/test_fsdp_compile.py b/test/test_fsdp_compile.py index 389b7569..cc449341 100644 --- a/test/test_fsdp_compile.py +++ b/test/test_fsdp_compile.py @@ -18,7 +18,7 @@ import torch.multiprocessing as mp import torch.nn as nn from float8_experimental import config -from float8_experimental.float8_linear import Float8Linear +from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, @@ -49,7 +49,14 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): nn.Linear(K, N, dtype=base_dtype), nn.ReLU(), ) - swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate) + swap_linear_with_float8_linear( + m, + Float8Linear, + emulate=emulate, + scaling_type_x=TensorScalingType.DELAYED, + scaling_type_w=TensorScalingType.DELAYED, + scaling_type_dL_dY=TensorScalingType.DELAYED, + ) return m