diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index 424306f897..a9f57bb5a5 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -15,22 +15,10 @@ swap_linear_with_semi_sparse_linear, swap_semi_sparse_linear_with_linear, ) +from torchao.testing.model_architectures import ToyTwoLinearModel from torchao.utils import is_fbcode -class ToyModel(nn.Module): - def __init__(self): - super().__init__() - self.linear1 = nn.Linear(128, 256, bias=False) - self.linear2 = nn.Linear(256, 128, bias=False) - - def forward(self, x): - x = self.linear1(x) - x = torch.nn.functional.relu(x) - x = self.linear2(x) - return x - - class TestRuntimeSemiStructuredSparsity(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") @@ -41,7 +29,7 @@ def test_runtime_weight_sparsification(self): input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() - model = ToyModel().half().cuda() + model = ToyTwoLinearModel(128, 256, 128, device="cuda", dtype=torch.float16) model_c = copy.deepcopy(model) for name, mod in model.named_modules(): @@ -89,7 +77,7 @@ def test_runtime_weight_sparsification_compile(self): input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() - model = ToyModel().half().cuda() + model = ToyTwoLinearModel(128, 256, 128, device="cuda", dtype=torch.float16) model_c = copy.deepcopy(model) for name, mod in model.named_modules(): diff --git a/torchao/testing/model_architectures.py b/torchao/testing/model_architectures.py index 8f41a8464c..4100a3cd76 100644 --- a/torchao/testing/model_architectures.py +++ b/torchao/testing/model_architectures.py @@ -11,14 +11,72 @@ import torch.nn.functional as F +class ToySingleLinearModel(torch.nn.Module): + def __init__( + self, + input_dim, + output_dim, + dtype, + device, + has_bias=False, + ): + super().__init__() + self.dtype = dtype + self.device = device + self.linear1 = torch.nn.Linear( + input_dim, output_dim, bias=has_bias, dtype=dtype, device=device + ) + + def example_inputs(self, batch_size=1): + return ( + torch.randn( + batch_size, + self.linear1.in_features, + dtype=self.dtype, + device=self.device, + ), + ) + + def forward(self, x): + x = self.linear1(x) + return x + + # TODO: Refactor torchao and tests to use these models -class ToyLinearModel(torch.nn.Module): - def __init__(self, k=64, n=32, dtype=torch.bfloat16): +class ToyTwoLinearModel(torch.nn.Module): + def __init__( + self, + input_dim, + hidden_dim, + output_dim, + dtype, + device, + has_bias=False, + ): super().__init__() - self.linear1 = torch.nn.Linear(k, n, bias=False).to(dtype) + self.dtype = dtype + self.device = device + self.linear1 = torch.nn.Linear( + input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device + ) + self.linear2 = torch.nn.Linear( + hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device + ) + + # Note: Tiny-GEMM kernel only uses BF16 inputs + def example_inputs(self, batch_size=1): + return ( + torch.randn( + batch_size, + self.linear1.in_features, + dtype=self.dtype, + device=self.device, + ), + ) def forward(self, x): x = self.linear1(x) + x = self.linear2(x) return x @@ -179,8 +237,8 @@ def create_model_and_input_data( m, k, n (int): dimensions of the model and input data """ if model_type == "linear": - model = ToyLinearModel(k, n, high_precision_dtype).to(device) - input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + model = ToySingleLinearModel(k, n, device=device, dtype=high_precision_dtype) + input_data = model.example_inputs(batch_size=m)[0] elif "ln_linear" in model_type: # Extract activation type from model_type string match = re.search(r"ln_linear_?(\w+)?", model_type)