Skip to content

Conversation

kaiyuan-li
Copy link
Contributor

This PR creates two functions for state_dict serialization and deserialization.

  1. generate_tensor_blob recursively looks for tensors in the state_dict and serializes them into a blob of tensors. And replace the tensor in the state_dict with TensorReference. TensorReference contains the metadata for the offset, shape and dtype of the original tensor.
  2. reconstruct_state_dict_from_tensor_blob does the reverse operation of generate_tensor_blob, it takes the tensor blob and state_dict (with only tensor_references) and replace all of the tensor_references inside the state_dict with the reconstructed tensors (from tensor blob and TensorReference)
  3. Tests added.

@kaiyuan-li kaiyuan-li requested a review from LucasLLC October 6, 2025 17:16
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 6, 2025
Copy link
Contributor

@LucasLLC LucasLLC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really solid, I'm impressed with how quickly this came together. A couple nits, but nothing major.

Could you please test speed increase in e2e test_models tests and report in this PR?

edit: Ah I just remembered we talked about landing this in stages so I think some of the necessary plumbing wont exist.

size: int # Size in bytes


def generate_tensor_blob(state_dict: Dict[str, Any]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we instead use flatten_state_dict instead of making this recursive?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about making this a class method of a "TorchStoreStateDict", or similar?

Then we can do things like:

torchstore_sd = TorchStoreStateDict.from_state_dict(original_state_dict)
torchstore_sd.to_state_dict()

and also store any necessary data as objects in the state dict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return modified_state_dict, torch.empty(0, dtype=torch.uint8)

# Calculate total size and update offsets
current_offset = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have a _state_dict_size function in state dict utils

Copy link
Contributor Author

@kaiyuan-li kaiyuan-li Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_state_dict_size calculates approximate size return size << 20.

Copy link
Contributor

@LucasLLC LucasLLC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the largest deltas on this PR are:

  1. Implementing as class
  2. Managing DTensor

My recommendation for DTensor is to first convert to tensor slice, and store all additional metadata in the state dict. (This is actually my advice always for dealing with dtensor so we can reduce the number of branches in the codebase)

@kaiyuan-li
Copy link
Contributor Author

Updated to class representation and using flattened state dict, which makes a lot of sense because list iteration is way simpler than recursion.

Also added DTensor support with TensorSlice as metadata. Please take another look.

@casteryh
Copy link
Contributor

casteryh commented Oct 8, 2025

Haven't gone through the code but have a general question in mind.
Let's say I have a state dict consisting of DTensors and I convert it to a TorchStoreStateDict and do a ts.put.
On the get side, do I get a state dict of DTensors? What if I want materialized whole tensors for each tensor in the state dict, or if the get side has different sharding patterns than those of the put side?

@kaiyuan-li
Copy link
Contributor Author

Hi @casteryh, for getting DTensor with a different sharding plan, right now the interface in torchstore is by specifying the get dtensor sharding plan in a inplace tensor.

Right now in this PR, it only supports get('state_dict_key') for getting the whole state dict. @LucasLLC and I have discussed this morning and I will add a new feature so we can get a specific DTensor like get('dtensor.fqn', inplace_dtensor), then the DTensor can be resharded. I should be able to get that done by tomorrow.

Copy link
Contributor

@casteryh casteryh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM.

still have some questions:
currently, this is not integrated with torchstore.put / torchstore.get right?

For example, if I have a state dict sd = {"a": t} where t is a DTensor sharded across two ranks.

  • On each rank, if I do ts_sd = TorchStoreStateDict.from_state_dict(d), then ts_sd will no longer contain DTensors, right?
  • Consequently, if I do a ts.put("state_dict_key", ts_sd) on both ranks, then torchstore is supposed to detect that ts_sd is a TorchStoreStateDict and handle the sharding logic accordingly, right? <- My understanding is this part is not done yet

Comment on lines 499 to 500
assert torchstore_state_dict.flattened_state_dict == {}
assert len(torchstore_state_dict.tensor_blob) == 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be testing implementation details as opposed to behaviors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Removed.

Comment on lines 507 to 508
assert len(torchstore_state_dict.tensor_blob) == 0
reconstructed = torchstore_state_dict.to_state_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

scalar_dict = {"scalar": torch.tensor(3.14159)}
torchstore_state_dict = TorchStoreStateDict.from_state_dict(scalar_dict)
# Check flattened state dict has TensorReference
scalar_ref = torchstore_state_dict.flattened_state_dict["scalar"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


# Create DTensor from local tensor
local_tensor = torch.randn(4, 6, dtype=torch.float32)
dtensor = DTensor.from_local(local_tensor, device_mesh, [Replicate()])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test for sharded dtensor (with world size > 1)? I am actually also confused about the expected behavior in this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That dtensor put then get functionality will be added in the next PR where we integrate the state_dict functionality into torchstore. This PR only do the serialization and deserialization part.

from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset


def create_tensor_slice_from_dtensor(dtensor: DTensor) -> "TensorSlice":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def create_tensor_slice_from_dtensor(dtensor: DTensor) -> "TensorSlice":
from torchstore.transport.pipe import TensorSlice
def create_tensor_slice_from_dtensor(dtensor: DTensor) -> TensorSlice:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the response to the other comment about import ordering.

Returns:
TensorSlice containing the distributed tensor metadata
"""
from torchstore.transport.pipe import TensorSlice
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a particular reason to avoid import this on the file level?
if not, move import to top of file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So there's a circular dependency where

pipe.TensorSlice
^
|
dtensor_util.create_tensor_slice_from_dtensor
^
|
pipe.Request.from_dtensor

Maybe we should put TensorSlice definition into dtensor_util.py module?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would from __future__ import annotations fix this?
If not then just leave it as is.

Copy link
Contributor Author

@kaiyuan-li kaiyuan-li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sync'ed with Yuxuan and Lucas. We will do DTensor put and get in the next PR. This current PR only makes sure that DTensor can be serialized and deserialized properly.


# Create DTensor from local tensor
local_tensor = torch.randn(4, 6, dtype=torch.float32)
dtensor = DTensor.from_local(local_tensor, device_mesh, [Replicate()])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That dtensor put then get functionality will be added in the next PR where we integrate the state_dict functionality into torchstore. This PR only do the serialization and deserialization part.

from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset


def create_tensor_slice_from_dtensor(dtensor: DTensor) -> "TensorSlice":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the response to the other comment about import ordering.

Returns:
TensorSlice containing the distributed tensor metadata
"""
from torchstore.transport.pipe import TensorSlice
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So there's a circular dependency where

pipe.TensorSlice
^
|
dtensor_util.create_tensor_slice_from_dtensor
^
|
pipe.Request.from_dtensor

Maybe we should put TensorSlice definition into dtensor_util.py module?

scalar_dict = {"scalar": torch.tensor(3.14159)}
torchstore_state_dict = TorchStoreStateDict.from_state_dict(scalar_dict)
# Check flattened state dict has TensorReference
scalar_ref = torchstore_state_dict.flattened_state_dict["scalar"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 499 to 500
assert torchstore_state_dict.flattened_state_dict == {}
assert len(torchstore_state_dict.tensor_blob) == 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Removed.

Comment on lines 507 to 508
assert len(torchstore_state_dict.tensor_blob) == 0
reconstructed = torchstore_state_dict.to_state_dict()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

Copy link
Contributor

@casteryh casteryh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe try from __future__ import annotations.
If it doesn't work then don't bother.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants