@@ -2597,91 +2597,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
2597
2597
template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128 >;
2598
2598
// template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2599
2599
2600
- kernel void kernel_cpy_f16_f16 (
2601
- device const half * src0,
2602
- device half * dst,
2603
- constant int64_t & ne00,
2604
- constant int64_t & ne01,
2605
- constant int64_t & ne02,
2606
- constant int64_t & ne03,
2607
- constant uint64_t & nb00,
2608
- constant uint64_t & nb01,
2609
- constant uint64_t & nb02,
2610
- constant uint64_t & nb03,
2611
- constant int64_t & ne0,
2612
- constant int64_t & ne1,
2613
- constant int64_t & ne2,
2614
- constant int64_t & ne3,
2615
- constant uint64_t & nb0,
2616
- constant uint64_t & nb1,
2617
- constant uint64_t & nb2,
2618
- constant uint64_t & nb3,
2619
- uint3 tgpig[[threadgroup_position_in_grid]],
2620
- uint3 tpitg[[thread_position_in_threadgroup]],
2621
- uint3 ntg[[threads_per_threadgroup]]) {
2622
- const int64_t i03 = tgpig[2 ];
2623
- const int64_t i02 = tgpig[1 ];
2624
- const int64_t i01 = tgpig[0 ];
2625
-
2626
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2627
-
2628
- const int64_t i3 = n / (ne2*ne1*ne0);
2629
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2630
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2631
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2632
-
2633
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2634
-
2635
- for (int64_t i00 = tpitg.x ; i00 < ne00; i00 += ntg.x ) {
2636
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2637
- dst_data[i00] = src[0 ];
2638
- }
2639
- }
2640
-
2641
- kernel void kernel_cpy_f16_f32 (
2642
- device const half * src0,
2643
- device float * dst,
2644
- constant int64_t & ne00,
2645
- constant int64_t & ne01,
2646
- constant int64_t & ne02,
2647
- constant int64_t & ne03,
2648
- constant uint64_t & nb00,
2649
- constant uint64_t & nb01,
2650
- constant uint64_t & nb02,
2651
- constant uint64_t & nb03,
2652
- constant int64_t & ne0,
2653
- constant int64_t & ne1,
2654
- constant int64_t & ne2,
2655
- constant int64_t & ne3,
2656
- constant uint64_t & nb0,
2657
- constant uint64_t & nb1,
2658
- constant uint64_t & nb2,
2659
- constant uint64_t & nb3,
2660
- uint3 tgpig[[threadgroup_position_in_grid]],
2661
- uint3 tpitg[[thread_position_in_threadgroup]],
2662
- uint3 ntg[[threads_per_threadgroup]]) {
2663
- const int64_t i03 = tgpig[2 ];
2664
- const int64_t i02 = tgpig[1 ];
2665
- const int64_t i01 = tgpig[0 ];
2666
-
2667
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2668
-
2669
- const int64_t i3 = n / (ne2*ne1*ne0);
2670
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2671
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2672
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2673
-
2674
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2675
-
2676
- for (int64_t i00 = tpitg.x ; i00 < ne00; i00 += ntg.x ) {
2677
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2678
- dst_data[i00] = src[0 ];
2679
- }
2680
- }
2681
-
2682
- kernel void kernel_cpy_f32_f16 (
2683
- device const float * src0,
2684
- device half * dst,
2600
+ template <typename T0, typename T1>
2601
+ kernel void kernel_cpy (
2602
+ device const void * src0,
2603
+ device void * dst,
2685
2604
constant int64_t & ne00,
2686
2605
constant int64_t & ne01,
2687
2606
constant int64_t & ne02,
@@ -2712,56 +2631,22 @@ kernel void kernel_cpy_f32_f16(
2712
2631
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2713
2632
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2714
2633
2715
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2634
+ device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2716
2635
2717
2636
for (int64_t i00 = tpitg.x ; i00 < ne00; i00 += ntg.x ) {
2718
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2719
-
2720
- dst_data[i00] = src[0 ];
2637
+ device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2638
+ dst_data[i00] = (T1) src[0 ];
2721
2639
}
2722
2640
}
2723
2641
2724
- kernel void kernel_cpy_f32_f32 (
2725
- device const float * src0,
2726
- device float * dst,
2727
- constant int64_t & ne00,
2728
- constant int64_t & ne01,
2729
- constant int64_t & ne02,
2730
- constant int64_t & ne03,
2731
- constant uint64_t & nb00,
2732
- constant uint64_t & nb01,
2733
- constant uint64_t & nb02,
2734
- constant uint64_t & nb03,
2735
- constant int64_t & ne0,
2736
- constant int64_t & ne1,
2737
- constant int64_t & ne2,
2738
- constant int64_t & ne3,
2739
- constant uint64_t & nb0,
2740
- constant uint64_t & nb1,
2741
- constant uint64_t & nb2,
2742
- constant uint64_t & nb3,
2743
- uint3 tgpig[[threadgroup_position_in_grid]],
2744
- uint3 tpitg[[thread_position_in_threadgroup]],
2745
- uint3 ntg[[threads_per_threadgroup]]) {
2746
- const int64_t i03 = tgpig[2 ];
2747
- const int64_t i02 = tgpig[1 ];
2748
- const int64_t i01 = tgpig[0 ];
2749
-
2750
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2751
-
2752
- const int64_t i3 = n / (ne2*ne1*ne0);
2753
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2754
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2755
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2756
-
2757
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2642
+ typedef decltype (kernel_cpy<float , float >) kernel_cpy_t;
2758
2643
2759
- for ( int64_t i00 = tpitg. x ; i00 < ne00; i00 += ntg. x ) {
2760
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00) ;
2761
-
2762
- dst_data[i00] = src[ 0 ] ;
2763
- }
2764
- }
2644
+ template [[host_name( " kernel_cpy_f32_f32 " )]] kernel kernel_cpy_t kernel_cpy< float , float >;
2645
+ template [[host_name( " kernel_cpy_f32_bf16 " )]] kernel kernel_cpy_t kernel_cpy< float , bfloat> ;
2646
+ template [[host_name( " kernel_cpy_f32_f16 " )]] kernel kernel_cpy_t kernel_cpy< float , half>;
2647
+ template [[host_name( " kernel_cpy_bf16_f32 " )]] kernel kernel_cpy_t kernel_cpy<bfloat, float > ;
2648
+ template [[host_name( " kernel_cpy_f16_f16 " )]] kernel kernel_cpy_t kernel_cpy<half, half>;
2649
+ template [[host_name( " kernel_cpy_f16_f32 " )]] kernel kernel_cpy_t kernel_cpy<half, float >;
2765
2650
2766
2651
kernel void kernel_cpy_f32_q8_0 (
2767
2652
device const float * src0,
0 commit comments