@@ -175,6 +175,30 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
175
175
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
176
176
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
177
177
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
178
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
179
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
180
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
181
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
182
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
183
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
184
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
185
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
186
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
187
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
188
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
189
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
190
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
191
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
192
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
193
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
194
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
195
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
196
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
197
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
198
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
199
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
200
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
201
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
178
202
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
179
203
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
180
204
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
@@ -699,6 +723,30 @@ @implementation GGMLMetalClass
699
723
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
700
724
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
701
725
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
726
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
727
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
728
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
729
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
730
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
731
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
732
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
733
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
734
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
735
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
736
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
737
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
738
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
739
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
740
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
741
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
742
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
743
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
744
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
745
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
746
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
747
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
748
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
749
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
702
750
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
703
751
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
704
752
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
@@ -1930,28 +1978,128 @@ static void ggml_metal_encode_node(
1930
1978
// to the matrix-vector kernel
1931
1979
int ne11_mm_min = 4 ;
1932
1980
1933
- #if 0
1934
- // the numbers below are measured on M2 Ultra for 7B and 13B models
1935
- // these numbers do not translate to other devices or model sizes
1936
- // TODO: need to find a better approach
1937
- if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
1938
- switch (src0t) {
1939
- case GGML_TYPE_F16: ne11_mm_min = 2; break;
1940
- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1941
- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1942
- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1943
- case GGML_TYPE_Q4_0:
1944
- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1945
- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1946
- case GGML_TYPE_Q5_0: // not tested yet
1947
- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1948
- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1949
- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1950
- default: ne11_mm_min = 1; break;
1951
- }
1952
- }
1953
- #endif
1981
+ if ((src0t == GGML_TYPE_F16 || // TODO: helper function
1982
+ src0t == GGML_TYPE_Q4_0 ||
1983
+ src0t == GGML_TYPE_Q4_1 ||
1984
+ src0t == GGML_TYPE_Q5_0 ||
1985
+ src0t == GGML_TYPE_Q5_1 ||
1986
+ src0t == GGML_TYPE_Q8_0
1987
+ ) &&
1988
+ src1t == GGML_TYPE_F32 &&
1989
+ (ne00%256 == 0 ) && // TODO: this can be relaxed to 128 for nxpsg == 8
1990
+ (ne11 >= 2 && ne11 <= 8 )) {
1991
+
1992
+ // TODO: determine the optimal parameters based on grid utilization
1993
+ const int nsg = 2 ; // TODO: or 4?
1994
+ const int nxpsg = ne11 < 3 ? 16 : 8 ;
1995
+ const int nypsg = 32 /nxpsg;
1996
+ const int r0ptg = nypsg*nsg;
1997
+ int r1ptg = 4 ;
1998
+
1999
+ switch (ne11) {
2000
+ case 2 :
2001
+ r1ptg = 2 ; break ;
2002
+ case 3 :
2003
+ case 6 :
2004
+ r1ptg = 3 ; break ;
2005
+ case 4 :
2006
+ case 7 :
2007
+ case 8 :
2008
+ r1ptg = 4 ; break ;
2009
+ case 5 :
2010
+ r1ptg = 5 ; break ;
2011
+ };
2012
+
2013
+ assert (nxpsg >= 8 );
2014
+ assert (nxpsg%8 == 0 );
2015
+
2016
+ id <MTLComputePipelineState > pipeline = nil ;
2017
+
2018
+ switch (src0->type ) {
2019
+ case GGML_TYPE_F16:
2020
+ switch (r1ptg) {
2021
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline ; break ;
2022
+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline ; break ;
2023
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline ; break ;
2024
+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline ; break ;
2025
+ default : GGML_ABORT (" not implemented" );
2026
+ } break ;
2027
+ case GGML_TYPE_Q4_0:
2028
+ switch (r1ptg) {
2029
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline ; break ;
2030
+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline ; break ;
2031
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline ; break ;
2032
+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline ; break ;
2033
+ default : GGML_ABORT (" not implemented" );
2034
+ } break ;
2035
+ case GGML_TYPE_Q4_1:
2036
+ switch (r1ptg) {
2037
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline ; break ;
2038
+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline ; break ;
2039
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline ; break ;
2040
+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline ; break ;
2041
+ default : GGML_ABORT (" not implemented" );
2042
+ } break ;
2043
+ case GGML_TYPE_Q5_0:
2044
+ switch (r1ptg) {
2045
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline ; break ;
2046
+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline ; break ;
2047
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline ; break ;
2048
+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline ; break ;
2049
+ default : GGML_ABORT (" not implemented" );
2050
+ } break ;
2051
+ case GGML_TYPE_Q5_1:
2052
+ switch (r1ptg) {
2053
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline ; break ;
2054
+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline ; break ;
2055
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline ; break ;
2056
+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline ; break ;
2057
+ default : GGML_ABORT (" not implemented" );
2058
+ } break ;
2059
+ case GGML_TYPE_Q8_0:
2060
+ switch (r1ptg) {
2061
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline ; break ;
2062
+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline ; break ;
2063
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline ; break ;
2064
+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline ; break ;
2065
+ default : GGML_ABORT (" not implemented" );
2066
+ } break ;
2067
+ default : GGML_ABORT (" not implemented" );
2068
+ }
2069
+
2070
+ ggml_metal_kargs_mul_mv_ext args = {
2071
+ /* .ne00 =*/ ne00,
2072
+ /* .ne01 =*/ ne01,
2073
+ /* .ne02 =*/ ne02,
2074
+ /* .nb00 =*/ nb00,
2075
+ /* .nb01 =*/ nb01,
2076
+ /* .nb02 =*/ nb02,
2077
+ /* .nb03 =*/ nb03,
2078
+ /* .ne10 =*/ ne10,
2079
+ /* .ne11 =*/ ne11,
2080
+ /* .ne12 =*/ ne12,
2081
+ /* .nb10 =*/ nb10,
2082
+ /* .nb11 =*/ nb11,
2083
+ /* .nb12 =*/ nb12,
2084
+ /* .nb13 =*/ nb13,
2085
+ /* .ne0 =*/ ne0,
2086
+ /* .ne1 =*/ ne1,
2087
+ /* .r2 =*/ r2,
2088
+ /* .r3 =*/ r3,
2089
+ /* .nsg =*/ nsg,
2090
+ /* .nxpsg =*/ nxpsg,
2091
+ /* .r1ptg =*/ r1ptg,
2092
+ };
2093
+
2094
+ [encoder setComputePipelineState: pipeline];
2095
+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
2096
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2097
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
2098
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
1954
2099
2100
+ // printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
2101
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + r0ptg - 1 )/r0ptg, (ne11 + r1ptg - 1 )/r1ptg, ne12*ne13) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
2102
+ } else
1955
2103
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1956
2104
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1957
2105
if ([device supportsFamily: MTLGPUFamilyApple7] &&
0 commit comments