@@ -387,22 +387,22 @@ class SubclassTensorArgs:
387
387
requires_grad : bool
388
388
389
389
390
- def get_block_absmax (inpt_tensor : torch .Tensor , block_size : int ) -> torch .Tensor :
390
+ def get_block_absmax (input_tensor : torch .Tensor , block_size : int ) -> torch .Tensor :
391
391
"""Iterate through a flattened tensor getting the absmax scalers for each block
392
392
393
393
Args:
394
- inpt_tensor : Input tensor to get scalers for
394
+ input_tensor : Input tensor to get scalers for
395
395
block_size: Block size for the scanning window
396
396
Returns:
397
397
torch.Tensor: Tensor of scalers for each block
398
398
"""
399
- assert inpt_tensor .dim () == 1 , "Input tensor must be flattened"
399
+ assert input_tensor .dim () == 1 , "Input tensor must be flattened"
400
400
assert (
401
- inpt_tensor .numel () % block_size
402
- ) == 0 , f"Input tensor must be divisible by block size, got { inpt_tensor .numel ()} and { block_size } "
401
+ input_tensor .numel () % block_size
402
+ ) == 0 , f"Input tensor must be divisible by block size, got { input_tensor .numel ()} and { block_size } "
403
403
404
- n_blocks = inpt_tensor .numel () // block_size
405
- blocks = inpt_tensor .view (n_blocks , block_size )
404
+ n_blocks = input_tensor .numel () // block_size
405
+ blocks = input_tensor .view (n_blocks , block_size )
406
406
block_scalers = blocks .abs ().max (dim = 1 ).values
407
407
return block_scalers
408
408
@@ -478,18 +478,18 @@ def __init__(
478
478
@torch .no_grad ()
479
479
def from_tensor (
480
480
cls ,
481
- inpt_tensor : torch .Tensor ,
481
+ input_tensor : torch .Tensor ,
482
482
block_size : int ,
483
483
scaler_block_size : int ,
484
484
):
485
- assert inpt_tensor .dim () <= 2 , f"expect input tensor dim <= 2 but got dim = { inpt_tensor .dim ()} "
485
+ assert input_tensor .dim () <= 2 , f"expect input tensor dim <= 2 but got dim = { input_tensor .dim ()} "
486
486
assert (
487
- inpt_tensor .numel () % block_size == 0
488
- ), f"Input tensor must be divisible by block size, got { inpt_tensor .numel ()} and { block_size } "
489
- assert inpt_tensor .is_contiguous , "Input tensor must be contiguous!"
487
+ input_tensor .numel () % block_size == 0
488
+ ), f"Input tensor must be divisible by block size, got { input_tensor .numel ()} and { block_size } "
489
+ assert input_tensor .is_contiguous , "Input tensor must be contiguous!"
490
490
# I think I want do this
491
- # assert not inpt_tensor .requires_grad, "Input tensor must not require grad"
492
- device = inpt_tensor .device
491
+ # assert not input_tensor .requires_grad, "Input tensor must not require grad"
492
+ device = input_tensor .device
493
493
# Cache the tensor on the class def
494
494
nf4 = torch .tensor (
495
495
[
@@ -511,27 +511,27 @@ def from_tensor(
511
511
1.0000 ,
512
512
],
513
513
device = device ,
514
- dtype = inpt_tensor .dtype ,
514
+ dtype = input_tensor .dtype ,
515
515
)
516
- n_blocks = inpt_tensor .numel () // block_size
516
+ n_blocks = input_tensor .numel () // block_size
517
517
# Double quantization
518
518
(
519
519
quantized_scalers ,
520
520
quantization_factor ,
521
521
scaler_mean ,
522
522
) = cls .double_quantize_scalers (
523
- inpt_tensor .flatten (), block_size , scaler_block_size
523
+ input_tensor .flatten (), block_size , scaler_block_size
524
524
)
525
525
quantized_data = cls .convert_to_norm_float_weight (
526
- inpt_tensor , n_blocks , block_size , nf4
526
+ input_tensor , n_blocks , block_size , nf4
527
527
)
528
528
tensor_meta = SubclassTensorArgs (
529
- inpt_tensor .size (),
530
- inpt_tensor .stride (),
531
- inpt_tensor .storage_offset (),
532
- inpt_tensor .dtype ,
533
- inpt_tensor .device ,
534
- inpt_tensor .requires_grad ,
529
+ input_tensor .size (),
530
+ input_tensor .stride (),
531
+ input_tensor .storage_offset (),
532
+ input_tensor .dtype ,
533
+ input_tensor .device ,
534
+ input_tensor .requires_grad ,
535
535
)
536
536
return cls (
537
537
tensor_meta ,
@@ -547,7 +547,7 @@ def from_tensor(
547
547
548
548
@staticmethod
549
549
def double_quantize_scalers (
550
- inpt_tensor : torch .Tensor ,
550
+ input_tensor : torch .Tensor ,
551
551
block_size : int ,
552
552
scaler_block_size : int ,
553
553
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
@@ -557,22 +557,22 @@ def double_quantize_scalers(
557
557
And then we calculate the absmax quantization factors for each block again. We then quantize the scalers to int8.
558
558
559
559
Args:
560
- inpt_tensor : Input tensor to convert to QLoRA format, typically a weight tensor
560
+ input_tensor : Input tensor to convert to QLoRA format, typically a weight tensor
561
561
562
562
Returns:
563
563
torch.Tensor: Tensor of per_block quantization factors stored in int8 format
564
564
size: (n_blocks)
565
565
torch.Tensor: Tensor of per_scaler_block quantization factors stored in int16 format
566
566
size: (n_scaler_blocks)
567
567
"""
568
- assert inpt_tensor .dim () == 1 , "Input tensor must be flattened"
568
+ assert input_tensor .dim () == 1 , "Input tensor must be flattened"
569
569
assert (
570
- inpt_tensor .numel () % scaler_block_size
571
- ) == 0 , f"Input tensor must be divisible by block size, got { inpt_tensor .numel ()} and { scaler_block_size } "
570
+ input_tensor .numel () % scaler_block_size
571
+ ) == 0 , f"Input tensor must be divisible by block size, got { input_tensor .numel ()} and { scaler_block_size } "
572
572
573
573
# First round of quantization
574
- # Produces: A tensor of size (n_blocks) of inpt_tensor .dtype
575
- scalers_1 = get_block_absmax (inpt_tensor , block_size )
574
+ # Produces: A tensor of size (n_blocks) of input_tensor .dtype
575
+ scalers_1 = get_block_absmax (input_tensor , block_size )
576
576
scalers_1_mean = scalers_1 .mean ()
577
577
scalers_1 = scalers_1 - scalers_1_mean
578
578
# Second round of quantization
@@ -607,52 +607,52 @@ def double_quantize_scalers(
607
607
608
608
def dequantize_scalers (
609
609
self ,
610
- inpt_tensor : torch .Tensor ,
610
+ input_tensor : torch .Tensor ,
611
611
quantization_factor : torch .Tensor ,
612
612
scaler_block_size : int ,
613
613
) -> torch .Tensor :
614
614
"""Used to unpack the double quantized scalers
615
615
616
616
Args;
617
- inpt_tensor : Input tensor to convert to QLoRA format this is the quantized scalers in int8 format
617
+ input_tensor : Input tensor to convert to QLoRA format this is the quantized scalers in int8 format
618
618
quantization_factor: Tensor of per_scaler_block quantization factors stored in inpt_weight.dtype
619
619
size: (n_scaler_blocks)
620
620
scaler_block_size: Scaler block size to use for double quantization.
621
621
622
622
"""
623
- assert inpt_tensor .dim () == 1 , "Input tensor must be flattened"
623
+ assert input_tensor .dim () == 1 , "Input tensor must be flattened"
624
624
assert (
625
- inpt_tensor .numel () % scaler_block_size
626
- ) == 0 , f"Input tensor must be divisible by block size, got { inpt_tensor .numel ()} and { scaler_block_size } "
627
- n_scaler_blocks = inpt_tensor .numel () // scaler_block_size
628
- inpt_tensor = inpt_tensor .view (n_scaler_blocks , scaler_block_size )
629
- dequantized = (inpt_tensor / quantization_factor .unsqueeze (- 1 )).flatten ().to (
625
+ input_tensor .numel () % scaler_block_size
626
+ ) == 0 , f"Input tensor must be divisible by block size, got { input_tensor .numel ()} and { scaler_block_size } "
627
+ n_scaler_blocks = input_tensor .numel () // scaler_block_size
628
+ input_tensor = input_tensor .view (n_scaler_blocks , scaler_block_size )
629
+ dequantized = (input_tensor / quantization_factor .unsqueeze (- 1 )).flatten ().to (
630
630
self .dtype
631
631
) + self .scaler_mean
632
632
return dequantized
633
633
634
634
@staticmethod
635
635
def convert_to_norm_float_weight (
636
- inpt_tensor : torch .Tensor , n_blocks : int , block_size : int , nf4 : torch .Tensor
636
+ input_tensor : torch .Tensor , n_blocks : int , block_size : int , nf4 : torch .Tensor
637
637
) -> torch .Tensor :
638
638
"""Convert a tensor to the normalized float weight format"""
639
- flattened_tensor = inpt_tensor .flatten ()
639
+ flattened_tensor = input_tensor .flatten ()
640
640
# Since we are using uint8 we will encode 2 entries per byte
641
- numel = inpt_tensor .numel ()
641
+ numel = input_tensor .numel ()
642
642
assert (
643
643
numel % 2 == 0
644
644
), "Number of elements must be even just to not have to think about the end"
645
645
# Reshape the flattened tensor into blocks of size self.block_size
646
646
blocks = flattened_tensor .view (n_blocks , block_size )
647
647
648
648
# Scale the blocks
649
- scalers = get_block_absmax (inpt_tensor .flatten (), block_size )
649
+ scalers = get_block_absmax (input_tensor .flatten (), block_size )
650
650
scales = scalers .unsqueeze (- 1 ).expand (n_blocks , block_size )
651
651
scaled_blocks = blocks / scales
652
652
653
653
# Returns a flattened tensor with each element quantized to nf4 index
654
654
# See Note: Quantize in Chunks
655
- quantized_blocks = torch .empty (numel , dtype = torch .uint8 , device = inpt_tensor .device )
655
+ quantized_blocks = torch .empty (numel , dtype = torch .uint8 , device = input_tensor .device )
656
656
flattened = scaled_blocks .flatten ()
657
657
for chunk_num in range (math .ceil (numel / CHUNK_SIZE )):
658
658
start = chunk_num * CHUNK_SIZE
0 commit comments