@@ -208,6 +208,7 @@ typedef struct {
208
208
static_assert (sizeof (block_q6_K) == sizeof(ggml_fp16_t ) + 13*QK_K/16, "wrong q6_K block size/padding");
209
209
210
210
#define WARP_SIZE 32
211
+ #define MATRIX_ROW_PADDING 256 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
211
212
212
213
#define CUDA_ADD_BLOCK_SIZE 256
213
214
#define CUDA_MUL_BLOCK_SIZE 256
@@ -1174,16 +1175,12 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
1174
1175
static __global__ void quantize_q8_1 (const float * __restrict__ x, void * __restrict__ vy, const int k) {
1175
1176
const int i = blockDim .x *blockIdx .x + threadIdx .x ;
1176
1177
1177
- if (i >= k) {
1178
- return ;
1179
- }
1180
-
1181
1178
block_q8_1 * y = (block_q8_1 *) vy;
1182
1179
1183
- const int ib = i / QK8_0 ; // block index
1184
- const int iqs = i % QK8_0 ; // quant index
1180
+ const int ib = i / QK8_1 ; // block index
1181
+ const int iqs = i % QK8_1 ; // quant index
1185
1182
1186
- const float xi = x[i];
1183
+ const float xi = i < k ? x[i] : 0 . 0f ;
1187
1184
float amax = fabsf (xi);
1188
1185
float sum = xi;
1189
1186
@@ -2359,8 +2356,10 @@ inline void ggml_cuda_op_mul_mat_vec(
2359
2356
#endif
2360
2357
2361
2358
if (use_mul_mat_vec_q) {
2359
+ int64_t padded_row_size = ne00 + MATRIX_ROW_PADDING - 1 ;
2360
+ padded_row_size -= padded_row_size % MATRIX_ROW_PADDING;
2362
2361
size_t as;
2363
- void * src1_q8_1 = ggml_cuda_pool_malloc (ne00 *sizeof (block_q8_1)/QK8_1, &as);
2362
+ void * src1_q8_1 = ggml_cuda_pool_malloc (padded_row_size *sizeof (block_q8_1)/QK8_1, &as);
2364
2363
quantize_row_q8_1_cuda (src1_ddf_i, src1_q8_1, ne00, cudaStream_main);
2365
2364
2366
2365
switch (src0->type ) {
@@ -3105,7 +3104,11 @@ void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
3105
3104
3106
3105
void ggml_cuda_transform_tensor (void * data, struct ggml_tensor * tensor) {
3107
3106
int nrows = ggml_nrows (tensor);
3107
+
3108
+ const int64_t ne0 = tensor->ne [0 ];
3109
+
3108
3110
const size_t nb1 = tensor->nb [1 ];
3111
+
3109
3112
ggml_backend backend = tensor->backend ;
3110
3113
struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu ;
3111
3114
memset (extra, 0 , sizeof (*extra));
@@ -3134,11 +3137,24 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
3134
3137
int64_t nrows_split = row_high - row_low;
3135
3138
3136
3139
const size_t offset_split = row_low*nb1;
3137
- const size_t size = ggml_nbytes_split (tensor, nrows_split);
3140
+ size_t size = ggml_nbytes_split (tensor, nrows_split);
3141
+ const size_t original_size = size;
3142
+
3143
+ // pad last row to a multiple of 256 elements to avoid out-of-bounds memory accesses
3144
+ if (ne0 % MATRIX_ROW_PADDING != 0 ) {
3145
+ size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
3146
+ * ggml_type_size (tensor->type )/ggml_blck_size (tensor->type );
3147
+ }
3138
3148
3139
- void * buf;
3149
+ char * buf;
3140
3150
CUDA_CHECK (cudaMalloc (&buf, size));
3141
- void * buf_host = (char *)data + offset_split;
3151
+ char * buf_host = (char *)data + offset_split;
3152
+
3153
+ // set padding to 0 to avoid possible NaN values
3154
+ if (size > original_size) {
3155
+ CUDA_CHECK (cudaMemset (buf + original_size, 0 , size - original_size));
3156
+ }
3157
+
3142
3158
3143
3159
cudaMemcpy (buf, buf_host, size, cudaMemcpyHostToDevice);
3144
3160
0 commit comments