@@ -5893,27 +5893,36 @@ static void ggml_compute_forward_add_f16_f32(
5893
5893
const int n = ggml_nrows (src0 );
5894
5894
const int nc = src0 -> ne [0 ];
5895
5895
5896
- // const size_t nb00 = src0->nb[0];
5896
+ const size_t nb00 = src0 -> nb [0 ];
5897
5897
const size_t nb01 = src0 -> nb [1 ];
5898
5898
5899
5899
const size_t nb10 = src1 -> nb [0 ];
5900
5900
const size_t nb11 = src1 -> nb [1 ];
5901
5901
5902
- // const size_t nb0 = dst->nb[0];
5902
+ const size_t nb0 = dst -> nb [0 ];
5903
5903
const size_t nb1 = dst -> nb [1 ];
5904
5904
5905
5905
GGML_ASSERT (src0 -> type == GGML_TYPE_F16 );
5906
5906
GGML_ASSERT (src1 -> type == GGML_TYPE_F32 );
5907
5907
GGML_ASSERT (dst -> type == GGML_TYPE_F16 );
5908
5908
5909
- for (int j = ith ; j < n ; j += nth ) {
5910
- ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) ((char * ) dst -> data + j * nb1 );
5911
- ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + j * nb01 );
5912
- for (int i = 0 ; i < nc ; i ++ ) {
5913
- float * src1_ptr = (float * ) ((char * ) src1 -> data + j * nb11 + i * nb10 );
5914
- dst_ptr [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (src0_ptr [i ]) + * src1_ptr );
5909
+ GGML_ASSERT ( nb0 == sizeof (ggml_fp16_t ));
5910
+ GGML_ASSERT (nb00 == sizeof (ggml_fp16_t ));
5911
+
5912
+ if (nb10 == sizeof (float )) {
5913
+ for (int j = ith ; j < n ; j += nth ) {
5914
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) ((char * ) dst -> data + j * nb1 );
5915
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + j * nb01 );
5916
+ for (int i = 0 ; i < nc ; i ++ ) {
5917
+ float * src1_ptr = (float * ) ((char * ) src1 -> data + j * nb11 + i * nb10 );
5918
+ dst_ptr [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (src0_ptr [i ]) + * src1_ptr );
5919
+ }
5915
5920
}
5916
5921
}
5922
+ else {
5923
+ // src1 is not contiguous
5924
+ GGML_ASSERT (false);
5925
+ }
5917
5926
}
5918
5927
5919
5928
static void ggml_compute_forward_add_f16_f16 (
@@ -5933,27 +5942,36 @@ static void ggml_compute_forward_add_f16_f16(
5933
5942
const int n = ggml_nrows (src0 );
5934
5943
const int nc = src0 -> ne [0 ];
5935
5944
5936
- // const size_t nb00 = src0->nb[0];
5945
+ const size_t nb00 = src0 -> nb [0 ];
5937
5946
const size_t nb01 = src0 -> nb [1 ];
5938
5947
5939
5948
const size_t nb10 = src1 -> nb [0 ];
5940
5949
const size_t nb11 = src1 -> nb [1 ];
5941
5950
5942
- // const size_t nb0 = dst->nb[0];
5951
+ const size_t nb0 = dst -> nb [0 ];
5943
5952
const size_t nb1 = dst -> nb [1 ];
5944
5953
5945
5954
GGML_ASSERT (src0 -> type == GGML_TYPE_F16 );
5946
5955
GGML_ASSERT (src1 -> type == GGML_TYPE_F16 );
5947
5956
GGML_ASSERT (dst -> type == GGML_TYPE_F16 );
5948
5957
5949
- for (int j = ith ; j < n ; j += nth ) {
5950
- ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) ((char * ) dst -> data + j * nb1 );
5951
- ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + j * nb01 );
5952
- for (int i = 0 ; i < nc ; i ++ ) {
5953
- ggml_fp16_t * src1_ptr = (ggml_fp16_t * ) ((char * ) src1 -> data + j * nb11 + i * nb10 );
5954
- dst_ptr [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (src0_ptr [i ]) + GGML_FP16_TO_FP32 (* src1_ptr ));
5958
+ GGML_ASSERT ( nb0 == sizeof (ggml_fp16_t ));
5959
+ GGML_ASSERT (nb00 == sizeof (ggml_fp16_t ));
5960
+
5961
+ if (nb10 == sizeof (ggml_fp16_t )) {
5962
+ for (int j = ith ; j < n ; j += nth ) {
5963
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) ((char * ) dst -> data + j * nb1 );
5964
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + j * nb01 );
5965
+ for (int i = 0 ; i < nc ; i ++ ) {
5966
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t * ) ((char * ) src1 -> data + j * nb11 + i * nb10 );
5967
+ dst_ptr [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (src0_ptr [i ]) + GGML_FP16_TO_FP32 (* src1_ptr ));
5968
+ }
5955
5969
}
5956
5970
}
5971
+ else {
5972
+ // src1 is not contiguous
5973
+ GGML_ASSERT (false);
5974
+ }
5957
5975
}
5958
5976
5959
5977
static void ggml_compute_forward_add_q_f32 (
0 commit comments