Skip to content

Commit 4ab39d2

Browse files
committed
ggml : update cuBLAS + normalize variable names
1 parent 45a8213 commit 4ab39d2

File tree

2 files changed

+149
-168
lines changed

2 files changed

+149
-168
lines changed

ggml-cuda.cu

Lines changed: 32 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -81,49 +81,39 @@ typedef struct {
8181
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
8282

8383
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
84+
static const int qk = QK4_0;
85+
8486
const block_q4_0 * x = (const block_q4_0 *) vx;
8587

8688
const int i = blockIdx.x;
8789

8890
const float d = x[i].d;
8991

90-
const uint8_t * pp = x[i].qs;
91-
92-
for (int l = 0; l < QK4_0; l += 2) {
93-
const uint8_t vi = pp[l/2];
94-
95-
const int8_t vi0 = vi & 0xf;
96-
const int8_t vi1 = vi >> 4;
92+
for (int j = 0; j < qk/2; ++j) {
93+
const int x0 = (x[i].qs[j] & 0xf) - 8;
94+
const int x1 = (x[i].qs[j] >> 4) - 8;
9795

98-
const float v0 = (vi0 - 8)*d;
99-
const float v1 = (vi1 - 8)*d;
100-
101-
y[i*QK4_0 + l + 0] = v0;
102-
y[i*QK4_0 + l + 1] = v1;
96+
y[i*qk + j + 0 ] = x0*d;
97+
y[i*qk + j + qk/2] = x1*d;
10398
}
10499
}
105100

106101
static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
102+
static const int qk = QK4_1;
103+
107104
const block_q4_1 * x = (const block_q4_1 *) vx;
108105

109106
const int i = blockIdx.x;
110107

111108
const float d = x[i].d;
112109
const float m = x[i].m;
113110

114-
const uint8_t * pp = x[i].qs;
115-
116-
for (int l = 0; l < QK4_1; l += 2) {
117-
const uint8_t vi = pp[l/2];
118-
119-
const int8_t vi0 = vi & 0xf;
120-
const int8_t vi1 = vi >> 4;
111+
for (int j = 0; j < qk/2; ++j) {
112+
const int x0 = (x[i].qs[j] & 0xf);
113+
const int x1 = (x[i].qs[j] >> 4);
121114

122-
const float v0 = vi0*d + m;
123-
const float v1 = vi1*d + m;
124-
125-
y[i*QK4_1 + l + 0] = v0;
126-
y[i*QK4_1 + l + 1] = v1;
115+
y[i*qk + j + 0 ] = x0*d + m;
116+
y[i*qk + j + qk/2] = x1*d + m;
127117
}
128118
}
129119

@@ -151,61 +141,51 @@ static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
151141
}
152142

153143
static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
144+
static const int qk = QK5_0;
145+
154146
const block_q5_0 * x = (const block_q5_0 *) vx;
155147

156148
const int i = blockIdx.x;
157149

158150
const float d = x[i].d;
159151

160-
const uint8_t * pp = x[i].qs;
161-
162152
uint32_t qh;
163153
memcpy(&qh, x[i].qh, sizeof(qh));
164154

165-
for (int l = 0; l < QK5_0; l += 2) {
166-
const uint8_t vi = pp[l/2];
167-
168-
const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
169-
const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
155+
for (int j = 0; j < qk/2; ++j) {
156+
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
157+
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
170158

171-
const int8_t vi0 = ((vi & 0xf) | vh0);
172-
const int8_t vi1 = ((vi >> 4) | vh1);
159+
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
160+
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
173161

174-
const float v0 = (vi0 - 16)*d;
175-
const float v1 = (vi1 - 16)*d;
176-
177-
y[i*QK5_0 + l + 0] = v0;
178-
y[i*QK5_0 + l + 1] = v1;
162+
y[i*qk + j + 0 ] = x0*d;
163+
y[i*qk + j + qk/2] = x1*d;
179164
}
180165
}
181166

182167
static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
168+
static const int qk = QK5_1;
169+
183170
const block_q5_1 * x = (const block_q5_1 *) vx;
184171

185172
const int i = blockIdx.x;
186173

187174
const float d = x[i].d;
188175
const float m = x[i].m;
189176

190-
const uint8_t * pp = x[i].qs;
191-
192177
uint32_t qh;
193178
memcpy(&qh, x[i].qh, sizeof(qh));
194179

195-
for (int l = 0; l < QK5_1; l += 2) {
196-
const uint8_t vi = pp[l/2];
197-
198-
const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
199-
const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
200-
201-
const int8_t vi0 = (vi & 0xf) | vh0;
202-
const int8_t vi1 = (vi >> 4) | vh1;
180+
for (int j = 0; j < qk/2; ++j) {
181+
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
182+
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
203183

204-
const float v0 = vi0*d + m;
205-
const float v1 = vi1*d + m;
184+
const int x0 = (x[i].qs[j] & 0xf) | xh_0;
185+
const int x1 = (x[i].qs[j] >> 4) | xh_1;
206186

207-
y[i*QK5_1 + l + 0] = v0;
208-
y[i*QK5_1 + l + 1] = v1;
187+
y[i*qk + j + 0 ] = x0*d + m;
188+
y[i*qk + j + qk/2] = x1*d + m;
209189
}
210190
}
211191

0 commit comments

Comments
 (0)