Skip to content

[distributed][_composable/test_composability] RuntimeError: 'Attempted to flatten sharded dimension 1 #1881

@zxd1997066

Description

@zxd1997066

🐛 Describe the bug

get wheels from https://github.com/intel/torch-xpu-ops/actions/runs/16365081838
git clone -b xiangdong/dist_upstream_p2 https://github.com/zxd1997066/pytorch.git
cd pytorch
pip install pytest expecttest
pip install -r requirements.txt

pytest -v test/distributed/_composable/test_composability/test_2d_composability.py::TestFullyShard2DTraining::test_train_parity_2d_transformer
pytest -v test/distributed/_composable/test_composability/test_2d_composability.py::TestFullyShard2DTraining::test_train_parity_2d_transformer_checkpoint_resume
Traceback (most recent call last):
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 853, in run_test
    getattr(self, test_name)()
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 707, in wrapper
    fn()
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3225, in wrapper
    method(*args, **kwargs)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 221, in wrapper
    return func(*args, **kwargs)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/testing/_internal/distributed/checkpoint_utils.py", line 153, in wrapper
    func(self, *args, **kwargs)
  File "/home/sdp/xiangdong/xccl_upstream/pytorch/test/distributed/_composable/test_composability/test_2d_composability.py", line 312, in test_train_parity_2d_transformer_checkpoint_resume
    self.run_subtests(
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/testing/_internal/common_fsdp.py", line 1188, in run_subtests
    return run_subtests(self, *args, **kwargs)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 1136, in run_subtests
    test_fn(*test_args, **test_kwargs, **subtest_kwargs)
  File "/home/sdp/xiangdong/xccl_upstream/pytorch/test/distributed/_composable/test_composability/test_2d_composability.py", line 362, in _test_train_parity_2d_transformer_checkpoint_resume
    loss_no_cp1 = train_step(model_no_cp, optim_no_cp, inp)
  File "/home/sdp/xiangdong/xccl_upstream/pytorch/test/distributed/_composable/test_composability/test_2d_composability.py", line 335, in train_step
    loss = _model(_inp).sum()
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 223, in forward
    h = layer(h)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 185, in forward
    h = x + self.attention(self.attention_norm(x))
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 150, in forward
    output = F.scaled_dot_product_attention(
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1004, in _fn
    return fn(*args, **kwargs)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 358, in __torch_dispatch__
    return DTensor._op_dispatcher.dispatch(
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 154, in dispatch
    self.sharding_propagator.propagate(op_info)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 310, in propagate
    OutputSharding, self.propagate_op_sharding(op_info.schema)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 46, in __call__
    return self.cache(*args, **kwargs)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 330, in propagate_op_sharding_non_cached
    op_strategy = self.op_strategy_funcs[op_schema.op](strategy_schema)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/distributed/tensor/_ops/_view_ops.py", line 646, in reshape_strategy
    input_tgt_placements, output_placements = propagate_shape_and_sharding(
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/distributed/tensor/_ops/_view_ops.py", line 596, in propagate_shape_and_sharding
    in_dim = get_in_dim_to_shard(cmd)
  File "/home/sdp/miniforge-pypy3/envs/xccl_ww27/lib/python3.10/site-packages/torch/distributed/tensor/_ops/_view_ops.py", line 532, in get_in_dim_to_shard
    raise RuntimeError(
RuntimeError: ('Attempted to flatten sharded dimension 1, \n\nTo execute this test, run the following from the base repo dir:\n    python test/distributed/_composable/test_composability/test_2d_composability.py TestFullyShard2DTraining.test_train_parity_2d_transformer_checkpoint_resume\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0', 'but only the leftmost dim of a Flatten can be sharded.')

Versions

https://github.com/zxd1997066/pytorch/tree/xiangdong/dist_upstream_p2

Metadata

Metadata

Labels

bugSomething isn't workingmodule: distributedFor distributed feature issue

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions