Skip to content

Commit 41d2a95

Browse files
pytorchmergebotxuhancn
authored andcommitted
Revert "Add meta func for scaled mm (pytorch#112609)"
This reverts commit 75174c3. Reverted pytorch#112609 on behalf of https://github.com/huydhn due to Sorry for reverting this change, but it is failing ROCm jobs https://hud.pytorch.org/pytorch/pytorch/commit/75174c379712433af1ff810b36e34573b3d2587e ([comment](pytorch#112609 (comment)))
1 parent f0472f6 commit 41d2a95

File tree

4 files changed

+2
-95
lines changed

4 files changed

+2
-95
lines changed

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
753753
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
754754
TORCH_CHECK(amax.scalar_type() == kFloat, "amax must be a float scalar");
755755
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
756-
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
756+
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat1.scalar_type());
757757
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
758758
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
759759
"Multiplication of two Float8_e5m2 matrices is not supported");

torch/_meta_registrations.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5179,56 +5179,6 @@ def meta__scaled_dot_product_efficient_backward(
51795179
return grad_q, grad_k, grad_v, grad_bias
51805180

51815181

5182-
@register_meta([aten._scaled_mm.default])
5183-
def meta_scaled_mm(
5184-
self: torch.Tensor,
5185-
mat2: torch.Tensor,
5186-
bias: Optional[torch.Tensor] = None,
5187-
out_dtype: Optional[torch.dtype] = None,
5188-
scale_a: Optional[torch.Tensor] = None,
5189-
scale_b: Optional[torch.Tensor] = None,
5190-
scale_result: Optional[torch.Tensor] = None,
5191-
use_fast_accum: bool = False,
5192-
):
5193-
def is_row_major(stride):
5194-
return stride[0] > stride[1] and stride[1] == 1
5195-
5196-
def is_col_major(shape, stride):
5197-
return stride[0] == 1 and stride[1] == shape[0]
5198-
5199-
def is_fp8_type(dtype):
5200-
return dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
5201-
5202-
torch._check(
5203-
self.dim() == 2 and mat2.dim() == 2,
5204-
lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}",
5205-
)
5206-
torch._check(
5207-
is_row_major(self.stride()),
5208-
lambda: "self must be row_major",
5209-
)
5210-
torch._check(
5211-
is_col_major(mat2.shape, mat2.stride()),
5212-
lambda: "mat2 must be col_major",
5213-
)
5214-
torch._check(
5215-
self.size(1) % 16 == 0,
5216-
lambda: f"Expected self.size(0) to be divisible by 16, but got self.size(1)={self.size(1)}",
5217-
)
5218-
torch._check(
5219-
mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0,
5220-
lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}",
5221-
)
5222-
torch._check(
5223-
is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype),
5224-
lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
5225-
)
5226-
_out_dtype = out_dtype if out_dtype is not None else self.dtype
5227-
return torch.empty(
5228-
self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device
5229-
), torch.empty((), dtype=torch.float32, device=self.device)
5230-
5231-
52325182
@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
52335183
@out_wrapper()
52345184
def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):

torch/testing/_creation.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
_INTEGRAL_TYPES = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
1313
_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
14-
_FLOATING_8BIT_TYPES = [torch.float8_e4m3fn, torch.float8_e5m2]
1514
_COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128]
1615
_BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES]
1716
_FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES]
@@ -218,18 +217,6 @@ def clamp(a: float, l: float, h: float) -> float:
218217
_uniform_random_(
219218
torch.view_as_real(result) if dtype in _COMPLEX_TYPES else result, low, high
220219
)
221-
elif dtype in _FLOATING_8BIT_TYPES:
222-
low, high = modify_low_high(
223-
low,
224-
high,
225-
lowest_inclusive=torch.finfo(dtype).min,
226-
highest_exclusive=torch.finfo(dtype).max,
227-
default_low=-9,
228-
default_high=9,
229-
)
230-
result = torch.empty(shape, device=device, dtype=torch.float32)
231-
_uniform_random_(result, low, high)
232-
result = result.to(dtype)
233220
else:
234221
raise TypeError(
235222
f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()."

torch/testing/_internal/common_methods_invocations.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
skipCPUIfNoMklSparse,
2727
toleranceOverride, tol)
2828
from torch.testing._internal.common_cuda import (
29-
SM53OrLater, SM60OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN,
29+
SM53OrLater, SM60OrLater, SM80OrLater, with_tf32_off, TEST_CUDNN,
3030
_get_torch_cuda_version, _get_torch_rocm_version,
3131
)
3232
from torch.testing._internal.common_utils import (
@@ -8176,25 +8176,6 @@ def error_inputs_triplet_margin_loss(op_info, device, **kwargs):
81768176
yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs),
81778177
error_type=error_type, error_regex=error_regex)
81788178

8179-
def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
8180-
make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad)
8181-
make_mat_e5m2 = partial(make_tensor, device=device, dtype=torch.float8_e5m2, requires_grad=requires_grad)
8182-
M, N, K = 15, 32, 16
8183-
samples = []
8184-
# two e4m3
8185-
mat1 = make_mat_e4m3((M, K))
8186-
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
8187-
samples.append(SampleInput(mat1, mat2))
8188-
# mat1 e4m3 mat2 e5m2
8189-
mat1 = make_mat_e4m3((M, K))
8190-
mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
8191-
samples.append(SampleInput(mat1, mat2))
8192-
# mat1 e5m2 mat2 e4m3
8193-
mat1 = make_mat_e5m2((M, K))
8194-
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
8195-
samples.append(SampleInput(mat1, mat2))
8196-
8197-
yield from samples
81988179

81998180
def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
82008181
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -13709,17 +13690,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1370913690
'TestUnaryUfuncs', device_type='cuda',
1371013691
), ],
1371113692
),
13712-
OpInfo(
13713-
'torch._scaled_mm',
13714-
sample_inputs_func=sample_inputs_scaled_mm,
13715-
dtypes=empty_types(),
13716-
dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,),
13717-
supports_out=True,
13718-
supports_forward_ad=False,
13719-
supports_autograd=False,
13720-
decorators=[skipCUDAIf(not SM90OrLater, 'Requires CUDA SM >= 9.0')],
13721-
skips=()
13722-
),
1372313693
OpInfo(
1372413694
'nn.functional.scaled_dot_product_attention',
1372513695
op=lambda *args, **kwargs:

0 commit comments

Comments
 (0)