-
Notifications
You must be signed in to change notification settings - Fork 19
support delayed scaling of weight in float8 all-gather #312
Conversation
Summary: Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f1707c1 Pull Request resolved: #312
what are the optional tensors ? |
all_amax_tensors = torch.cat( | ||
fp8_amax_x_tensor_list + fp8_amax_dL_dY_tensor_list | ||
fp8_amax_x_tensor_list | ||
+ fp8_amax_w_tensor_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we only do this if we are using fp8 all gather ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that could make sense, I'd love to see the data to see if this is going to matter for performance. Focusing on numerics for now, was hoping for performance be tackled in future PRs.
@@ -110,3 +112,181 @@ def fsdp_post_all_gather( | |||
out._scale = scale | |||
return | |||
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) | |||
|
|||
|
|||
class WeightWithDelayedFloat8CastTensor(torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[no change needed] I wish there was a way to share some more code with the dynamic version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, me too. Looking at the code below, really the only code which would be shared is fsdp_post_all_gather
, everything else would have to have if/else branches for delayed vs dynamic
def __repr__(self): | ||
return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._mm_config})" | ||
|
||
def fsdp_pre_all_gather(self, mesh): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ill let @weifengpy confirm this portion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
confirming that fsdp part looks good
Summary: Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c83e4df Pull Request resolved: #312
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args, kwargs=None): | ||
if func == torch.ops.aten.detach.default: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly just a nit, but any reason to special-case detach here? Alternatively, you could set it up so that every view ops automatiomatically propagates subclass-ness in the same way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is something I wrote, I think it was just something I saw in some other subclasses. Having every view up propagate subclass-ness in the same way sounds good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamping for the fsdp part
document 2 open questions (not blocker for this PR)
- should we merge
WeightWithDelayedFloat8CastTensor
andWeightWithDynamicFloat8CastTensor
into one class and add if-else to unify logic around__torch_dispatch__
,fsdp_pre_all_gather
/fsdp_post_all_gather
. we unifedFloat8Linear
already - compare perfs between
sync_float8_amax_and_scale_history
andprecompute_float8_dynamic_scale_for_fsdp
. If they are similar, people would not need to worry about numeric problem from delayed scaling
I'm open if someone is interested in doing that in a follow-up PR. I'm not sure it will be better than what we have now though. Note that
yes, that would be great! I think we can do this in follow-up PRs. Note that delayed scaling is theoretically faster than dynamic scaling (less memory reads), but performance is not optimized across the stack yet. I think it's good to have options and allow people to optimize different settings in parallel. Eventually if there is clear data that only one of these is needed, we can delete the not-needed ones. |
Summary: Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: cdc9d96 Pull Request resolved: #312
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
This pull request has been merged in de93990. |
|
||
def fsdp_pre_all_gather(self, mesh): | ||
# initialize if needed | ||
# TODO(before land): ensure settings are consistent between Float8Linear and here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we still need to resolve this?
self._amax_buffer, | ||
self._amax_history_buffer, | ||
self._scale_buffer, | ||
"max", # TODO(before land): read this from parent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
Stack from ghstack (oldest at bottom):
swap_linear_with_dynamic
from fsdp2 eager test case #311Summary:
Adds support for delayed scaling in FSDP2 float8 all-gather. In detail:
WeightWithDelayedFloat8CastTensor
, note that we don't reusecode with the dynamic version because I'd rather not deal with
plumbing optional tensors through dynamo. We can try that in a
separate PR later.
Float8Linear
to use (1)Next up (in separate PRs) will be training run validation for numerics, and
taking a look at performance.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D59685258