-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Description
Following the work in #159933, DTensor will utilize a 'generator' passed as kwarg to random ops instead of using its own built-in generator.
This exposed a question of semantics, which I want to hear opinions on before implementing and documenting the behavior.
Question
If a user passes their own RNG to DTensor random ops, should the state-offset advancement be visible to the user on that RNG object?
My intuition is, yes: it is very surprising if the passed-in RNG object does not advance after being used.
However, implementing it this way introduces another discrepancy. DTensor ALREADY does this deviation, on the default RNG. Currently, DTensor keeps an alternate copy of the original default RNG, and only advances that copy.
Example: test_init_with_user_generator
Example: First, the test uses the passed-in RNG to sample twice in a row in a loop, and ensures that both loop iterations match the values from the default RNG which wsa seeded with the same value. This part already works fine after #159933 lands.Then, it runs into trouble with trying to establish a sane UX if someone re-seeds the RNG, or if someone inspects the state of the passed-in RNG.
def test_init_with_user_generator(self):
device_mesh = self.build_device_mesh()
torch.manual_seed(42)
rng = torch.Generator(device="cuda").manual_seed(42)
t1 = torch.distributed.tensor.empty(
(2, 3), device_mesh=device_mesh, placements=[Shard(0)]
)
t2 = torch.distributed.tensor.empty(
(2, 3), device_mesh=device_mesh, placements=[Shard(0)]
)
for i in range(2):
print(f"{i=}")
# run a second time, to make sure that `rng`'s offset-state is advancing on the second usage
torch.nn.init.uniform_(t1, 0.0, 1.0)
torch.nn.init.uniform_(t2, 0.0, 1.0, rng)
self.assertEqual(t1.full_tensor(), t2.full_tensor(), f"Failed at {i=}")
# Problem area:
# (1)
# torch.distributed.tensor._random._rng_tracker._manual_seed(55)
# torch.manual_seed(55)
# (2)
# rng.manual_seed(55)
torch.nn.init.uniform_(t1, 0.0, 1.0)
torch.nn.init.uniform_(t2, 0.0, 1.0, rng)
self.assertEqual(t1.full_tensor(), t2.full_tensor())
Problems, in this example:
-
If someone wants to affect the state of DTensor's RNG after init, they have to know to call this hidden manual_seed API
torch.distributed.tensor._random._rng_tracker._manual_seedinstead of callingtorch.manual_seedwhich worked in the first place. -
if someone wanted to re-seed the passed-in RNG object, I made that ineffective in [DTensor] Support user-supplied Generator for random ops #159933 so that it is consistent with the behavior of the default RNG. I also made it so
rng.get_state()is consistent withtorch.cuda.get_rng_state()(e.g. both will show the initial state, not reflecting any activity from DTensor)
Proposal:
We could make DTensor update the global RNG state so that (1) does not require a hidden seed API, and DTensor always uses the latest RNG state from torch. We can then easily also make the passed-in RNG state get updated whenever its used.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k @pragupta @svekars @sekyondaMeta @AlannaBurke @tianyu-l @XilunWu