|
23 | 23 | state_dict_device, |
24 | 24 | use_et_backend, |
25 | 25 | ) |
| 26 | +from qops import LinearInt8 as WeightOnlyInt8Linear, QuantizedEmbedding |
26 | 27 |
|
27 | | -from qops import LinearInt8 as WeightOnlyInt8Linear |
28 | 28 |
|
29 | 29 | ######################################################################### |
30 | 30 | ### torchchat quantization API ### |
@@ -489,9 +489,9 @@ def replace_embedding_weight_only_grouped_int8_per_channel( |
489 | 489 | setattr( |
490 | 490 | module, |
491 | 491 | name, |
492 | | - QuantizedGroupEmbedding( |
| 492 | + QuantizedEmbedding( |
493 | 493 | device=device, |
494 | | - vocab_size=child.weight.shape[0], |
| 494 | + num_embeddings=child.weight.shape[0], |
495 | 495 | embedding_dim=child.weight.shape[1], |
496 | 496 | bitwidth=bitwidth, |
497 | 497 | groupsize=groupsize, |
@@ -586,116 +586,6 @@ def quantized_model(self) -> nn.Module: |
586 | 586 | return self.model_ |
587 | 587 |
|
588 | 588 |
|
589 | | -class QuantizedGroupEmbedding(torch.nn.Module): |
590 | | - def __init__( |
591 | | - self, |
592 | | - device, |
593 | | - vocab_size: int, |
594 | | - embedding_dim: int, |
595 | | - bitwidth: int, |
596 | | - groupsize: Optional[int] = None, |
597 | | - *, |
598 | | - dtype=torch.half, |
599 | | - ) -> None: |
600 | | - super().__init__() |
601 | | - if groupsize is None or groupsize == 0: |
602 | | - groupsize = embedding_dim |
603 | | - self.groupsize = groupsize |
604 | | - self.dtype = dtype |
605 | | - self.bitwidth = bitwidth |
606 | | - |
607 | | - if use_et_backend(): |
608 | | - self.forward = self.et_forward |
609 | | - else: |
610 | | - self.forward = self.aoti_forward |
611 | | - |
612 | | - if bitwidth == 8: |
613 | | - self.register_buffer( |
614 | | - "weight", |
615 | | - torch.empty( |
616 | | - (vocab_size, embedding_dim), dtype=torch.int8, device=device |
617 | | - ), |
618 | | - ) |
619 | | - elif bitwidth == 4: # packed |
620 | | - self.register_buffer( |
621 | | - "weight", |
622 | | - torch.empty( |
623 | | - (vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device |
624 | | - ), |
625 | | - ) |
626 | | - else: |
627 | | - raise RuntimeError( |
628 | | - f"QUantized embedding does not support bitwidth={bitwidth}" |
629 | | - ) |
630 | | - |
631 | | - groups_per_row = (embedding_dim + groupsize - 1) // groupsize |
632 | | - if groups_per_row > 1: |
633 | | - self.register_buffer( |
634 | | - "scales", |
635 | | - torch.ones( |
636 | | - (vocab_size, groups_per_row), dtype=torch.float16, device=device |
637 | | - ), |
638 | | - ) |
639 | | - else: |
640 | | - self.register_buffer( |
641 | | - "scales", torch.ones((vocab_size,), dtype=torch.float16, device=device) |
642 | | - ) |
643 | | - |
644 | | - @torch.no_grad() |
645 | | - def et_forward(self, indices: torch.Tensor) -> torch.Tensor: |
646 | | - if self.bitwidth == 8: |
647 | | - return torch.ops.quantized_decomposed.embedding_byte.dtype( |
648 | | - self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype |
649 | | - ) |
650 | | - else: |
651 | | - return torch.ops.quantized_decomposed.embedding_4bit.dtype( |
652 | | - self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype |
653 | | - ) |
654 | | - |
655 | | - @torch.no_grad() |
656 | | - def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor: |
657 | | - # result_weights = self.weight.index_select(0, indices.view(-1)) |
658 | | - # result_scales = self.scales.index_select(0, indices.view(-1)) |
659 | | - |
660 | | - if self.bitwidth == 4: |
661 | | - weight_even = self.weight.div(16, rounding_mode="trunc") |
662 | | - weight_odd = self.weight.remainder(16) |
663 | | - weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1) |
664 | | - weight = weight_unpacked.view(self.weight.shape[0], -1) |
665 | | - weight = weight.to(torch.int8).add(-8) |
666 | | - else: |
667 | | - weight = self.weight |
668 | | - |
669 | | - scales = self.scales.view(weight.shape[0], -1) |
670 | | - |
671 | | - result_weights = F.embedding(indices, weight) |
672 | | - result_scales = F.embedding(indices, scales) |
673 | | - |
674 | | - rw_view = result_weights.to(dtype=result_scales.dtype).view( |
675 | | - tuple( |
676 | | - result_weights.shape[:-1] |
677 | | - + ( |
678 | | - scales.shape[1], |
679 | | - -1, |
680 | | - ) |
681 | | - ) |
682 | | - ) |
683 | | - rs_view = result_scales.view( |
684 | | - tuple(result_scales.shape[:-1]) |
685 | | - + ( |
686 | | - scales.shape[1], |
687 | | - 1, |
688 | | - ) |
689 | | - ) |
690 | | - # print(f"rw_view {rw_view.shape}") |
691 | | - # print(f"rs_view {rs_view.shape}") |
692 | | - |
693 | | - r = rw_view * rs_view |
694 | | - return r.view(indices.size() + (-1,)) |
695 | | - |
696 | | - # r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, )) |
697 | | - |
698 | | - |
699 | 589 | ######################################################################### |
700 | 590 | ##### weight only int4 per channel groupwise quantized code ###### |
701 | 591 |
|
|
0 commit comments