-
Notifications
You must be signed in to change notification settings - Fork 12.5k
Add support for BitnetForCausalLM (new model / new datatype) #7931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
076b4a1
57dfc3b
1f2e0ee
5e59660
2a01a7c
4e1ab50
ca09085
dbee0a8
1c5a8b7
3a0f8b0
97d22be
344467f
65ac3a3
abd798d
841c903
c0fd4df
de1d507
2322e9d
c0cd08d
f395dd9
5e5eee7
7a8961f
95dced0
569a03e
a03eff3
4edc958
89c7e4c
fcf2da4
fa9a742
230396b
2b09768
a58cf0d
abcdc50
c6ddfa7
55a57a5
0520d88
16f0c30
226c5ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -659,6 +659,24 @@ static inline __m128i packNibbles( __m256i bytes ) { | |
} | ||
#endif //__loongarch_asx | ||
|
||
void quantize_row_i8_s(const float * x, void * y, int64_t n, float* act_scales) { | ||
int8_t* dst = (int8_t*)y; | ||
double min = 0.00001; | ||
double max = min; | ||
for (int i = 0; i < n; ++i) { | ||
max = MAX(max, (double)fabs((double)x[i])); | ||
} | ||
float s = 127 / max; | ||
act_scales[0] = s; | ||
float temp; | ||
for (int i = 0; i < n; ++i) { | ||
temp = round((double)(x[i] * s)); | ||
if (temp > 127) temp = 127; | ||
if (temp < -128) temp = -128; | ||
dst[i] = (int8_t)(temp); | ||
} | ||
} | ||
|
||
// reference implementation for deterministic creation of model files | ||
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { | ||
static const int qk = QK4_0; | ||
|
@@ -3306,6 +3324,53 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr | |
return nrow * row_size; | ||
} | ||
|
||
size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { | ||
// 2 bits per weight | ||
UNUSED(quant_weights); | ||
|
||
size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row); | ||
|
||
int n = nrow * n_per_row; | ||
|
||
// f32 -> q8 | ||
double i2_scale = 0; | ||
for (int i=0; i<n; i++) { | ||
if (fabs((double)(src[i])) > 1e-6) { | ||
i2_scale = (double)src[i]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only taking the last non-zero value of the tensor as the scale, if I understand correctly? The other quants use the absmax, so this looks a bit weird. Does it work as expected? If so, how or why? Should it be the absmean of the non-zero values instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for remind! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This would slightly complicate the eventual Numpy implementation in (It's dequantization that needs to be fast for good inference speed) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I got it. I will change it into absmax to make it more reproducible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Eddie-Wang1120, actually (I only noticed it now), in section 2 of the BitNet 1.58b paper, they specifically say they use absmean:
See https://arxiv.org/html/2402.17764v1#S2 But if it's applied twice (e.g. on pre-quantized weights), then maybe the mean shouldn't include the zero values. ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And an interesting fact is, if we don't pre-quantize weights to {-1, 0, +1} * scale, the tokens generated will be wrong. That's why I put the absmean quantization (weight pre-quantization) in convert-hf-to-gguf.py, otherwise we'll get a meaningless fp32/fp16 gguf model. |
||
} | ||
} | ||
|
||
uint8_t* q8 = (uint8_t*)dst; | ||
for (int i=0; i<n; i++) { | ||
if (fabs((double)(src[i])) < 1e-6) { | ||
q8[i] = 0; | ||
continue; | ||
} | ||
q8[i] = (double)src[i] * i2_scale > 0 ? 1 : 3; | ||
compilade marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
// q8 -> 0, 1, 3 | ||
// | | | | ||
// 0, 1,-1 | ||
|
||
uint8_t* i2_weight = (uint8_t*)dst; | ||
for (int i=0; i<n; i++) { | ||
int group_idx = i / 4; | ||
int group_pos = i % 4; | ||
uint8_t temp = (q8[i] << (6 - 2 * group_pos)); | ||
q8[i] = 0; | ||
i2_weight[group_idx] |= temp; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder, maybe this could be made even more compact? Instead of fitting only 4 ternary values per byte, it would be possible to fit 5 of them (because 3⁵ = 243, which is smaller than 256). To avoid using modulo when dequantizing, assuming multiplication by 3 is fast (it can be turned into an addition and a bit shift), maybe storing an inverted value would work. Not sure what speed difference it would have compared to bit shifts and masks, though. Here's an example program verifying that multiplication can be an alternative to modulo by 3 (click to expand)#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
int main() {
char s1[6] = {0};
char s2[6] = {0};
for (uint8_t i = 0; i < 243; ++i) {
uint8_t n = i;
// extract with modulo
for (int j = 5; j-- > 0;) {
s1[j] = (n % 3) + '0';
n /= 3;
}
// invert the value
uint8_t q = (((uint16_t) i) * 256) / 243;
if (q != 0) {
// otherwise it's always one smaller than the original
q += 1;
}
// extract with multiplication
for (int j = 0; j < 5; ++j) {
uint16_t m = q * 3;
s2[j] = (m >> 8) + '0';
q = m & 0xFF;
}
printf("%s, %s: %s\n", s1, s2, strcmp(s1, s2) == 0 ? "\033[1;32mPASS\033[0m" : "\033[1;31mFAIL\033[0m");
}
return 0;
} Compile and run: $ gcc ternary-packing.c -o ternary-packing
$ ./ternary-packing Output (click to expand)$ ./ternary-packing
00000, 00000: PASS
00001, 00001: PASS
00002, 00002: PASS
00010, 00010: PASS
00011, 00011: PASS
00012, 00012: PASS
00020, 00020: PASS
00021, 00021: PASS
00022, 00022: PASS
00100, 00100: PASS
00101, 00101: PASS
00102, 00102: PASS
00110, 00110: PASS
00111, 00111: PASS
00112, 00112: PASS
00120, 00120: PASS
00121, 00121: PASS
00122, 00122: PASS
00200, 00200: PASS
00201, 00201: PASS
00202, 00202: PASS
00210, 00210: PASS
00211, 00211: PASS
00212, 00212: PASS
00220, 00220: PASS
00221, 00221: PASS
00222, 00222: PASS
01000, 01000: PASS
01001, 01001: PASS
01002, 01002: PASS
01010, 01010: PASS
01011, 01011: PASS
01012, 01012: PASS
01020, 01020: PASS
01021, 01021: PASS
01022, 01022: PASS
01100, 01100: PASS
01101, 01101: PASS
01102, 01102: PASS
01110, 01110: PASS
01111, 01111: PASS
01112, 01112: PASS
01120, 01120: PASS
01121, 01121: PASS
01122, 01122: PASS
01200, 01200: PASS
01201, 01201: PASS
01202, 01202: PASS
01210, 01210: PASS
01211, 01211: PASS
01212, 01212: PASS
01220, 01220: PASS
01221, 01221: PASS
01222, 01222: PASS
02000, 02000: PASS
02001, 02001: PASS
02002, 02002: PASS
02010, 02010: PASS
02011, 02011: PASS
02012, 02012: PASS
02020, 02020: PASS
02021, 02021: PASS
02022, 02022: PASS
02100, 02100: PASS
02101, 02101: PASS
02102, 02102: PASS
02110, 02110: PASS
02111, 02111: PASS
02112, 02112: PASS
02120, 02120: PASS
02121, 02121: PASS
02122, 02122: PASS
02200, 02200: PASS
02201, 02201: PASS
02202, 02202: PASS
02210, 02210: PASS
02211, 02211: PASS
02212, 02212: PASS
02220, 02220: PASS
02221, 02221: PASS
02222, 02222: PASS
10000, 10000: PASS
10001, 10001: PASS
10002, 10002: PASS
10010, 10010: PASS
10011, 10011: PASS
10012, 10012: PASS
10020, 10020: PASS
10021, 10021: PASS
10022, 10022: PASS
10100, 10100: PASS
10101, 10101: PASS
10102, 10102: PASS
10110, 10110: PASS
10111, 10111: PASS
10112, 10112: PASS
10120, 10120: PASS
10121, 10121: PASS
10122, 10122: PASS
10200, 10200: PASS
10201, 10201: PASS
10202, 10202: PASS
10210, 10210: PASS
10211, 10211: PASS
10212, 10212: PASS
10220, 10220: PASS
10221, 10221: PASS
10222, 10222: PASS
11000, 11000: PASS
11001, 11001: PASS
11002, 11002: PASS
11010, 11010: PASS
11011, 11011: PASS
11012, 11012: PASS
11020, 11020: PASS
11021, 11021: PASS
11022, 11022: PASS
11100, 11100: PASS
11101, 11101: PASS
11102, 11102: PASS
11110, 11110: PASS
11111, 11111: PASS
11112, 11112: PASS
11120, 11120: PASS
11121, 11121: PASS
11122, 11122: PASS
11200, 11200: PASS
11201, 11201: PASS
11202, 11202: PASS
11210, 11210: PASS
11211, 11211: PASS
11212, 11212: PASS
11220, 11220: PASS
11221, 11221: PASS
11222, 11222: PASS
12000, 12000: PASS
12001, 12001: PASS
12002, 12002: PASS
12010, 12010: PASS
12011, 12011: PASS
12012, 12012: PASS
12020, 12020: PASS
12021, 12021: PASS
12022, 12022: PASS
12100, 12100: PASS
12101, 12101: PASS
12102, 12102: PASS
12110, 12110: PASS
12111, 12111: PASS
12112, 12112: PASS
12120, 12120: PASS
12121, 12121: PASS
12122, 12122: PASS
12200, 12200: PASS
12201, 12201: PASS
12202, 12202: PASS
12210, 12210: PASS
12211, 12211: PASS
12212, 12212: PASS
12220, 12220: PASS
12221, 12221: PASS
12222, 12222: PASS
20000, 20000: PASS
20001, 20001: PASS
20002, 20002: PASS
20010, 20010: PASS
20011, 20011: PASS
20012, 20012: PASS
20020, 20020: PASS
20021, 20021: PASS
20022, 20022: PASS
20100, 20100: PASS
20101, 20101: PASS
20102, 20102: PASS
20110, 20110: PASS
20111, 20111: PASS
20112, 20112: PASS
20120, 20120: PASS
20121, 20121: PASS
20122, 20122: PASS
20200, 20200: PASS
20201, 20201: PASS
20202, 20202: PASS
20210, 20210: PASS
20211, 20211: PASS
20212, 20212: PASS
20220, 20220: PASS
20221, 20221: PASS
20222, 20222: PASS
21000, 21000: PASS
21001, 21001: PASS
21002, 21002: PASS
21010, 21010: PASS
21011, 21011: PASS
21012, 21012: PASS
21020, 21020: PASS
21021, 21021: PASS
21022, 21022: PASS
21100, 21100: PASS
21101, 21101: PASS
21102, 21102: PASS
21110, 21110: PASS
21111, 21111: PASS
21112, 21112: PASS
21120, 21120: PASS
21121, 21121: PASS
21122, 21122: PASS
21200, 21200: PASS
21201, 21201: PASS
21202, 21202: PASS
21210, 21210: PASS
21211, 21211: PASS
21212, 21212: PASS
21220, 21220: PASS
21221, 21221: PASS
21222, 21222: PASS
22000, 22000: PASS
22001, 22001: PASS
22002, 22002: PASS
22010, 22010: PASS
22011, 22011: PASS
22012, 22012: PASS
22020, 22020: PASS
22021, 22021: PASS
22022, 22022: PASS
22100, 22100: PASS
22101, 22101: PASS
22102, 22102: PASS
22110, 22110: PASS
22111, 22111: PASS
22112, 22112: PASS
22120, 22120: PASS
22121, 22121: PASS
22122, 22122: PASS
22200, 22200: PASS
22201, 22201: PASS
22202, 22202: PASS
22210, 22210: PASS
22211, 22211: PASS
22212, 22212: PASS
22220, 22220: PASS
22221, 22221: PASS
22222, 22222: PASS There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is a very good thought! |
||
|
||
float* scale_ptr = (float*)((char*)i2_weight + n / 4); | ||
for (int i=0; i<8; i++) { | ||
scale_ptr[i] = i2_scale; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the same scale stored 8 times? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I noticed that there is a alignment restrcition for gguf, which is 32bytes, so I stored 8 times by a float32 scale to fill the alignment. It can still work if I change it to scale_ptr[0] = i2_scale. |
||
} | ||
|
||
// 32B for scale | ||
return nrow * row_size / 4 + 32; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regarding the tensor-wide scales, even though the paper suggests using them, I wonder if using block scales would work too, so that it works better with the existing When (or if) packing 5 values per byte, since the row sizes from the bitnet models are usually not multiples of 5 (e.g. 1536, 2048), and since the 3B model uses a hidden_size of 3200 which isn't a multiple of 256, using blocks of 128 elements could work. Two groups of 64 elements, with each group having 12 bytes with 5 elements per byte, with 1 more byte with 4 elements, so If packing only 4 ternary values per byte (as in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, personally, I still think the tensor-wide scale is a better way for bitnet. At least for the 2bit compaction, 2.25 bpw means a around 10% model size waste, and it's kind of beyond what is acceptable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since 2 bits is already wasting 20% of the tensor size compared to the 1.6 bpw ideal for ternary, maybe there could be a way to still make this a block-wise quant (e.g. 4 elements per block of 1 byte), and have a row-wise/tensor-wise scale by somehow encoding it in unused bits in the first few blocks of each row? Might be a bad idea though, but I don't know yet why (maybe the overhead in recovering the scale?). (This would also require asserting a minimal row size in the quantize function.) Because 4 ternary values fit in 7 bits (3^4 == 81 < 128), and you're already using a lookup table to expand the packed bits into 4 bytes, this could let the SIMD vec_dot stay pretty much identical to how it is now, except maybe it could include the scaling in its result? Not sure yet how to pack the scale in Anyway, at least this gives some ideas to try eventually. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got your idea. Assume we have 128 continuous value, we can compact it into (128 / 4) * 7 = 224bits, and other 32bits will be a float32 scale, and still it's 2 bpw. The block size could be 28char(224bits) + 1float32(32bits) == 32bytes. One thing worries me a little is that we need to do some shifting to make the weight align so that we can index from the lookuptable, it may slow down the kernel, but it deserves give a try. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Note that the unit for block sizes are elements, while type sizes are in bytes. To keep the alignment, my suggestion was actually to keep using 8 bits per 4 elements (so that alignment remains easy), but also use the top bit of the first 16 or 32 bytes to store the scale. Only the lower (or upper? doesn't matter) 7 bits of the bytes would store 4 elements, using the fact that 3^4 == 81 < 128 == 2^7. To go for maximum compactness, the same idea can be applied to 5-elements per bytes to achieve |
||
} | ||
|
||
// ====================== "True" 2-bit (de)-quantization | ||
|
||
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) { | ||
|
@@ -3726,6 +3791,85 @@ static inline __m128i get_scale_shuffle(int i) { | |
} | ||
#endif | ||
|
||
//====================================== I2 =============================================== | ||
|
||
void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { | ||
const uint8_t * restrict x = vx; | ||
const int8_t * restrict y = vy; | ||
|
||
UNUSED(bs); | ||
UNUSED(bx); | ||
UNUSED(by); | ||
UNUSED(nrc); | ||
|
||
// TODO | ||
// #if defined(__AVX2__) | ||
// __m256i accu = _mm256_setzero_si256(); | ||
|
||
// for (int i=0; i<n/32; i++) { | ||
// const int8_t* w0 = (const int8_t *)(i2s_i8s + x[i*8 + 0]); | ||
// const int8_t* w1 = (const int8_t *)(i2s_i8s + x[i*8 + 1]); | ||
// const int8_t* w2 = (const int8_t *)(i2s_i8s + x[i*8 + 2]); | ||
// const int8_t* w3 = (const int8_t *)(i2s_i8s + x[i*8 + 3]); | ||
// const int8_t* w4 = (const int8_t *)(i2s_i8s + x[i*8 + 4]); | ||
// const int8_t* w5 = (const int8_t *)(i2s_i8s + x[i*8 + 5]); | ||
// const int8_t* w6 = (const int8_t *)(i2s_i8s + x[i*8 + 6]); | ||
// const int8_t* w7 = (const int8_t *)(i2s_i8s + x[i*8 + 7]); | ||
|
||
// __m256i xq8 = _mm256_set_epi8( | ||
// w0[0], w0[1], w0[2], w0[3], | ||
// w1[0], w1[1], w1[2], w1[3], | ||
// w2[0], w2[1], w2[2], w2[3], | ||
// w3[0], w3[1], w3[2], w3[3], | ||
// w4[0], w4[1], w4[2], w4[3], | ||
// w5[0], w5[1], w5[2], w5[3], | ||
// w6[0], w6[1], w6[2], w6[3], | ||
// w7[0], w7[1], w7[2], w7[3] | ||
// ); | ||
|
||
// __m256i yq8 = _mm256_loadu_si256((const __m256i*)(y + i*32)); | ||
|
||
// __m128i hxq8 = _mm256_castsi256_si128(xq8); | ||
// __m128i lxq8 = _mm256_extractf128_si256(xq8, 1); | ||
// __m128i hyq8 = _mm256_castsi256_si128(yq8); | ||
// __m128i lyq8 = _mm256_extractf128_si256(yq8, 1); | ||
|
||
// __m256i hxq16 = _mm256_cvtepi8_epi16(hxq8); | ||
// __m256i lxq16 = _mm256_cvtepi8_epi16(lxq8); | ||
// __m256i hyq16 = _mm256_cvtepi8_epi16(hyq8); | ||
// __m256i lyq16 = _mm256_cvtepi8_epi16(lyq8); | ||
|
||
// __m256i hzq16 = _mm256_sign_epi16(hyq16, hxq16); | ||
// __m256i lzq16 = _mm256_sign_epi16(lyq16, lxq16); | ||
|
||
// __m256i hhzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(hzq16)); | ||
// __m256i hlzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(hzq16, 1)); | ||
// __m256i llzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(lzq16)); | ||
// __m256i lhzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(lzq16, 1)); | ||
|
||
// accu = _mm256_add_epi32(accu, hhzq32); | ||
// accu = _mm256_add_epi32(accu, hlzq32); | ||
// accu = _mm256_add_epi32(accu, llzq32); | ||
// accu = _mm256_add_epi32(accu, lhzq32); | ||
// } | ||
|
||
// int sumi = hsum_i32_8(accu); | ||
// *s = (float)sumi; | ||
// #else | ||
|
||
int sumi = 0; | ||
|
||
for (int i = 0; i < n / 4; i++) { | ||
const int8_t* weight = (const int8_t *)(i2s_i8s + x[i]); | ||
sumi += (int)y[i*4+0] * weight[0]; | ||
sumi += (int)y[i*4+1] * weight[1]; | ||
sumi += (int)y[i*4+2] * weight[2]; | ||
sumi += (int)y[i*4+3] * weight[3]; | ||
} | ||
*s = (float)sumi; | ||
// #endif | ||
} | ||
|
||
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { | ||
const int qk = QK8_0; | ||
const int nb = n / qk; | ||
|
@@ -14367,6 +14511,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte | |
case GGML_TYPE_I16: | ||
case GGML_TYPE_I32: | ||
case GGML_TYPE_I64: | ||
case GGML_TYPE_I2_S: | ||
// nothing to validate | ||
break; | ||
default: | ||
|
Uh oh!
There was an error while loading. Please reload this page.