@@ -81,49 +81,39 @@ typedef struct {
81
81
static_assert (sizeof (block_q8_0) == sizeof(float ) + QK8_0, "wrong q8_0 block size/padding");
82
82
83
83
static __global__ void dequantize_block_q4_0 (const void * vx, float * y) {
84
+ static const int qk = QK4_0;
85
+
84
86
const block_q4_0 * x = (const block_q4_0 *) vx;
85
87
86
88
const int i = blockIdx .x ;
87
89
88
90
const float d = x[i].d ;
89
91
90
- const uint8_t * pp = x[i].qs ;
91
-
92
- for (int l = 0 ; l < QK4_0; l += 2 ) {
93
- const uint8_t vi = pp[l/2 ];
94
-
95
- const int8_t vi0 = vi & 0xf ;
96
- const int8_t vi1 = vi >> 4 ;
92
+ for (int j = 0 ; j < qk/2 ; ++j) {
93
+ const int x0 = (x[i].qs [j] & 0xf ) - 8 ;
94
+ const int x1 = (x[i].qs [j] >> 4 ) - 8 ;
97
95
98
- const float v0 = (vi0 - 8 )*d;
99
- const float v1 = (vi1 - 8 )*d;
100
-
101
- y[i*QK4_0 + l + 0 ] = v0;
102
- y[i*QK4_0 + l + 1 ] = v1;
96
+ y[i*qk + j + 0 ] = x0*d;
97
+ y[i*qk + j + qk/2 ] = x1*d;
103
98
}
104
99
}
105
100
106
101
static __global__ void dequantize_block_q4_1 (const void * vx, float * y) {
102
+ static const int qk = QK4_1;
103
+
107
104
const block_q4_1 * x = (const block_q4_1 *) vx;
108
105
109
106
const int i = blockIdx .x ;
110
107
111
108
const float d = x[i].d ;
112
109
const float m = x[i].m ;
113
110
114
- const uint8_t * pp = x[i].qs ;
115
-
116
- for (int l = 0 ; l < QK4_1; l += 2 ) {
117
- const uint8_t vi = pp[l/2 ];
118
-
119
- const int8_t vi0 = vi & 0xf ;
120
- const int8_t vi1 = vi >> 4 ;
111
+ for (int j = 0 ; j < qk/2 ; ++j) {
112
+ const int x0 = (x[i].qs [j] & 0xf );
113
+ const int x1 = (x[i].qs [j] >> 4 );
121
114
122
- const float v0 = vi0*d + m;
123
- const float v1 = vi1*d + m;
124
-
125
- y[i*QK4_1 + l + 0 ] = v0;
126
- y[i*QK4_1 + l + 1 ] = v1;
115
+ y[i*qk + j + 0 ] = x0*d + m;
116
+ y[i*qk + j + qk/2 ] = x1*d + m;
127
117
}
128
118
}
129
119
@@ -151,61 +141,51 @@ static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
151
141
}
152
142
153
143
static __global__ void dequantize_block_q5_0 (const void * vx, float * y) {
144
+ static const int qk = QK5_0;
145
+
154
146
const block_q5_0 * x = (const block_q5_0 *) vx;
155
147
156
148
const int i = blockIdx .x ;
157
149
158
150
const float d = x[i].d ;
159
151
160
- const uint8_t * pp = x[i].qs ;
161
-
162
152
uint32_t qh;
163
153
memcpy (&qh, x[i].qh , sizeof (qh));
164
154
165
- for (int l = 0 ; l < QK5_0; l += 2 ) {
166
- const uint8_t vi = pp[l/2 ];
167
-
168
- const int8_t vh0 = ((qh & (1 << (l + 0 ))) >> (l + 0 )) << 4 ;
169
- const int8_t vh1 = ((qh & (1 << (l + 1 ))) >> (l + 1 )) << 4 ;
155
+ for (int j = 0 ; j < qk/2 ; ++j) {
156
+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4 ;
157
+ const uint8_t xh_1 = ((qh & (1u << (j + 16 ))) >> (j + 12 ));
170
158
171
- const int8_t vi0 = ((vi & 0xf ) | vh0) ;
172
- const int8_t vi1 = ((vi >> 4 ) | vh1) ;
159
+ const int32_t x0 = ((x[i]. qs [j] & 0xf ) | xh_0) - 16 ;
160
+ const int32_t x1 = ((x[i]. qs [j] >> 4 ) | xh_1) - 16 ;
173
161
174
- const float v0 = (vi0 - 16 )*d;
175
- const float v1 = (vi1 - 16 )*d;
176
-
177
- y[i*QK5_0 + l + 0 ] = v0;
178
- y[i*QK5_0 + l + 1 ] = v1;
162
+ y[i*qk + j + 0 ] = x0*d;
163
+ y[i*qk + j + qk/2 ] = x1*d;
179
164
}
180
165
}
181
166
182
167
static __global__ void dequantize_block_q5_1 (const void * vx, float * y) {
168
+ static const int qk = QK5_1;
169
+
183
170
const block_q5_1 * x = (const block_q5_1 *) vx;
184
171
185
172
const int i = blockIdx .x ;
186
173
187
174
const float d = x[i].d ;
188
175
const float m = x[i].m ;
189
176
190
- const uint8_t * pp = x[i].qs ;
191
-
192
177
uint32_t qh;
193
178
memcpy (&qh, x[i].qh , sizeof (qh));
194
179
195
- for (int l = 0 ; l < QK5_1; l += 2 ) {
196
- const uint8_t vi = pp[l/2 ];
197
-
198
- const int8_t vh0 = ((qh & (1 << (l + 0 ))) >> (l + 0 )) << 4 ;
199
- const int8_t vh1 = ((qh & (1 << (l + 1 ))) >> (l + 1 )) << 4 ;
200
-
201
- const int8_t vi0 = (vi & 0xf ) | vh0;
202
- const int8_t vi1 = (vi >> 4 ) | vh1;
180
+ for (int j = 0 ; j < qk/2 ; ++j) {
181
+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4 ;
182
+ const uint8_t xh_1 = ((qh & (1u << (j + 16 ))) >> (j + 12 ));
203
183
204
- const float v0 = vi0*d + m ;
205
- const float v1 = vi1*d + m ;
184
+ const int x0 = (x[i]. qs [j] & 0xf ) | xh_0 ;
185
+ const int x1 = (x[i]. qs [j] >> 4 ) | xh_1 ;
206
186
207
- y[i*QK5_1 + l + 0 ] = v0 ;
208
- y[i*QK5_1 + l + 1 ] = v1 ;
187
+ y[i*qk + j + 0 ] = x0*d + m ;
188
+ y[i*qk + j + qk/ 2 ] = x1*d + m ;
209
189
}
210
190
}
211
191
0 commit comments