1
1
#include "ggml.h"
2
2
3
+ #include <alloca.h>
3
4
#include <assert.h>
4
5
#include <time.h>
5
6
#include <math.h>
12
13
#include <pthread.h>
13
14
14
15
#define GGML_DEBUG 0
15
- #define GGML_MEM_ALIGN 16
16
+
17
+ #if UINTPTR_MAX == 0xFFFFFFFF
18
+ #define GGML_MEM_ALIGN 4
19
+ #else
20
+ #define GGML_MEM_ALIGN 16
21
+ #endif
16
22
17
23
#define MAX (a , b ) ((a) > (b) ? (a) : (b))
18
24
#define MIN (a , b ) ((a) < (b) ? (a) : (b))
@@ -305,6 +311,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
305
311
#ifdef __ARM_NEON
306
312
const int n32 = (n & ~31 );
307
313
314
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC )
308
315
float16x8_t sum0 = vdupq_n_f16 (0 );
309
316
float16x8_t sum1 = vdupq_n_f16 (0 );
310
317
float16x8_t sum2 = vdupq_n_f16 (0 );
@@ -344,6 +351,61 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
344
351
345
352
float32x2_t sumf32 = vadd_f32 (vget_low_f32 (sum0f32 ), vget_high_f32 (sum0f32 ));
346
353
sumf = vget_lane_f32 (sumf32 , 0 ) + vget_lane_f32 (sumf32 , 1 );
354
+ #else
355
+ float32x4_t sum0 = vdupq_n_f32 (0 );
356
+ float32x4_t sum1 = vdupq_n_f32 (0 );
357
+ float32x4_t sum2 = vdupq_n_f32 (0 );
358
+ float32x4_t sum3 = vdupq_n_f32 (0 );
359
+ float32x4_t sum4 = vdupq_n_f32 (0 );
360
+ float32x4_t sum5 = vdupq_n_f32 (0 );
361
+ float32x4_t sum6 = vdupq_n_f32 (0 );
362
+ float32x4_t sum7 = vdupq_n_f32 (0 );
363
+
364
+ float32x4_t x0 , x1 , x2 , x3 , x4 , x5 , x6 , x7 ;
365
+ float32x4_t y0 , y1 , y2 , y3 , y4 , y5 , y6 , y7 ;
366
+
367
+ for (int i = 0 ; i < n32 ; i += 32 ) {
368
+ x0 = vcvt_f32_f16 (vld1_f16 (x + i + 0 ));
369
+ x1 = vcvt_f32_f16 (vld1_f16 (x + i + 4 ));
370
+ x2 = vcvt_f32_f16 (vld1_f16 (x + i + 8 ));
371
+ x3 = vcvt_f32_f16 (vld1_f16 (x + i + 12 ));
372
+ x4 = vcvt_f32_f16 (vld1_f16 (x + i + 16 ));
373
+ x5 = vcvt_f32_f16 (vld1_f16 (x + i + 20 ));
374
+ x6 = vcvt_f32_f16 (vld1_f16 (x + i + 24 ));
375
+ x7 = vcvt_f32_f16 (vld1_f16 (x + i + 28 ));
376
+
377
+ y0 = vcvt_f32_f16 (vld1_f16 (y + i + 0 ));
378
+ y1 = vcvt_f32_f16 (vld1_f16 (y + i + 4 ));
379
+ y2 = vcvt_f32_f16 (vld1_f16 (y + i + 8 ));
380
+ y3 = vcvt_f32_f16 (vld1_f16 (y + i + 12 ));
381
+ y4 = vcvt_f32_f16 (vld1_f16 (y + i + 16 ));
382
+ y5 = vcvt_f32_f16 (vld1_f16 (y + i + 20 ));
383
+ y6 = vcvt_f32_f16 (vld1_f16 (y + i + 24 ));
384
+ y7 = vcvt_f32_f16 (vld1_f16 (y + i + 28 ));
385
+
386
+ sum0 = vfmaq_f32 (sum0 , x0 , y0 );
387
+ sum1 = vfmaq_f32 (sum1 , x1 , y1 );
388
+ sum2 = vfmaq_f32 (sum2 , x2 , y2 );
389
+ sum3 = vfmaq_f32 (sum3 , x3 , y3 );
390
+ sum4 = vfmaq_f32 (sum4 , x4 , y4 );
391
+ sum5 = vfmaq_f32 (sum5 , x5 , y5 );
392
+ sum6 = vfmaq_f32 (sum6 , x6 , y6 );
393
+ sum7 = vfmaq_f32 (sum7 , x7 , y7 );
394
+ }
395
+
396
+ // reduce sum0..sum7 to sum0
397
+ sum0 = vaddq_f32 (sum0 , sum1 );
398
+ sum2 = vaddq_f32 (sum2 , sum3 );
399
+ sum4 = vaddq_f32 (sum4 , sum5 );
400
+ sum6 = vaddq_f32 (sum6 , sum7 );
401
+ sum0 = vaddq_f32 (sum0 , sum2 );
402
+ sum4 = vaddq_f32 (sum4 , sum6 );
403
+ sum0 = vaddq_f32 (sum0 , sum4 );
404
+
405
+ // reduce sum0 to sumf
406
+ float32x2_t sumf32 = vadd_f32 (vget_low_f32 (sum0 ), vget_high_f32 (sum0 ));
407
+ sumf = vget_lane_f32 (sumf32 , 0 ) + vget_lane_f32 (sumf32 , 1 );
408
+ #endif
347
409
348
410
// leftovers
349
411
for (int i = n32 ; i < n ; ++ i ) {
@@ -486,6 +548,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
486
548
// NEON 128-bit
487
549
const int n32 = (n & ~31 );
488
550
551
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC )
489
552
const float16x8_t v8 = vdupq_n_f16 (v );
490
553
491
554
float16x8_t x0 , x1 , x2 , x3 ;
@@ -512,6 +575,51 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
512
575
vst1q_f16 (y + i + 16 , y2 );
513
576
vst1q_f16 (y + i + 24 , y3 );
514
577
}
578
+ #else
579
+ const float32x4_t v40 = vdupq_n_f32 (v );
580
+ const float32x4_t v41 = vdupq_n_f32 (v );
581
+
582
+ float32x4_t x0 , x1 , x2 , x3 , x4 , x5 , x6 , x7 ;
583
+ float32x4_t y0 , y1 , y2 , y3 , y4 , y5 , y6 , y7 ;
584
+
585
+ for (int i = 0 ; i < n32 ; i += 32 ) {
586
+ y0 = vcvt_f32_f16 (vld1_f16 (y + i + 0 ));
587
+ y1 = vcvt_f32_f16 (vld1_f16 (y + i + 4 ));
588
+ y2 = vcvt_f32_f16 (vld1_f16 (y + i + 8 ));
589
+ y3 = vcvt_f32_f16 (vld1_f16 (y + i + 12 ));
590
+ y4 = vcvt_f32_f16 (vld1_f16 (y + i + 16 ));
591
+ y5 = vcvt_f32_f16 (vld1_f16 (y + i + 20 ));
592
+ y6 = vcvt_f32_f16 (vld1_f16 (y + i + 24 ));
593
+ y7 = vcvt_f32_f16 (vld1_f16 (y + i + 28 ));
594
+
595
+ x0 = vcvt_f32_f16 (vld1_f16 (x + i + 0 ));
596
+ x1 = vcvt_f32_f16 (vld1_f16 (x + i + 4 ));
597
+ x2 = vcvt_f32_f16 (vld1_f16 (x + i + 8 ));
598
+ x3 = vcvt_f32_f16 (vld1_f16 (x + i + 12 ));
599
+ x4 = vcvt_f32_f16 (vld1_f16 (x + i + 16 ));
600
+ x5 = vcvt_f32_f16 (vld1_f16 (x + i + 20 ));
601
+ x6 = vcvt_f32_f16 (vld1_f16 (x + i + 24 ));
602
+ x7 = vcvt_f32_f16 (vld1_f16 (x + i + 28 ));
603
+
604
+ y0 = vfmaq_f32 (y0 , x0 , v40 );
605
+ y1 = vfmaq_f32 (y1 , x1 , v40 );
606
+ y2 = vfmaq_f32 (y2 , x2 , v40 );
607
+ y3 = vfmaq_f32 (y3 , x3 , v40 );
608
+ y4 = vfmaq_f32 (y4 , x4 , v41 );
609
+ y5 = vfmaq_f32 (y5 , x5 , v41 );
610
+ y6 = vfmaq_f32 (y6 , x6 , v41 );
611
+ y7 = vfmaq_f32 (y7 , x7 , v41 );
612
+
613
+ vst1_f16 (y + i + 0 , vcvt_f16_f32 (y0 ));
614
+ vst1_f16 (y + i + 4 , vcvt_f16_f32 (y1 ));
615
+ vst1_f16 (y + i + 8 , vcvt_f16_f32 (y2 ));
616
+ vst1_f16 (y + i + 12 , vcvt_f16_f32 (y3 ));
617
+ vst1_f16 (y + i + 16 , vcvt_f16_f32 (y4 ));
618
+ vst1_f16 (y + i + 20 , vcvt_f16_f32 (y5 ));
619
+ vst1_f16 (y + i + 24 , vcvt_f16_f32 (y6 ));
620
+ vst1_f16 (y + i + 28 , vcvt_f16_f32 (y7 ));
621
+ }
622
+ #endif
515
623
516
624
// leftovers
517
625
for (int i = n32 ; i < n ; ++ i ) {
@@ -911,16 +1019,18 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
911
1019
if (is_first_call ) {
912
1020
const uint64_t t_start = ggml_time_us (); UNUSED (t_start );
913
1021
1022
+ ggml_fp16_t ii ;
914
1023
for (int i = 0 ; i < (1 << 16 ); ++ i ) {
915
- uint16_t ii = (uint16_t ) i ;
916
- const float f = ggml_fp16_to_fp32 (* (ggml_fp16_t * )(& ii ));
1024
+ uint16_t ui = i ;
1025
+ memcpy (& ii , & ui , sizeof (ii ));
1026
+ const float f = ggml_fp16_to_fp32 (ii );
917
1027
table_gelu_f16 [i ] = ggml_fp32_to_fp16 (ggml_gelu_f32 (f ));
918
1028
table_exp_f16 [i ] = ggml_fp32_to_fp16 (exp (f ));
919
1029
}
920
1030
921
1031
const uint64_t t_end = ggml_time_us (); UNUSED (t_end );
922
1032
923
- GGML_PRINT_DEBUG ("%s: GELU table initialized in %f ms\n" , __func__ , (t_end - t_start )/1000.0f );
1033
+ GGML_PRINT_DEBUG ("%s: GELU and EXP tables initialized in %f ms\n" , __func__ , (t_end - t_start )/1000.0f );
924
1034
925
1035
is_first_call = false;
926
1036
}
@@ -4427,13 +4537,15 @@ void ggml_compute_forward_soft_max_f32(
4427
4537
4428
4538
ggml_float sum = 0.0 ;
4429
4539
4540
+ uint16_t ss ;
4430
4541
for (int i = 0 ; i < nc ; i ++ ) {
4431
4542
if (p [i ] == - INFINITY ) {
4432
4543
p [i ] = 0.0 ;
4433
4544
} else {
4434
4545
//const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
4435
4546
ggml_fp16_t s = ggml_fp32_to_fp16 (p [i ] - max );
4436
- const float val = ggml_fp16_to_fp32 (table_exp_f16 [* (uint16_t * ) & s ]);
4547
+ memcpy (& ss , & s , sizeof (ss ));
4548
+ const float val = ggml_fp16_to_fp32 (table_exp_f16 [ss ]);
4437
4549
sum += val ;
4438
4550
p [i ] = val ;
4439
4551
}
@@ -5234,13 +5346,15 @@ void ggml_compute_forward_flash_attn_f32(
5234
5346
5235
5347
ggml_float sum = 0.0 ;
5236
5348
5349
+ uint16_t ss ;
5237
5350
for (int i = 0 ; i < M ; i ++ ) {
5238
5351
if (S [i ] == - INFINITY ) {
5239
5352
S [i ] = 0.0 ;
5240
5353
} else {
5241
5354
//const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
5242
5355
ggml_fp16_t s = ggml_fp32_to_fp16 (S [i ] - max );
5243
- const float val = ggml_fp16_to_fp32 (table_exp_f16 [* (uint16_t * ) & s ]);
5356
+ memcpy (& ss , & s , sizeof (ss ));
5357
+ const float val = ggml_fp16_to_fp32 (table_exp_f16 [ss ]);
5244
5358
sum += val ;
5245
5359
S [i ] = val ;
5246
5360
}
@@ -5413,13 +5527,15 @@ void ggml_compute_forward_flash_attn_f16(
5413
5527
5414
5528
ggml_float sum = 0.0 ;
5415
5529
5530
+ uint16_t ss ;
5416
5531
for (int i = 0 ; i < M ; i ++ ) {
5417
5532
if (S [i ] == - INFINITY ) {
5418
5533
S [i ] = 0.0 ;
5419
5534
} else {
5420
5535
//const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
5421
5536
ggml_fp16_t s = ggml_fp32_to_fp16 (S [i ] - max );
5422
- const float val = ggml_fp16_to_fp32 (table_exp_f16 [* (uint16_t * ) & s ]);
5537
+ memcpy (& ss , & s , sizeof (ss ));
5538
+ const float val = ggml_fp16_to_fp32 (table_exp_f16 [ss ]);
5423
5539
sum += val ;
5424
5540
S [i ] = val ;
5425
5541
}
0 commit comments