diff --git a/benchmarks/benchmark_aq.py b/benchmarks/benchmark_aq.py index b391f15daf..02c9d6abd4 100644 --- a/benchmarks/benchmark_aq.py +++ b/benchmarks/benchmark_aq.py @@ -16,32 +16,7 @@ _replace_with_custom_fn_if_matches_filter, quantize_, ) - - -class ToyLinearModel(torch.nn.Module): - """Single linear for m * k * n problem size""" - - def __init__( - self, m=64, n=32, k=64, has_bias=False, dtype=torch.float, device="cuda" - ): - super().__init__() - self.m = m - self.dtype = dtype - self.device = device - self.linear = torch.nn.Linear(k, n, bias=has_bias).to( - dtype=self.dtype, device=self.device - ) - - def example_inputs(self): - return ( - torch.randn( - self.m, self.linear.in_features, dtype=self.dtype, device=self.device - ), - ) - - def forward(self, x): - x = self.linear(x) - return x +from torchao.testing.model_architectures import ToySingleLinearModel def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass): @@ -69,14 +44,26 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): @torch.no_grad -def _bench_quantized_tensor_subclass_perf(api, config, M, N, K): - m = ToyLinearModel( +def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): + if kwargs is None: + kwargs = {} + + m = ToySingleLinearModel( M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda" ).eval() m_bf16 = copy.deepcopy(m) + m_ref = copy.deepcopy(m) + example_inputs = m.example_inputs(batch_size=M) + + api(m, **kwargs) + + # reference example_inputs = m.example_inputs() - api(m, config) # Pass both model and config + res = m(*example_inputs) + ref = m_ref(*example_inputs) + + assert torch.equal(res, ref) # perf comparison from torchao.utils import benchmark_model @@ -95,6 +82,11 @@ def _bench_quantized_tensor_subclass_perf(api, config, M, N, K): benchmark_model(m, WARMUP, example_inputs) elapsed_time = benchmark_model(m, RUNS, example_inputs) + torch._dynamo.reset() + m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True) + benchmark_model(m_bf16, WARMUP, example_inputs) + bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs) + print( f"{(M, N, K)}: elapsed time: {elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}" ) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 35870a5e6b..3b3323f654 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -38,6 +38,7 @@ choose_qparams_affine, ) from torchao.quantization.quantize_.common import KernelPreference +from torchao.testing.model_architectures import ToyTwoLinearModel from torchao.utils import ( is_sm_at_least_89, is_sm_at_least_90, @@ -48,18 +49,6 @@ torch.manual_seed(0) -class ToyLinearModel(torch.nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self.linear1 = torch.nn.Linear(in_features, out_features, bias=False) - self.linear2 = torch.nn.Linear(out_features, in_features, bias=False) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x - - class TestAffineQuantizedFloat8Compile(InductorTestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( @@ -121,8 +110,8 @@ def test_fp8_linear_variants( ), } - # Create a linear layer with bfloat16 dtype - model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + # Create a linear layer + model = ToyTwoLinearModel(K, N, K, device="cuda", dtype=dtype).eval() quantized_model = copy.deepcopy(model) factory = mode_map[mode]() @@ -179,7 +168,9 @@ def test_per_row_with_float32(self): AssertionError, match="PerRow quantization only works for bfloat16 precision", ): - model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda") + model = ToyTwoLinearModel( + 64, 64, 64, device="cuda", dtype=torch.float32 + ).eval() quantize_( model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), @@ -192,7 +183,7 @@ def test_per_row_with_float32(self): @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) def test_serialization(self, mode: str): # Create and quantize the model - model = ToyLinearModel(16, 32).to(device="cuda") + model = ToyTwoLinearModel(16, 32, 16, device="cuda", dtype=torch.float32) mode_map = { "dynamic": partial( @@ -224,7 +215,9 @@ def test_serialization(self, mode: str): # Create a new model and load the state dict with torch.device("meta"): - new_model = ToyLinearModel(16, 32) + new_model = ToyTwoLinearModel( + 16, 32, 16, device="cuda", dtype=torch.float32 + ) if mode == "static": quantize_(new_model, factory) new_model.load_state_dict(loaded_state_dict, assign=True) @@ -266,7 +259,9 @@ def test_serialization(self, mode: str): ) def test_fp8_weight_dimension_warning(self): # Create model with incompatible dimensions (not multiples of 16) - model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights + model = ToyTwoLinearModel( + 10, 25, 10, device="cuda", dtype=torch.float32 + ) # 10x25 and 25x10 weights # Set up logging capture with self.assertLogs( diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index dc58470526..36849fd4c5 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -69,6 +69,7 @@ from torchao.quantization.utils import ( compute_error as SQNR, ) +from torchao.testing.model_architectures import ToyTwoLinearModel from torchao.testing.utils import skip_if_rocm from torchao.utils import ( benchmark_model, @@ -1910,30 +1911,13 @@ def test_get_model_size_aqt(self, api, test_device, test_dtype): class TestBenchmarkModel(unittest.TestCase): - class ToyLinearModel(torch.nn.Module): - def __init__(self, m=64, n=32, k=64): - super().__init__() - self.linear1 = torch.nn.Linear(m, n, bias=False) - self.linear2 = torch.nn.Linear(n, k, bias=False) - - def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): - return ( - torch.randn( - batch_size, self.linear1.in_features, dtype=dtype, device=device - ), - ) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x - def run_benchmark_model(self, device): # params - dtype = torch.bfloat16 - m = self.ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to(device) + m = ToyTwoLinearModel( + 1024, 1024, 1024, device=device, dtype=torch.bfloat16 + ).eval() m_bf16 = copy.deepcopy(m) - example_inputs = m.example_inputs(dtype=dtype, device=device) + example_inputs = m.example_inputs() m_bf16 = torch.compile(m_bf16, mode="max-autotune") num_runs = 1 return benchmark_model(m_bf16, num_runs, example_inputs) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index d7953f4ec3..83794c4551 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -25,6 +25,7 @@ from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor from torchao.quantization.utils import compute_error +from torchao.testing.model_architectures import ToyTwoLinearModel from torchao.testing.utils import TorchAOIntegrationTestCase from torchao.utils import ( _is_fbgemm_gpu_genai_available, @@ -38,18 +39,6 @@ torch._dynamo.config.cache_size_limit = 128 -class ToyLinearModel(torch.nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self.linear1 = torch.nn.Linear(in_features, out_features, bias=False) - self.linear2 = torch.nn.Linear(out_features, in_features, bias=False) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x - - class ToyConvModel(torch.nn.Module): def __init__( self, dim, in_channels, out_channels, kernel_size, bias, padding, dtype, device @@ -145,7 +134,7 @@ def test_fp8_linear_variants( input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") # Create a linear layer with bfloat16 dtype - model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + model = ToyTwoLinearModel(K, N, K, device="cuda", dtype=dtype).eval() quantized_model = copy.deepcopy(model) @@ -333,7 +322,7 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes): dtype = torch.bfloat16 input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") # Create a linear layer with bfloat16 dtype - model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + model = ToyTwoLinearModel(K, N, K, device="cuda", dtype=dtype).eval() # reference kernel preference and results # we are using KerenelPreference.TORCH as the reference diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e1c6471b17..054e2979a3 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -64,6 +64,7 @@ ) from torchao.quantization.quant_primitives import MappingType from torchao.quantization.utils import compute_error +from torchao.testing.model_architectures import ToyTwoLinearModel from torchao.testing.utils import skip_if_rocm from torchao.utils import ( is_sm_at_least_89, @@ -127,25 +128,6 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module: return model -class ToyLinearModel(torch.nn.Module): - def __init__(self, m=64, n=32, k=64, bias=False): - super().__init__() - self.linear1 = torch.nn.Linear(m, n, bias=bias).to(torch.float) - self.linear2 = torch.nn.Linear(n, k, bias=bias).to(torch.float) - - def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): - return ( - torch.randn( - batch_size, self.linear1.in_features, dtype=dtype, device=device - ), - ) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x - - def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass): def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): """ @@ -172,8 +154,12 @@ class TestQuantFlow(TestCase): ["xpu"] if torch.xpu.is_available() else [] ) + def setUp(self): + self.device = "cpu" + self.dtype = torch.float32 + def test_dynamic_quant_gpu_singleline(self): - m = ToyLinearModel().eval() + m = ToyTwoLinearModel(64, 32, 64, device=self.device, dtype=self.dtype).eval() example_inputs = m.example_inputs() quantize_(m, Int8DynamicActivationInt8WeightConfig()) m(*example_inputs) @@ -187,7 +173,7 @@ def test_dynamic_quant_gpu_singleline(self): @unittest.skip("skipping for now due to torch.compile error") def test_dynamic_quant_gpu_unified_api_unified_impl(self): quantizer = XNNPackDynamicQuantizer() - m = ToyLinearModel().eval() + m = ToyTwoLinearModel(64, 32, 64, device=self.device, dtype=self.dtype).eval() example_inputs = m.example_inputs() m = quantizer.prepare(m) m = quantizer.convert(m) @@ -204,7 +190,7 @@ def test_dynamic_quant_gpu_unified_api_unified_impl(self): ) def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): quantizer = TorchCompileDynamicQuantizer() - m = ToyLinearModel().eval() + m = ToyTwoLinearModel(64, 32, 64, device=self.device, dtype=self.dtype).eval() example_inputs = m.example_inputs() m = quantizer.quantize(m) quantized = m(*example_inputs) @@ -215,7 +201,7 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): @unittest.skipIf(not torch.xpu.is_available(), "Need XPU available") @unittest.skipIf(not torch_version_at_least("2.8.0"), "only works for torch 2.8+") def test_int4_wo_quant_save_load(self): - m = ToyLinearModel().eval().cpu() + m = ToyTwoLinearModel(64, 32, 64, device=self.device, dtype=self.dtype).eval() def api(model): quantize_(model, Int4WeightOnlyConfig(layout=Int4XPULayout(), version=1)) @@ -230,7 +216,7 @@ def api(model): f.seek(0) state_dict = torch.load(f) - m2 = ToyLinearModel().eval().cpu() + m2 = ToyTwoLinearModel(64, 32, 64, device=self.device, dtype=self.dtype).eval() api(m2) m2.load_state_dict(state_dict) @@ -242,7 +228,7 @@ def api(model): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int8_wo_quant_save_load(self): - m = ToyLinearModel().eval().cpu() + m = ToyTwoLinearModel(64, 32, 64, device=self.device, dtype=self.dtype).eval() def api(model): quantize_(model, Int8WeightOnlyConfig()) @@ -257,7 +243,7 @@ def api(model): f.seek(0) state_dict = torch.load(f) - m2 = ToyLinearModel().eval().cpu() + m2 = ToyTwoLinearModel(64, 32, 64, device=self.device, dtype=self.dtype).eval() api(m2) m2.load_state_dict(state_dict) @@ -274,7 +260,7 @@ def test_8da4w_quantizer(self): from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer quantizer = Int8DynActInt4WeightQuantizer(groupsize=32) - m = ToyLinearModel().eval() + m = ToyTwoLinearModel(64, 32, 64, device=self.device, dtype=self.dtype).eval() example_inputs = m.example_inputs() m = quantizer.quantize(m) assert isinstance(m.linear1, Int8DynActInt4WeightLinear) @@ -286,7 +272,9 @@ def test_8da4w_quantizer_linear_bias(self): from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer quantizer = Int8DynActInt4WeightQuantizer(groupsize=32) - m = ToyLinearModel(bias=True).eval() + m = ToyTwoLinearModel( + 64, 32, 64, device=self.device, dtype=self.dtype, has_bias=True + ).eval() example_inputs = m.example_inputs() m = quantizer.quantize(m) assert isinstance(m.linear1, Int8DynActInt4WeightLinear) @@ -404,7 +392,7 @@ def test_eval_wrapper_llama3(self): ) def test_quantized_tensor_subclass_8da4w(self, mapping_type): group_size = 32 - m = ToyLinearModel().eval() + m = ToyTwoLinearModel(64, 32, 64, device=self.device, dtype=self.dtype).eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() quantize_( @@ -440,9 +428,11 @@ def test_quantized_tensor_subclass_8da4w(self, mapping_type): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_save_load(self): - m = ToyLinearModel().eval().to(torch.bfloat16) + m = ToyTwoLinearModel( + 64, 32, 64, device=self.device, dtype=torch.bfloat16 + ).eval() m_copy = copy.deepcopy(m) - example_inputs = m.example_inputs(dtype=torch.bfloat16) + example_inputs = m.example_inputs() quantize_(m, Int8WeightOnlyConfig()) ref = m(*example_inputs) @@ -458,8 +448,10 @@ def test_quantized_tensor_subclass_save_load(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int8wo_quantized_model_to_device(self): - m = ToyLinearModel().eval().to(torch.bfloat16) - example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu") + m = ToyTwoLinearModel( + 64, 32, 64, device=self.device, dtype=torch.bfloat16 + ).eval() + example_inputs = m.example_inputs() quantize_(m, Int8WeightOnlyConfig()) ref = m(*example_inputs) @@ -471,18 +463,20 @@ def test_int8wo_quantized_model_to_device(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_save_load_map_location(self): - m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device="cuda") - example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") + m = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cuda").eval() + example_inputs = m.example_inputs() quantize_(m, Int8WeightOnlyConfig()) ref = m(*example_inputs) with tempfile.NamedTemporaryFile() as f: torch.save(m.state_dict(), f) f.seek(0) - state_dict = torch.load(f.name, map_location="cpu", mmap=True) + state_dict = torch.load(f.name, map_location=self.device, mmap=True) with torch.device("meta"): - m_copy = ToyLinearModel().eval() + m_copy = ToyTwoLinearModel( + 64, 32, 64, device=self.device, dtype=self.dtype + ).eval() m_copy.load_state_dict(state_dict, assign=True) m_copy.to(dtype=torch.bfloat16, device="cuda") @@ -498,13 +492,15 @@ def reset_memory(): torch.cuda.reset_peak_memory_stats() reset_memory() - m = ToyLinearModel() + + m = ToyTwoLinearModel(64, 32, 64, device=self.device, dtype=self.dtype) quantize_(m.to(device="cuda"), Int8WeightOnlyConfig()) memory_baseline = torch.cuda.max_memory_allocated() del m reset_memory() - m = ToyLinearModel() + + m = ToyTwoLinearModel(64, 32, 64, device=self.device, dtype=self.dtype) quantize_(m, Int8WeightOnlyConfig(), device="cuda") memory_streaming = torch.cuda.max_memory_allocated() @@ -512,13 +508,12 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) - @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("dtype", [torch.float32, torch.bfloat16, torch.half]) @common_utils.parametrize("x_dim", [2, 3]) @common_utils.parametrize("use_hqq", [True, False]) def test_int4wo_cpu(self, dtype, x_dim, use_hqq): - device = "cpu" - m = ToyLinearModel().eval().to(dtype).to(device) - example_inputs = m.example_inputs(dtype=dtype, device=device) + m = ToyTwoLinearModel(64, 32, 64, device=self.device, dtype=dtype).eval() + example_inputs = m.example_inputs() if x_dim == 3: example_inputs = (example_inputs[0].unsqueeze(0),) @@ -611,8 +606,9 @@ def test_module_fqn_to_config_default(self): config1 = Int4WeightOnlyConfig(group_size=32, version=1) config2 = Int8WeightOnlyConfig() config = ModuleFqnToConfig({"_default": config1, "linear2": config2}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + + model = ToyTwoLinearModel(64, 32, 64, device="cuda", dtype=torch.bfloat16) + example_inputs = model.example_inputs() quantize_(model, config, filter_fn=None) model(*example_inputs) assert isinstance(model.linear1.weight, AffineQuantizedTensor) @@ -625,9 +621,10 @@ def test_module_fqn_to_config_module_name(self): config1 = Int4WeightOnlyConfig(group_size=32, version=1) config2 = Int8WeightOnlyConfig() config = ModuleFqnToConfig({"linear1": config1, "linear2": config2}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyTwoLinearModel(64, 32, 64, device="cuda", dtype=torch.bfloat16) + example_inputs = model.example_inputs() quantize_(model, config, filter_fn=None) + model(*example_inputs) assert isinstance(model.linear1.weight, AffineQuantizedTensor) assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout) @@ -639,11 +636,13 @@ def test_module_fqn_to_config_regex_basic(self): config1 = Int4WeightOnlyConfig( group_size=32, int4_packing_format="tile_packed_to_4d" ) - config = ModuleFqnToConfig({"re:linear.": config1}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + config = ModuleFqnToConfig({"re:linear.*": config1}) + model = ToyTwoLinearModel( + 64, 32, 64, device=self.device, dtype=torch.bfloat16 + ).eval() quantize_(model, config, filter_fn=None) - model(*example_inputs) + model(*model.example_inputs()) + assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor) assert isinstance(model.linear2.weight, Int4TilePackedTo4dTensor) @@ -656,11 +655,11 @@ def test_module_fqn_to_config_regex_precedence(self): group_size=32, int4_packing_format="tile_packed_to_4d" ) config2 = IntxWeightOnlyConfig() - config = ModuleFqnToConfig({"linear1": config1, "re:linear.": config2}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + config = ModuleFqnToConfig({"linear1": config1, "re:linear.*": config2}) + model = ToyTwoLinearModel(64, 32, 64, device="cuda", dtype=torch.bfloat16) quantize_(model, config, filter_fn=None) - model(*example_inputs) + model(*model.example_inputs()) + assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor) assert isinstance(model.linear2.weight, IntxUnpackedToInt8Tensor) @@ -675,11 +674,14 @@ def test_module_fqn_to_config_regex_precedence2(self): group_size=32, int4_packing_format="tile_packed_to_4d" ) config2 = IntxWeightOnlyConfig() - config = ModuleFqnToConfig({"re:linear.": config2, "linear1": config1}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + + config = ModuleFqnToConfig({"re:linear.*": config2, "linear1": config1}) + model = ToyTwoLinearModel( + 64, 32, 64, device="cuda", dtype=torch.bfloat16 + ).eval() quantize_(model, config, filter_fn=None) - model(*example_inputs) + model(*model.example_inputs()) + assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor) assert isinstance(model.linear2.weight, IntxUnpackedToInt8Tensor) @@ -763,9 +765,10 @@ def test_module_fqn_to_config_embedding_linear(self): def test_module_fqn_to_config_skip(self): config1 = Int4WeightOnlyConfig(group_size=32, version=1) config = ModuleFqnToConfig({"_default": config1, "linear2": None}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyTwoLinearModel(64, 32, 64, device="cuda", dtype=torch.bfloat16) + example_inputs = model.example_inputs() quantize_(model, config, filter_fn=None) + model(*example_inputs) assert isinstance(model.linear1.weight, AffineQuantizedTensor) assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout) @@ -774,10 +777,10 @@ def test_module_fqn_to_config_skip(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int4wo_cuda_serialization(self): config = Int4WeightOnlyConfig(group_size=32, version=1) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) + model = ToyTwoLinearModel(64, 32, 64, device="cuda", dtype=torch.bfloat16) # quantize in cuda quantize_(model, config) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + example_inputs = model.example_inputs() model(*example_inputs) with tempfile.NamedTemporaryFile() as ckpt: # save checkpoint in cuda @@ -898,7 +901,9 @@ def test_quantize_param_fqn_regex(self): assert isinstance(model.experts.gate_up_proj, Float8Tensor) def test_quantize_fqn_precedence_param_over_module(self): - model = ToyLinearModel().to(torch.bfloat16).cuda().eval() + model = ToyTwoLinearModel( + 64, 32, 64, dtype=torch.bfloat16, device="cuda" + ).eval() quant_config = FqnToConfig( { @@ -913,7 +918,9 @@ def test_quantize_fqn_precedence_param_over_module(self): assert model.linear1.weight.scale.numel() == 1 def test_quantize_fqn_precedence_param_over_module_regex(self): - model = ToyLinearModel().to(torch.bfloat16).cuda().eval() + model = ToyTwoLinearModel( + 64, 32, 64, dtype=torch.bfloat16, device="cuda" + ).eval() quant_config = FqnToConfig( { @@ -928,7 +935,9 @@ def test_quantize_fqn_precedence_param_over_module_regex(self): assert model.linear1.weight.scale.numel() == 1 def test_quantize_fqn_precedence_param_regex_over_module_regex(self): - model = ToyLinearModel().to(torch.bfloat16).cuda().eval() + model = ToyTwoLinearModel( + 64, 32, 64, dtype=torch.bfloat16, device="cuda" + ).eval() quant_config = FqnToConfig( { @@ -943,7 +952,9 @@ def test_quantize_fqn_precedence_param_regex_over_module_regex(self): assert model.linear1.weight.scale.numel() == 1 def test_quantize_fqn_precedence_module_over_param_regex(self): - model = ToyLinearModel().to(torch.bfloat16).cuda().eval() + model = ToyTwoLinearModel( + 64, 32, 64, dtype=torch.bfloat16, device="cuda" + ).eval() quant_config = FqnToConfig( { @@ -959,7 +970,9 @@ def test_quantize_fqn_precedence_module_over_param_regex(self): assert not isinstance(model.linear2.weight, Float8Tensor) def test_quantize_fqn_precedence_param_over_default(self): - model = ToyLinearModel().to(torch.bfloat16).cuda().eval() + model = ToyTwoLinearModel( + 64, 32, 64, dtype=torch.bfloat16, device="cuda" + ).eval() quant_config = FqnToConfig( { @@ -975,7 +988,9 @@ def test_quantize_fqn_precedence_param_over_default(self): assert not isinstance(model.linear2.weight, Float8Tensor) def test_quantize_fqn_precedence_param_regex_over_default(self): - model = ToyLinearModel().to(torch.bfloat16).cuda().eval() + model = ToyTwoLinearModel( + 64, 32, 64, dtype=torch.bfloat16, device="cuda" + ).eval() quant_config = FqnToConfig( { @@ -990,7 +1005,9 @@ def test_quantize_fqn_precedence_param_regex_over_default(self): assert not isinstance(model.linear1.weight, Float8Tensor) def test_quantize_model_same_module_different_param(self): - model = ToyLinearModel().to(torch.bfloat16).cuda().eval() + model = ToyTwoLinearModel( + 64, 32, 64, dtype=torch.bfloat16, device="cuda" + ).eval() model.linear1.register_parameter( "weight2", torch.nn.Parameter(model.linear1.weight.clone()) ) @@ -1016,7 +1033,9 @@ def test_quantize_model_same_module_different_param(self): assert model.linear1.weight2.scale.numel() == 32 def test_quantize_model_same_module_different_param_regex(self): - model = ToyLinearModel().to(torch.bfloat16).cuda().eval() + model = ToyTwoLinearModel( + 64, 32, 64, dtype=torch.bfloat16, device="cuda" + ).eval() quant_config = FqnToConfig( { "re:.*weight": Float8DynamicActivationFloat8WeightConfig( @@ -1108,13 +1127,13 @@ def reset_memory(): quant_config = FqnToConfig({"_default": Int8WeightOnlyConfig()}) reset_memory() - m = ToyLinearModel() + m = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cuda").eval() quantize_(m.to(device="cuda"), quant_config, filter_fn=None) memory_baseline = torch.cuda.max_memory_allocated() del m reset_memory() - m = ToyLinearModel() + m = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cuda").eval() quantize_(m, quant_config, device="cuda", filter_fn=None) memory_streaming = torch.cuda.max_memory_allocated() diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index 424306f897..7448e8181b 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -15,33 +15,20 @@ swap_linear_with_semi_sparse_linear, swap_semi_sparse_linear_with_linear, ) +from torchao.testing.model_architectures import ToyTwoLinearModel from torchao.utils import is_fbcode -class ToyModel(nn.Module): - def __init__(self): - super().__init__() - self.linear1 = nn.Linear(128, 256, bias=False) - self.linear2 = nn.Linear(256, 128, bias=False) - - def forward(self, x): - x = self.linear1(x) - x = torch.nn.functional.relu(x) - x = self.linear2(x) - return x - - class TestRuntimeSemiStructuredSparsity(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") - @unittest.skip("Temporarily skipping to unpin nightlies") def test_runtime_weight_sparsification(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() - model = ToyModel().half().cuda() + model = ToyTwoLinearModel(128, 256, 128, device="cuda", dtype=torch.float16) model_c = copy.deepcopy(model) for name, mod in model.named_modules(): @@ -82,14 +69,13 @@ def test_runtime_weight_sparsification(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") - @unittest.skip("Temporarily skipping to unpin nightlies") def test_runtime_weight_sparsification_compile(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() - model = ToyModel().half().cuda() + model = ToyTwoLinearModel(128, 256, 128, device="cuda", dtype=torch.float16) model_c = copy.deepcopy(model) for name, mod in model.named_modules(): diff --git a/torchao/testing/model_architectures.py b/torchao/testing/model_architectures.py index 8f41a8464c..8bc8bdff56 100644 --- a/torchao/testing/model_architectures.py +++ b/torchao/testing/model_architectures.py @@ -11,14 +11,73 @@ import torch.nn.functional as F -# TODO: Refactor torchao and tests to use these models -class ToyLinearModel(torch.nn.Module): - def __init__(self, k=64, n=32, dtype=torch.bfloat16): +class ToySingleLinearModel(torch.nn.Module): + """Single linear for input_dim*output_dim problem size""" + + def __init__( + self, + input_dim, + output_dim, + dtype, + device, + has_bias=False, + ): super().__init__() - self.linear1 = torch.nn.Linear(k, n, bias=False).to(dtype) + self.dtype = dtype + self.device = device + self.linear1 = torch.nn.Linear( + input_dim, output_dim, bias=has_bias, dtype=dtype, device=device + ) + + def example_inputs(self, batch_size=1): + return ( + torch.randn( + batch_size, + self.linear1.in_features, + dtype=self.dtype, + device=self.device, + ), + ) + + def forward(self, x): + x = self.linear1(x) + return x + + +class ToyTwoLinearModel(torch.nn.Module): + def __init__( + self, + input_dim, + hidden_dim, + output_dim, + dtype, + device, + has_bias=False, + ): + super().__init__() + self.dtype = dtype + self.device = device + self.linear1 = torch.nn.Linear( + input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device + ) + self.linear2 = torch.nn.Linear( + hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device + ) + + # Note: tinygemm kernel only uses bfloat16 inputs + def example_inputs(self, batch_size=1): + return ( + torch.randn( + batch_size, + self.linear1.in_features, + dtype=self.dtype, + device=self.device, + ), + ) def forward(self, x): x = self.linear1(x) + x = self.linear2(x) return x @@ -179,7 +238,7 @@ def create_model_and_input_data( m, k, n (int): dimensions of the model and input data """ if model_type == "linear": - model = ToyLinearModel(k, n, high_precision_dtype).to(device) + model = ToySingleLinearModel(k, n, device=device, dtype=high_precision_dtype) input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) elif "ln_linear" in model_type: # Extract activation type from model_type string