Skip to content

Commit f07690c

Browse files
authored
vulkan: use fp32 in coopmat2 q4_k dequant function (#12309)
1 parent 891c639 commit f07690c

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
178178

179179
uvec4 v = bl128.block.q4k[0];
180180

181-
const f16vec2 loadd = unpackFloat2x16(v.x);
181+
const vec2 loadd = vec2(unpackFloat2x16(v.x));
182182

183183
uint32_t sc;
184184
uint32_t mbyte;
@@ -199,15 +199,15 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
199199
sc &= 0x3F;
200200
mbyte &= 0x3F;
201201

202-
const float16_t d = loadd.x * float16_t(sc);
203-
const float16_t m = loadd.y * float16_t(mbyte);
202+
const float d = loadd.x * float(sc);
203+
const float m = loadd.y * float(mbyte);
204204

205205
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
206206
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
207207

208-
float16_t ret = d * float16_t(qs) - m;
208+
float ret = d * float(qs) - m;
209209

210-
return ret;
210+
return float16_t(ret);
211211
}
212212

213213
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {

0 commit comments

Comments
 (0)