@@ -8245,8 +8245,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
82458245 ggml_fp16_t * d_X = ggml_cuda_pool_malloc (sizeof (float ) * x_ne , & x_size );
82468246 ggml_fp16_t * d_Y = ggml_cuda_pool_malloc (sizeof (float ) * y_ne , & y_size );
82478247 float * d_D = ggml_cuda_pool_malloc (sizeof (float ) * d_ne , & d_size );
8248- #else
8249- float * const wdata = params -> wdata ;
82508248#endif
82518249 for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
82528250 for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
@@ -8263,15 +8261,20 @@ static void ggml_compute_forward_mul_mat_f16_f32(
82638261 wdata [id ++ ] = GGML_FP32_TO_FP16 (* (float * ) ((char * ) src1 -> data + i03 * nb13 + i02 * nb12 + i01 * nb11 + i00 * nb10 ));
82648262 }
82658263 }
8264+
8265+ assert (id * sizeof (ggml_fp16_t ) <= params -> wsize );
82668266 }
82678267#else
8268+ float * const wdata = params -> wdata ;
82688269 {
82698270 size_t id = 0 ;
82708271 for (int64_t i01 = 0 ; i01 < ne01 ; ++ i01 ) {
82718272 for (int64_t i00 = 0 ; i00 < ne00 ; ++ i00 ) {
82728273 wdata [id ++ ] = GGML_FP16_TO_FP32 (* (ggml_fp16_t * ) ((char * ) src0 -> data + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00 ));
82738274 }
82748275 }
8276+
8277+ assert (id * sizeof (float ) <= params -> wsize );
82758278 }
82768279#endif
82778280
@@ -8537,7 +8540,10 @@ static void ggml_compute_forward_mul_mat_q_f32(
85378540 dequantize_row_q ((char * ) src0 -> data + i03 * nb03 + i02 * nb02 + i01 * nb01 , wdata + id , ne00 );
85388541 id += ne00 ;
85398542 }
8543+
8544+ assert (id * sizeof (float ) <= params -> wsize );
85408545 }
8546+
85418547 const float * x = wdata ;
85428548#endif
85438549
@@ -9118,7 +9124,7 @@ static void ggml_compute_forward_alibi_f32(
91189124 //const int nb3 = src0->nb[3];
91199125
91209126 assert (nb0 == sizeof (float ));
9121- assert (ne1 + n_past == ne0 );
9127+ assert (ne1 + n_past == ne0 ); ( void ) n_past ;
91229128
91239129 // add alibi to src0 (KQ_scaled)
91249130 const int n_heads_log2_floor = 1 << (int ) floor (log2 (n_head ));
@@ -9179,7 +9185,7 @@ static void ggml_compute_forward_alibi_f16(
91799185 //const int nb3 = src0->nb[3];
91809186
91819187 assert (nb0 == sizeof (ggml_fp16_t ));
9182- assert (ne1 + n_past == ne0 );
9188+ assert (ne1 + n_past == ne0 ); ( void ) n_past ;
91839189
91849190 // add alibi to src0 (KQ_scaled)
91859191 const int n_heads_log2_floor = 1 << (int ) floor (log2 (n_head ));
@@ -11571,12 +11577,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1157111577 if (ggml_compute_forward_mul_mat_use_blas (node -> src0 , node -> src1 , node )) {
1157211578 node -> n_tasks = 1 ; // TODO: this actually is doing nothing
1157311579 // the threads are still spinning
11574- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
11580+ #if defined(GGML_USE_CUBLAS )
11581+ // with cuBLAS, we need memory for the full 3D / 4D data of src1
11582+ cur = GGML_TYPE_SIZE [GGML_TYPE_F16 ]* ggml_nelements (node -> src1 );
11583+ #else
1157511584 // here we need memory just for single 2D matrix from src0
1157611585 cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ]* (node -> src0 -> ne [0 ]* node -> src0 -> ne [1 ]);
11577- #else
11578- // with GPU, we need memory for the full 3D / 4D data
11579- cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ]* MAX (ggml_nelements (node -> src1 ), ggml_nelements (node -> src0 ));
1158011586#endif
1158111587 } else {
1158211588 cur = GGML_TYPE_SIZE [GGML_TYPE_F16 ]* ggml_nelements (node -> src1 );
@@ -11586,7 +11592,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1158611592#endif
1158711593 } else if (node -> src0 -> type == GGML_TYPE_F32 && node -> src1 -> type == GGML_TYPE_F32 ) {
1158811594 cur = 0 ;
11589- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined(GGML_USE_CUBLAS )
11595+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined(GGML_USE_CUBLAS ) || defined( GGML_USE_CLBLAST )
1159011596 if (ggml_compute_forward_mul_mat_use_blas (node -> src0 , node -> src1 , node )) {
1159111597 node -> n_tasks = 1 ;
1159211598 }
0 commit comments