Skip to content

Commit 4c4cb30

Browse files
ikawrakowKawrakow
andauthored
IQ3_S: a much better alternative to Q3_K (#5676)
* iq4_nl: squash commits for easier rebase * Basics (quantize, dequantize) * CUDA dequantize and dot product * Slightly faster CUDA dot product (120 t/s) * Switch to 6-bit scales * Scalar dot product * AVX2 dot product * ARM_NEON dot product * Works on metal, but still slow * Slightly better Metal dot product * Another small Metal improvement * Metal dot product is getting there * Faster CUDA dot product * Add 1/8 ffn_down layers as Q5_K when no imatrix has been provided * Report the actual bpw * Add _xs mix that is 4.05 bpw for non-MoE models * Remove IQ4_XS for now, slightly adjust kvalues_iq4nl * AVX2 dot product uses Q8_0 instead of Q8_K * Add to test-backend-ops * Minor fix * Also use use Q5_K for attn_output in MoE models * Fixes after merging latest master * Switching to blocks of 32 * AVX2 for blocks of 32 * Scaler dot product for blocks of 32 * ARM_NEON dot product for blocks of 32 * Metal kernels for blocks of 32 * Slightly faster Metal kernels * Resurrecting iq3_xs After all the experimentation, nothing was better than this. * Minor PPL improvement via a block scale fudge factor * Minor improvement via 3 neighbours * iq3_xs: working scalar and AVX2 dot products * iq3_xs: ARM_NEON dot product - works but extremely slow (10 t/s) * iq3_xs: working Metal implementation * Adding IQ3_M - IQ3_XS mix with mostly Q4_K * iiq3_xs: a 3.4375 bpw variant * iq3_xs: make CUDA work for new version * iq3_xs: make scalar and AVX2 work for new version * iq3_s: make ARM_NEON work with new version * iq3_xs: make new version work on metal Performance is very similar to Q3_K_S * iq3_xs: tiny Metal speed improvement * iq3_xs: tiny Metal speed improvement * Fix stupid warning * Q3_K_XS now uses a mix of IQ3_XS and IQ3_XXS * iq3_xs: rename to iq3_s * iq3_s: make tests pass * Move Q3_K_XS mix to 3.25 bpw * Attempt to fix failing tests * Another attempt to fix the Windows builds * Attempt to fix ROCm * ROCm again * iq3_s: partial fix for QK_K = 64 * iq3_s: make it work on metal for QK_K = 64 Pleasent surprise: the coding was super-block size independent, so all it took was to delete some QK_K == 256 guards. * Will this fix ROCm? --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 525213d commit 4c4cb30

12 files changed

+1211
-84
lines changed

examples/quantize/quantize.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
2727
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
2828
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
2929
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
30+
{ "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", },
31+
{ "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", },
3032
{ "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" },
3133
{ "Q3_K_XS",LLAMA_FTYPE_MOSTLY_Q3_K_XS,"3-bit extra small quantization" , },
3234
{ "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", },

ggml-cuda.cu

Lines changed: 170 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@
172172
#endif
173173

174174
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
175+
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
175176
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
176177
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
177178
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
@@ -196,6 +197,18 @@ static __device__ __forceinline__ int __vsub4(const int a, const int b) {
196197
return __vsubss4(a, b);
197198
}
198199

200+
static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
201+
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
202+
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
203+
unsigned int c;
204+
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
205+
#pragma unroll
206+
for (int i = 0; i < 4; ++i) {
207+
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
208+
}
209+
return c;
210+
}
211+
199212
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
200213
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
201214
c = __builtin_amdgcn_sdot4(a, b, c, false);
@@ -518,6 +531,17 @@ typedef struct {
518531
} block_iq3_xxs;
519532
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
520533

534+
#define QR3_XS 8
535+
#define QI3_XS (QK_K / (4*QR3_XS))
536+
typedef struct {
537+
half d;
538+
uint8_t qs[QK_K/4];
539+
uint8_t qh[QK_K/32];
540+
uint8_t signs[QK_K/8];
541+
uint8_t scales[QK_K/64];
542+
} block_iq3_s;
543+
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding");
544+
521545
#define QR1_S 8
522546
#define QI1_S (QK_K / (4*QR1_S))
523547
typedef struct {
@@ -1700,6 +1724,74 @@ static const __device__ uint32_t iq3xxs_grid[256] = {
17001724
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
17011725
};
17021726

1727+
static const __device__ uint32_t iq3xs_grid[512] = {
1728+
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
1729+
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
1730+
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
1731+
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
1732+
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
1733+
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
1734+
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
1735+
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
1736+
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
1737+
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
1738+
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
1739+
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
1740+
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
1741+
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
1742+
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
1743+
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
1744+
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
1745+
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
1746+
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
1747+
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
1748+
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
1749+
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
1750+
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
1751+
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
1752+
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
1753+
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
1754+
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
1755+
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
1756+
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
1757+
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
1758+
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
1759+
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
1760+
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
1761+
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
1762+
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
1763+
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
1764+
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
1765+
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
1766+
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
1767+
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
1768+
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
1769+
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
1770+
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
1771+
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
1772+
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
1773+
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
1774+
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
1775+
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
1776+
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
1777+
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
1778+
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
1779+
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
1780+
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
1781+
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
1782+
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
1783+
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
1784+
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
1785+
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
1786+
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
1787+
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
1788+
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
1789+
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
1790+
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
1791+
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
1792+
};
1793+
1794+
17031795
static const __device__ uint64_t iq1s_grid[512] = {
17041796
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
17051797
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
@@ -1973,6 +2065,32 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
19732065

19742066
}
19752067

2068+
template<typename dst_t>
2069+
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
2070+
2071+
const int i = blockIdx.x;
2072+
const block_iq3_s * x = (const block_iq3_s *) vx;
2073+
2074+
const int tid = threadIdx.x;
2075+
#if QK_K == 256
2076+
const int il = tid/8; // 0...3
2077+
const int ib = tid%8; // 0...7
2078+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
2079+
const uint8_t * qs = x[i].qs + 8*ib;
2080+
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
2081+
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
2082+
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
2083+
const uint8_t signs = x[i].signs[4*ib + il];
2084+
for (int j = 0; j < 4; ++j) {
2085+
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
2086+
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
2087+
}
2088+
#else
2089+
assert(false);
2090+
#endif
2091+
2092+
}
2093+
19762094
template<typename dst_t>
19772095
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
19782096

@@ -4717,6 +4835,41 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
47174835
#endif
47184836
}
47194837

4838+
// TODO: don't use lookup table for signs
4839+
static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
4840+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4841+
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
4842+
#if QK_K == 256
4843+
const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
4844+
4845+
const int ib32 = iqs;
4846+
const uint8_t * qs = bq2->qs + 8*ib32;
4847+
const int8_t * q8 = bq8_1[ib32].qs;
4848+
int sumi = 0;
4849+
for (int l = 0; l < 4; ++l) {
4850+
const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
4851+
const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
4852+
uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
4853+
uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
4854+
const int grid_l = __vsub4(grid1[0] ^ signs0, signs0);
4855+
const int grid_h = __vsub4(grid2[0] ^ signs1, signs1);
4856+
sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
4857+
sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
4858+
q8 += 8;
4859+
}
4860+
const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f;
4861+
return d * sumi;
4862+
#else
4863+
assert(false);
4864+
return 0.f;
4865+
#endif
4866+
#else
4867+
assert(false);
4868+
return 0.f;
4869+
#endif
4870+
}
4871+
4872+
47204873
static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
47214874
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
47224875
#if QK_K == 256
@@ -6849,6 +7002,12 @@ static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k,
68497002
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
68507003
}
68517004

7005+
template<typename dst_t>
7006+
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
7007+
const int nb = k / QK_K;
7008+
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
7009+
}
7010+
68527011
template<typename dst_t>
68537012
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
68547013
const int nb = k / QK_K;
@@ -6904,6 +7063,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
69047063
return dequantize_row_iq1_s_cuda;
69057064
case GGML_TYPE_IQ4_NL:
69067065
return dequantize_row_iq4_nl_cuda;
7066+
case GGML_TYPE_IQ3_S:
7067+
return dequantize_row_iq3_s_cuda;
69077068
case GGML_TYPE_F32:
69087069
return convert_unary_cuda<float>;
69097070
default:
@@ -6943,6 +7104,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
69437104
return dequantize_row_iq1_s_cuda;
69447105
case GGML_TYPE_IQ4_NL:
69457106
return dequantize_row_iq4_nl_cuda;
7107+
case GGML_TYPE_IQ3_S:
7108+
return dequantize_row_iq3_s_cuda;
69467109
case GGML_TYPE_F16:
69477110
return convert_unary_cuda<half>;
69487111
default:
@@ -8688,6 +8851,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
86888851
case GGML_TYPE_IQ3_XXS:
86898852
case GGML_TYPE_IQ1_S:
86908853
case GGML_TYPE_IQ4_NL:
8854+
case GGML_TYPE_IQ3_S:
86918855
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
86928856
default:
86938857
GGML_ASSERT(false);
@@ -8713,6 +8877,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
87138877
case GGML_TYPE_IQ3_XXS:
87148878
case GGML_TYPE_IQ1_S:
87158879
case GGML_TYPE_IQ4_NL:
8880+
case GGML_TYPE_IQ3_S:
87168881
return max_compute_capability >= CC_VOLTA ? 128 : 64;
87178882
case GGML_TYPE_Q6_K:
87188883
return 64;
@@ -8818,6 +8983,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
88188983
mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
88198984
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
88208985
break;
8986+
case GGML_TYPE_IQ3_S:
8987+
mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
8988+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8989+
break;
88218990
default:
88228991
GGML_ASSERT(false);
88238992
break;
@@ -11541,7 +11710,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
1154111710
}
1154211711
ggml_type a_type = a->type;
1154311712
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
11544-
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL) {
11713+
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S) {
1154511714
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
1154611715
return false;
1154711716
}

0 commit comments

Comments
 (0)