@@ -2199,13 +2199,15 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
2199
2199
template <typename type4x4>
2200
2200
void dequantize_q4_0 (device const block_q4_0 *xb, short il, thread type4x4 & reg) {
2201
2201
device const uint16_t * qs = ((device const uint16_t *)xb + 1 );
2202
- const float d = xb->d ;
2202
+ const float d1 = il ? (xb->d / 16 .h ) : xb->d ;
2203
+ const float d2 = d1 / 256 .f ;
2203
2204
const float md = -8 .h * xb->d ;
2204
- const ushort mask = il ? 0x00F0 : 0x000F ;
2205
+ const ushort mask0 = il ? 0x00F0 : 0x000F ;
2206
+ const ushort mask1 = mask0 << 8 ;
2205
2207
2206
2208
for (int i=0 ;i<8 ;i++) {
2207
- reg[i/2 ][2 *(i%2 )+0 ] = d * (( qs[i] & mask) >> (il ? 4 : 0 ) ) + md;
2208
- reg[i/2 ][2 *(i%2 )+1 ] = d * ((( qs[i] >> 8 ) & mask) >> (il ? 4 : 0 ) ) + md;
2209
+ reg[i/2 ][2 *(i%2 )+0 ] = d1 * (qs[i] & mask0 ) + md;
2210
+ reg[i/2 ][2 *(i%2 )+1 ] = d2 * (qs[i] & mask1 ) + md;
2209
2211
}
2210
2212
}
2211
2213
@@ -2235,13 +2237,13 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
2235
2237
2236
2238
const int x_mv = (il ? 4 : 0 );
2237
2239
2238
- const int qh_mv = (il ? 12 : 0 );
2239
- const int qh_bk = (il ? 0 : 4 );
2240
+ const int gh_mv = (il ? 12 : 0 );
2241
+ const int gh_bk = (il ? 0 : 4 );
2240
2242
2241
2243
for (int i = 0 ; i < 8 ; i++) {
2242
2244
// extract the 5-th bits for x0 and x1
2243
- const uint8_t xh_0 = ((qh >> (qh_mv + 2 *i )) << qh_bk ) & 0x10 ;
2244
- const uint8_t xh_1 = ((qh >> (qh_mv + 2 *i+1 )) << qh_bk ) & 0x10 ;
2245
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2 *i )) << gh_bk ) & 0x10 ;
2246
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2 *i+1 )) << gh_bk ) & 0x10 ;
2245
2247
2246
2248
// combine the 4-bits from qs with the 5th bit
2247
2249
const int32_t x0 = (((qs[i] & mask) >> x_mv) | xh_0);
0 commit comments