Skip to content

Commit 78f5e4c

Browse files
committed
Add test cases
1 parent 894857e commit 78f5e4c

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

test/sparsity/test_sparse_api.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,31 @@ def test_sparse(self, compile):
267267

268268
torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1)
269269

270+
# TODO: Remove this test once the deprecated API has been removed
271+
def test_sparse_deprecated(self):
272+
import sys
273+
import warnings
274+
275+
# We need to clear the cache to force re-importing and trigger the warning again.
276+
modules_to_clear = [
277+
"torchao.dtypes.uintx.block_sparse_layout",
278+
"torchao.dtypes",
279+
]
280+
for mod in modules_to_clear:
281+
if mod in sys.modules:
282+
del sys.modules[mod]
283+
284+
with warnings.catch_warnings(record=True) as w:
285+
warnings.simplefilter("always") # Ensure all warnings are captured
286+
self.assertTrue(
287+
any(
288+
issubclass(warning.category, DeprecationWarning)
289+
and "BlockSparseLayout" in str(warning.message)
290+
for warning in w
291+
),
292+
f"Expected deprecation warning for BlockSparseLayout, got: {[str(w.message) for w in w]}",
293+
)
294+
270295

271296
common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse)
272297
common_utils.instantiate_parametrized_tests(TestQuantSemiSparse)

0 commit comments

Comments
 (0)