-
Notifications
You must be signed in to change notification settings - Fork 5
State dict serialization #51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really solid, I'm impressed with how quickly this came together. A couple nits, but nothing major.
Could you please test speed increase in e2e test_models tests and report in this PR?
edit: Ah I just remembered we talked about landing this in stages so I think some of the necessary plumbing wont exist.
torchstore/state_dict_utils.py
Outdated
size: int # Size in bytes | ||
|
||
|
||
def generate_tensor_blob(state_dict: Dict[str, Any]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we instead use flatten_state_dict instead of making this recursive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about making this a class method of a "TorchStoreStateDict", or similar?
Then we can do things like:
torchstore_sd = TorchStoreStateDict.from_state_dict(original_state_dict)
torchstore_sd.to_state_dict()
and also store any necessary data as objects in the state dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
torchstore/state_dict_utils.py
Outdated
return modified_state_dict, torch.empty(0, dtype=torch.uint8) | ||
|
||
# Calculate total size and update offsets | ||
current_offset = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have a _state_dict_size function in state dict utils
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_state_dict_size
calculates approximate size return size << 20
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the largest deltas on this PR are:
- Implementing as class
- Managing DTensor
My recommendation for DTensor is to first convert to tensor slice, and store all additional metadata in the state dict. (This is actually my advice always for dealing with dtensor so we can reduce the number of branches in the codebase)
Updated to class representation and using flattened state dict, which makes a lot of sense because list iteration is way simpler than recursion. Also added DTensor support with |
Haven't gone through the code but have a general question in mind. |
Hi @casteryh, for getting DTensor with a different sharding plan, right now the interface in torchstore is by specifying the get dtensor sharding plan in a inplace tensor. Right now in this PR, it only supports |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall LGTM.
still have some questions:
currently, this is not integrated with torchstore.put / torchstore.get right?
For example, if I have a state dict sd = {"a": t}
where t
is a DTensor
sharded across two ranks.
- On each rank, if I do
ts_sd = TorchStoreStateDict.from_state_dict(d)
, thents_sd
will no longer contain DTensors, right? - Consequently, if I do a
ts.put("state_dict_key", ts_sd)
on both ranks, then torchstore is supposed to detect that ts_sd is aTorchStoreStateDict
and handle the sharding logic accordingly, right? <- My understanding is this part is not done yet
tests/test_state_dict.py
Outdated
assert torchstore_state_dict.flattened_state_dict == {} | ||
assert len(torchstore_state_dict.tensor_blob) == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to be testing implementation details as opposed to behaviors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Removed.
tests/test_state_dict.py
Outdated
assert len(torchstore_state_dict.tensor_blob) == 0 | ||
reconstructed = torchstore_state_dict.to_state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
tests/test_state_dict.py
Outdated
scalar_dict = {"scalar": torch.tensor(3.14159)} | ||
torchstore_state_dict = TorchStoreStateDict.from_state_dict(scalar_dict) | ||
# Check flattened state dict has TensorReference | ||
scalar_ref = torchstore_state_dict.flattened_state_dict["scalar"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
# Create DTensor from local tensor | ||
local_tensor = torch.randn(4, 6, dtype=torch.float32) | ||
dtensor = DTensor.from_local(local_tensor, device_mesh, [Replicate()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a test for sharded dtensor (with world size > 1)? I am actually also confused about the expected behavior in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That dtensor put then get functionality will be added in the next PR where we integrate the state_dict functionality into torchstore. This PR only do the serialization and deserialization part.
from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset | ||
|
||
|
||
def create_tensor_slice_from_dtensor(dtensor: DTensor) -> "TensorSlice": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def create_tensor_slice_from_dtensor(dtensor: DTensor) -> "TensorSlice": | |
from torchstore.transport.pipe import TensorSlice | |
def create_tensor_slice_from_dtensor(dtensor: DTensor) -> TensorSlice: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the response to the other comment about import ordering.
Returns: | ||
TensorSlice containing the distributed tensor metadata | ||
""" | ||
from torchstore.transport.pipe import TensorSlice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a particular reason to avoid import this on the file level?
if not, move import to top of file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So there's a circular dependency where
pipe.TensorSlice
^
|
dtensor_util.create_tensor_slice_from_dtensor
^
|
pipe.Request.from_dtensor
Maybe we should put TensorSlice
definition into dtensor_util.py
module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would from __future__ import annotations
fix this?
If not then just leave it as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sync'ed with Yuxuan and Lucas. We will do DTensor put and get in the next PR. This current PR only makes sure that DTensor can be serialized and deserialized properly.
|
||
# Create DTensor from local tensor | ||
local_tensor = torch.randn(4, 6, dtype=torch.float32) | ||
dtensor = DTensor.from_local(local_tensor, device_mesh, [Replicate()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That dtensor put then get functionality will be added in the next PR where we integrate the state_dict functionality into torchstore. This PR only do the serialization and deserialization part.
from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset | ||
|
||
|
||
def create_tensor_slice_from_dtensor(dtensor: DTensor) -> "TensorSlice": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the response to the other comment about import ordering.
Returns: | ||
TensorSlice containing the distributed tensor metadata | ||
""" | ||
from torchstore.transport.pipe import TensorSlice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So there's a circular dependency where
pipe.TensorSlice
^
|
dtensor_util.create_tensor_slice_from_dtensor
^
|
pipe.Request.from_dtensor
Maybe we should put TensorSlice
definition into dtensor_util.py
module?
tests/test_state_dict.py
Outdated
scalar_dict = {"scalar": torch.tensor(3.14159)} | ||
torchstore_state_dict = TorchStoreStateDict.from_state_dict(scalar_dict) | ||
# Check flattened state dict has TensorReference | ||
scalar_ref = torchstore_state_dict.flattened_state_dict["scalar"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tests/test_state_dict.py
Outdated
assert torchstore_state_dict.flattened_state_dict == {} | ||
assert len(torchstore_state_dict.tensor_blob) == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Removed.
tests/test_state_dict.py
Outdated
assert len(torchstore_state_dict.tensor_blob) == 0 | ||
reconstructed = torchstore_state_dict.to_state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe try from __future__ import annotations
.
If it doesn't work then don't bother.
This PR creates two functions for state_dict serialization and deserialization.
generate_tensor_blob
recursively looks for tensors in the state_dict and serializes them into a blob of tensors. And replace the tensor in the state_dict withTensorReference
.TensorReference
contains the metadata for the offset, shape and dtype of the original tensor.reconstruct_state_dict_from_tensor_blob
does the reverse operation ofgenerate_tensor_blob
, it takes the tensor blob and state_dict (with only tensor_references) and replace all of the tensor_references inside the state_dict with the reconstructed tensors (from tensor blob andTensorReference
)