|
31 | 31 | "QuantizationType", |
32 | 32 | "QuantizationStrategy", |
33 | 33 | "QuantizationArgs", |
34 | | - "round_to_quantized_type", |
| 34 | + "round_to_quantized_type_args", |
| 35 | + "round_to_quantized_type_dtype", |
35 | 36 | "ActivationOrdering", |
36 | 37 | "DynamicType", |
37 | 38 | ] |
@@ -392,47 +393,57 @@ def get_observer(self) -> str: |
392 | 393 | model_config = ConfigDict(extra="forbid") |
393 | 394 |
|
394 | 395 |
|
395 | | -def _round_dtype(tensor: torch.Tensor, dtype: torch.dtype): |
| 396 | +def round_to_quantized_type_dtype( |
| 397 | + tensor: torch.Tensor, dtype: torch.dtype |
| 398 | +) -> torch.Tensor: |
| 399 | + """ |
| 400 | + Rounds an input tensor to the nearest quantized representation given a dtype. |
| 401 | + The original dtype is kept post-rounding. |
| 402 | +
|
| 403 | + :param tensor: tensor to round |
| 404 | + :param dtype: dtype to use for rounding |
| 405 | + :return: rounded tensor |
| 406 | + """ |
| 407 | + original_dtype = tensor.dtype |
396 | 408 | if torch.is_floating_point(torch.tensor([], dtype=dtype)): |
397 | 409 | finfo = torch.finfo(dtype) |
398 | | - return torch.clamp(tensor, finfo.min, finfo.max).to(dtype) |
| 410 | + rounded = torch.clamp(tensor, finfo.min, finfo.max).to(dtype) |
399 | 411 | else: |
400 | 412 | iinfo = torch.iinfo(dtype) |
401 | | - return torch.round(torch.clamp(tensor, iinfo.min, iinfo.max)) |
402 | | - |
| 413 | + rounded = torch.round(torch.clamp(tensor, iinfo.min, iinfo.max)) |
403 | 414 |
|
404 | | -def _round_args(tensor: torch.Tensor, args: QuantizationArgs): |
405 | | - if args.type == QuantizationType.FLOAT: |
406 | | - if args.num_bits == 8: |
407 | | - return tensor.to(FP8_E4M3_DATA.dtype) |
408 | | - elif args.num_bits == 4: |
409 | | - return FP4_E2M1_DATA.cast_to_fp4(tensor) |
410 | | - else: |
411 | | - raise NotImplementedError("Only num_bits in (4, 8) are supported") |
412 | | - elif args.type == QuantizationType.INT: |
413 | | - return torch.round(tensor) |
414 | | - else: |
415 | | - raise ValueError(f"Invalid quantization type {args.type}") |
| 415 | + return rounded.to(original_dtype) |
416 | 416 |
|
417 | 417 |
|
418 | | -def round_to_quantized_type( |
| 418 | +def round_to_quantized_type_args( |
419 | 419 | tensor: torch.Tensor, |
420 | | - args: Optional[QuantizationArgs] = None, |
421 | | - dtype: Optional[torch.dtype] = None, |
| 420 | + args: QuantizationArgs, |
| 421 | + min: torch.Tensor, |
| 422 | + max: torch.Tensor, |
422 | 423 | ) -> torch.Tensor: |
423 | 424 | """ |
424 | | - Rounds each element of the input tensor to the nearest quantized representation, |
425 | | - keeping to original dtype. This can be done given QuantizationArgs or dtype |
| 425 | + Rounds an input tensor to the nearest quantized representation given |
| 426 | + qunatization args. The original dtype is kept post-rounding. |
426 | 427 |
|
427 | 428 | :param tensor: tensor to round |
428 | | - :param args: QuantizationArgs to pull appropriate dtype from |
429 | | - :param dtype: dtype to use for rounding |
| 429 | + :param args: quantization args to use for rounding |
| 430 | + :param min: min value to use for clamping |
| 431 | + :param max: max value to use for clamping |
430 | 432 | :return: rounded tensor |
431 | 433 | """ |
| 434 | + |
432 | 435 | original_dtype = tensor.dtype |
433 | | - if dtype is not None: |
434 | | - rounded = _round_dtype(tensor=tensor, dtype=dtype) |
435 | | - elif args is not None: |
436 | | - rounded = _round_args(tensor=tensor, args=args) |
| 436 | + tensor = torch.clamp(tensor, min, max) |
| 437 | + if args.type == QuantizationType.FLOAT: |
| 438 | + if args.num_bits == 8: |
| 439 | + rounded = tensor.to(FP8_E4M3_DATA.dtype) |
| 440 | + elif args.num_bits == 4: |
| 441 | + rounded = FP4_E2M1_DATA.cast_to_fp4(tensor) |
| 442 | + else: |
| 443 | + raise NotImplementedError("Only num_bits in (4, 8) are supported") |
| 444 | + elif args.type == QuantizationType.INT: |
| 445 | + rounded = torch.round(tensor) |
| 446 | + else: |
| 447 | + raise ValueError(f"Invalid quantization type {args.type}") |
437 | 448 |
|
438 | 449 | return rounded.to(original_dtype) |
0 commit comments