|
21 | 21 | import warnings |
22 | 22 | from collections import OrderedDict |
23 | 23 | from dataclasses import dataclass, field |
| 24 | +from functools import partial |
24 | 25 | from typing import Any, Callable, List, Optional, Tuple, Union |
25 | 26 | from typing import OrderedDict as OrderedDictType |
26 | 27 |
|
@@ -416,6 +417,19 @@ def _embedding_extra_repr(self): |
416 | 417 | return f"num_embeddings={self.weight.shape[0]}, embedding_dim={self.weight.shape[1]}, weight={_quantization_type(self.weight)}" |
417 | 418 |
|
418 | 419 |
|
| 420 | +def _module_extra_repr(self, original_extra_repr, parameter_name): |
| 421 | + module_torchao_extra_repr = [] |
| 422 | + |
| 423 | + original_extra_repr_str = original_extra_repr() |
| 424 | + if len(original_extra_repr_str) > 0: |
| 425 | + module_torchao_extra_repr.append(original_extra_repr_str) |
| 426 | + |
| 427 | + module_torchao_extra_repr.append( |
| 428 | + f"{parameter_name}={_quantization_type(getattr(self, parameter_name))}" |
| 429 | + ) |
| 430 | + return ", ".join(module_torchao_extra_repr) |
| 431 | + |
| 432 | + |
419 | 433 | def _get_linear_subclass_inserter( |
420 | 434 | constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs |
421 | 435 | ): |
@@ -1375,11 +1389,22 @@ def _int8_weight_only_transform( |
1375 | 1389 | "applying int8 weight only quant requires module to have {parameter_name} attribute" |
1376 | 1390 | + " but {module} does not have one" |
1377 | 1391 | ) |
1378 | | - new_weight = _int8_weight_only_quantize_tensor( |
| 1392 | + quantized_tensor = _int8_weight_only_quantize_tensor( |
1379 | 1393 | getattr(module, parameter_name), config |
1380 | 1394 | ) |
1381 | | - setattr(module, parameter_name, torch.nn.Parameter(new_weight, requires_grad=False)) |
1382 | | - module.extra_repr = types.MethodType(_linear_extra_repr, module) |
| 1395 | + setattr( |
| 1396 | + module, |
| 1397 | + parameter_name, |
| 1398 | + torch.nn.Parameter(quantized_tensor, requires_grad=False), |
| 1399 | + ) |
| 1400 | + module.extra_repr = types.MethodType( |
| 1401 | + partial( |
| 1402 | + _module_extra_repr, |
| 1403 | + original_extra_repr=module.extra_repr, |
| 1404 | + parameter_name=parameter_name, |
| 1405 | + ), |
| 1406 | + module, |
| 1407 | + ) |
1383 | 1408 | return module |
1384 | 1409 |
|
1385 | 1410 |
|
@@ -1664,16 +1689,23 @@ def _float8_weight_only_transform( |
1664 | 1689 | if isinstance(module, Float8Linear): |
1665 | 1690 | module = _unwrap_float8_linear(module) |
1666 | 1691 |
|
1667 | | - new_weight = _float8_weight_only_quant_tensor( |
| 1692 | + quantized_tensor = _float8_weight_only_quant_tensor( |
1668 | 1693 | getattr(module, parameter_name), config |
1669 | 1694 | ) |
1670 | 1695 |
|
1671 | 1696 | setattr( |
1672 | 1697 | module, |
1673 | 1698 | parameter_name, |
1674 | | - torch.nn.Parameter(new_weight, requires_grad=False), |
| 1699 | + torch.nn.Parameter(quantized_tensor, requires_grad=False), |
| 1700 | + ) |
| 1701 | + module.extra_repr = types.MethodType( |
| 1702 | + partial( |
| 1703 | + _module_extra_repr, |
| 1704 | + original_extra_repr=module.extra_repr, |
| 1705 | + parameter_name=parameter_name, |
| 1706 | + ), |
| 1707 | + module, |
1675 | 1708 | ) |
1676 | | - module.extra_repr = types.MethodType(_linear_extra_repr, module) |
1677 | 1709 | return module |
1678 | 1710 |
|
1679 | 1711 |
|
@@ -1946,7 +1978,14 @@ def _float8_dynamic_activation_float8_weight_transform( |
1946 | 1978 | parameter_name, |
1947 | 1979 | torch.nn.Parameter(quantized_tensor, requires_grad=False), |
1948 | 1980 | ) |
1949 | | - module.extra_repr = types.MethodType(_linear_extra_repr, module) |
| 1981 | + module.extra_repr = types.MethodType( |
| 1982 | + partial( |
| 1983 | + _module_extra_repr, |
| 1984 | + original_extra_repr=module.extra_repr, |
| 1985 | + parameter_name=parameter_name, |
| 1986 | + ), |
| 1987 | + module, |
| 1988 | + ) |
1950 | 1989 | return module |
1951 | 1990 |
|
1952 | 1991 |
|
|
0 commit comments