@@ -8127,7 +8127,7 @@ static void ggml_compute_forward_mul_mat(
8127
8127
#endif
8128
8128
8129
8129
if (params -> type == GGML_TASK_INIT ) {
8130
- if (vec_dot_type != GGML_TYPE_F32 ) {
8130
+ if (src1 -> type != vec_dot_type ) {
8131
8131
char * wdata = params -> wdata ;
8132
8132
const size_t row_size = ne10 * GGML_TYPE_SIZE [vec_dot_type ]/GGML_BLCK_SIZE [vec_dot_type ];
8133
8133
@@ -8159,7 +8159,7 @@ static void ggml_compute_forward_mul_mat(
8159
8159
const int ir0 = dr * ith ;
8160
8160
const int ir1 = MIN (ir0 + dr , nr );
8161
8161
8162
- void * wdata = (vec_dot_type == GGML_TYPE_F32 ) ? src1 -> data : params -> wdata ;
8162
+ void * wdata = (src1 -> type == vec_dot_type ) ? src1 -> data : params -> wdata ;
8163
8163
const size_t row_size = ne00 * GGML_TYPE_SIZE [vec_dot_type ]/GGML_BLCK_SIZE [vec_dot_type ];
8164
8164
8165
8165
for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
@@ -11041,6 +11041,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
11041
11041
//printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
11042
11042
11043
11043
size_t cur = 0 ;
11044
+ const enum ggml_type vec_dot_type = type_handling [node -> src0 -> type ].vec_dot_type ;
11044
11045
11045
11046
if (node -> src0 -> type == GGML_TYPE_F32 && node -> src1 -> type == GGML_TYPE_F32 ) {
11046
11047
cur = 0 ;
@@ -11049,7 +11050,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
11049
11050
node -> n_tasks = 1 ;
11050
11051
}
11051
11052
#endif
11052
- } else if (type_handling [ node -> src0 -> type ]. vec_dot && node -> src1 -> type == GGML_TYPE_F32 ) {
11053
+ } else if (node -> src1 -> type != vec_dot_type ) {
11053
11054
#if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined(GGML_USE_CUBLAS ) || defined(GGML_USE_CLBLAST )
11054
11055
if (ggml_compute_forward_mul_mat_use_blas (node -> src0 , node -> src1 , node )) {
11055
11056
node -> n_tasks = 1 ;
@@ -11058,8 +11059,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
11058
11059
} else
11059
11060
#endif
11060
11061
{
11061
- const enum ggml_type type_q = type_handling [node -> src0 -> type ].vec_dot_type ;
11062
- cur = GGML_TYPE_SIZE [type_q ]* ggml_nelements (node -> src1 )/GGML_BLCK_SIZE [type_q ];
11062
+ cur = GGML_TYPE_SIZE [vec_dot_type ]* ggml_nelements (node -> src1 )/GGML_BLCK_SIZE [vec_dot_type ];
11063
11063
}
11064
11064
} else {
11065
11065
GGML_ASSERT (false);
0 commit comments