Skip to content

Commit 8d37db3

Browse files
committed
ggml_add: Add more checks
1 parent 0a6d5ad commit 8d37db3

File tree

1 file changed

+34
-16
lines changed

1 file changed

+34
-16
lines changed

ggml.c

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5893,27 +5893,36 @@ static void ggml_compute_forward_add_f16_f32(
58935893
const int n = ggml_nrows(src0);
58945894
const int nc = src0->ne[0];
58955895

5896-
//const size_t nb00 = src0->nb[0];
5896+
const size_t nb00 = src0->nb[0];
58975897
const size_t nb01 = src0->nb[1];
58985898

58995899
const size_t nb10 = src1->nb[0];
59005900
const size_t nb11 = src1->nb[1];
59015901

5902-
//const size_t nb0 = dst->nb[0];
5902+
const size_t nb0 = dst->nb[0];
59035903
const size_t nb1 = dst->nb[1];
59045904

59055905
GGML_ASSERT(src0->type == GGML_TYPE_F16);
59065906
GGML_ASSERT(src1->type == GGML_TYPE_F32);
59075907
GGML_ASSERT(dst->type == GGML_TYPE_F16);
59085908

5909-
for (int j = ith; j < n; j += nth) {
5910-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
5911-
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5912-
for (int i = 0; i < nc; i++) {
5913-
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
5914-
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
5909+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
5910+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5911+
5912+
if (nb10 == sizeof(float)) {
5913+
for (int j = ith; j < n; j += nth) {
5914+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
5915+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5916+
for (int i = 0; i < nc; i++) {
5917+
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
5918+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
5919+
}
59155920
}
59165921
}
5922+
else {
5923+
// src1 is not contiguous
5924+
GGML_ASSERT(false);
5925+
}
59175926
}
59185927

59195928
static void ggml_compute_forward_add_f16_f16(
@@ -5933,27 +5942,36 @@ static void ggml_compute_forward_add_f16_f16(
59335942
const int n = ggml_nrows(src0);
59345943
const int nc = src0->ne[0];
59355944

5936-
//const size_t nb00 = src0->nb[0];
5945+
const size_t nb00 = src0->nb[0];
59375946
const size_t nb01 = src0->nb[1];
59385947

59395948
const size_t nb10 = src1->nb[0];
59405949
const size_t nb11 = src1->nb[1];
59415950

5942-
//const size_t nb0 = dst->nb[0];
5951+
const size_t nb0 = dst->nb[0];
59435952
const size_t nb1 = dst->nb[1];
59445953

59455954
GGML_ASSERT(src0->type == GGML_TYPE_F16);
59465955
GGML_ASSERT(src1->type == GGML_TYPE_F16);
59475956
GGML_ASSERT(dst->type == GGML_TYPE_F16);
59485957

5949-
for (int j = ith; j < n; j += nth) {
5950-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
5951-
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5952-
for (int i = 0; i < nc; i++) {
5953-
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
5954-
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
5958+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
5959+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5960+
5961+
if (nb10 == sizeof(ggml_fp16_t)) {
5962+
for (int j = ith; j < n; j += nth) {
5963+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
5964+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5965+
for (int i = 0; i < nc; i++) {
5966+
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
5967+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
5968+
}
59555969
}
59565970
}
5971+
else {
5972+
// src1 is not contiguous
5973+
GGML_ASSERT(false);
5974+
}
59575975
}
59585976

59595977
static void ggml_compute_forward_add_q_f32(

0 commit comments

Comments
 (0)