Skip to content

[DTensor] Decide / Document RNG semantics #159991

@wconstab

Description

@wconstab

Following the work in #159933, DTensor will utilize a 'generator' passed as kwarg to random ops instead of using its own built-in generator.

This exposed a question of semantics, which I want to hear opinions on before implementing and documenting the behavior.

Question
If a user passes their own RNG to DTensor random ops, should the state-offset advancement be visible to the user on that RNG object?

My intuition is, yes: it is very surprising if the passed-in RNG object does not advance after being used.

However, implementing it this way introduces another discrepancy. DTensor ALREADY does this deviation, on the default RNG. Currently, DTensor keeps an alternate copy of the original default RNG, and only advances that copy.

Example: test_init_with_user_generator Example: First, the test uses the passed-in RNG to sample twice in a row in a loop, and ensures that both loop iterations match the values from the default RNG which wsa seeded with the same value. This part already works fine after #159933 lands.

Then, it runs into trouble with trying to establish a sane UX if someone re-seeds the RNG, or if someone inspects the state of the passed-in RNG.

    def test_init_with_user_generator(self):
        device_mesh = self.build_device_mesh()
        torch.manual_seed(42)
        rng = torch.Generator(device="cuda").manual_seed(42)
        t1 = torch.distributed.tensor.empty(
            (2, 3), device_mesh=device_mesh, placements=[Shard(0)]
        )
        t2 = torch.distributed.tensor.empty(
            (2, 3), device_mesh=device_mesh, placements=[Shard(0)]
        )
        for i in range(2):
            print(f"{i=}")
            # run a second time, to make sure that `rng`'s offset-state is advancing on the second usage
            torch.nn.init.uniform_(t1, 0.0, 1.0)
            torch.nn.init.uniform_(t2, 0.0, 1.0, rng)
            self.assertEqual(t1.full_tensor(), t2.full_tensor(), f"Failed at {i=}")

        # Problem area:
        # (1)
        # torch.distributed.tensor._random._rng_tracker._manual_seed(55)
        # torch.manual_seed(55)
        # (2)
        # rng.manual_seed(55)
        torch.nn.init.uniform_(t1, 0.0, 1.0)
        torch.nn.init.uniform_(t2, 0.0, 1.0, rng)
        self.assertEqual(t1.full_tensor(), t2.full_tensor())

Problems, in this example:

  1. If someone wants to affect the state of DTensor's RNG after init, they have to know to call this hidden manual_seed API torch.distributed.tensor._random._rng_tracker._manual_seed instead of calling torch.manual_seed which worked in the first place.

  2. if someone wanted to re-seed the passed-in RNG object, I made that ineffective in [DTensor] Support user-supplied Generator for random ops #159933 so that it is consistent with the behavior of the default RNG. I also made it so rng.get_state() is consistent with torch.cuda.get_rng_state() (e.g. both will show the initial state, not reflecting any activity from DTensor)

Proposal:
We could make DTensor update the global RNG state so that (1) does not require a hidden seed API, and DTensor always uses the latest RNG state from torch. We can then easily also make the passed-in RNG state get updated whenever its used.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k @pragupta @svekars @sekyondaMeta @AlannaBurke @tianyu-l @XilunWu

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: docsRelated to our documentation, both in docs/ and docblocksmodule: dtensordistributed tensor tagoncall: distributedAdd this issue/PR to distributed oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions