@@ -142,10 +142,10 @@ def _fp4_packed_to_bf16(
142142 Output: a tensor of bfloat16 values
143143 """
144144
145- # low -bits: original location 0:3
146- # high -bits: original location 4:7
147- x_low_bits = x_packed >> 4
148- x_high_bits = x_packed & 0xF
145+ # high -bits: original location 0:3
146+ # low -bits: original location 4:7
147+ x_high_bits = x_packed >> 4
148+ x_low_bits = x_packed & 0xF
149149 x = tl .interleave (x_low_bits , x_high_bits )
150150
151151 # cast logic below
@@ -735,8 +735,8 @@ def unpack_uint4(uint8_data) -> torch.Tensor:
735735 # verified that we get a single triton kernel, but that is even slower
736736 # than the two kernels before this PR
737737 # * TODO add a microbenchmark of just the cast and profile this
738- first_elements = (uint8_data >> 4 ).to (torch .uint8 )
739- second_elements = (uint8_data & 0b1111 ).to (torch .uint8 )
738+ first_elements = (uint8_data & 0b1111 ).to (torch .uint8 )
739+ second_elements = (uint8_data >> 4 ).to (torch .uint8 )
740740 unpacked = torch .stack ([first_elements , second_elements ], dim = - 1 ).view (
741741 up_size (shape )
742742 )
@@ -758,7 +758,7 @@ def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor:
758758 shape = uint8_data .shape
759759 assert shape [- 1 ] % 2 == 0
760760 uint8_data = uint8_data .contiguous ().view (- 1 )
761- return (uint8_data [::2 ] << 4 | uint8_data [1 ::2 ]).view (down_size (shape ))
761+ return (uint8_data [::2 ] | uint8_data [1 ::2 ] << 4 ).view (down_size (shape ))
762762
763763
764764# PyTorch implementation of fp6 packing for reference purposes
@@ -1250,8 +1250,8 @@ def convert_fp32_to_fp4_packed(x_pairs):
12501250 Returns:
12511251 Packed tensor with shape [...] (last dimension removed) where each
12521252 element is an int8 containing 2 FP4 values:
1253- - First value of pair → high nibble (bits 4-7 )
1254- - Second value of pair → low nibble (bits 0-3 )
1253+ - First value of pair → low nibble (bits 0-3 )
1254+ - Second value of pair → high nibble (bits 4-7 )
12551255
12561256 Example:
12571257 Input: [128, 32, 2] containing FP32 pairs
@@ -1263,10 +1263,10 @@ def convert_fp32_to_fp4_packed(x_pairs):
12631263 asm = """
12641264 {
12651265 .reg .b8 byte0, byte1, byte2, byte3;
1266- cvt.rn.satfinite.e2m1x2.f32 byte0, $1 , $5 ;
1267- cvt.rn.satfinite.e2m1x2.f32 byte1, $2 , $6 ;
1268- cvt.rn.satfinite.e2m1x2.f32 byte2, $3 , $7 ;
1269- cvt.rn.satfinite.e2m1x2.f32 byte3, $4 , $8 ;
1266+ cvt.rn.satfinite.e2m1x2.f32 byte0, $5 , $1 ;
1267+ cvt.rn.satfinite.e2m1x2.f32 byte1, $6 , $2 ;
1268+ cvt.rn.satfinite.e2m1x2.f32 byte2, $7 , $3 ;
1269+ cvt.rn.satfinite.e2m1x2.f32 byte3, $8 , $4 ;
12701270 mov.b32 $0, {byte0, byte1, byte2, byte3};
12711271 }
12721272 """ ,
0 commit comments