@@ -1657,6 +1657,10 @@ void ggml_metal_graph_compute(
1657
1657
}
1658
1658
};
1659
1659
1660
+ if (ggml_is_quantized (src0t)) {
1661
+ GGML_ASSERT (ne00 >= nth0*nth1);
1662
+ }
1663
+
1660
1664
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1661
1665
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1662
1666
[encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
@@ -1715,6 +1719,9 @@ void ggml_metal_graph_compute(
1715
1719
// TODO: make this more general
1716
1720
GGML_ASSERT (n_as <= 8 );
1717
1721
1722
+ // max size of the src1ids array in the kernel stack
1723
+ GGML_ASSERT (ne11 <= 512 );
1724
+
1718
1725
struct ggml_tensor * src2 = gf->nodes [i]->src [2 ];
1719
1726
1720
1727
const int64_t ne20 = src2 ? src2->ne [0 ] : 0 ;
@@ -1732,32 +1739,29 @@ void ggml_metal_graph_compute(
1732
1739
GGML_ASSERT (!ggml_is_transposed (src2));
1733
1740
GGML_ASSERT (!ggml_is_transposed (src1));
1734
1741
1735
- GGML_ASSERT (ne20 % 32 == 0 );
1736
- // !!!!!!!!! TODO: this assert is probably required but not sure!
1737
- // GGML_ASSERT(ne20 >= 64);
1738
1742
GGML_ASSERT (src1t == GGML_TYPE_F32);
1739
1743
1740
1744
const uint r2 = ne12/ne22;
1741
1745
const uint r3 = ne13/ne23;
1742
1746
1743
1747
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
1744
1748
// to the matrix-vector kernel
1745
- int ne11_mm_min = 1 ;
1749
+ int ne11_mm_min = n_as ;
1746
1750
1747
1751
const int idx = ((int32_t *) dst->op_params )[0 ];
1748
1752
1749
1753
// batch size
1750
1754
GGML_ASSERT (ne01 == ne11);
1751
1755
1752
- const int64_t _ne1 = 1 ; // kernel_mul_mm_impl needs a reference in constant memory
1753
-
1754
1756
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1755
1757
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1756
1758
// !!!
1757
1759
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
1758
1760
// indirect matrix multiplication
1759
1761
// !!!
1760
- if ([ctx->device supportsFamily: MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
1762
+ if ([ctx->device supportsFamily: MTLGPUFamilyApple7] &&
1763
+ ne20 % 32 == 0 && ne20 >= 64 &&
1764
+ ne11 > ne11_mm_min) {
1761
1765
switch (src2->type ) {
1762
1766
case GGML_TYPE_F32: [encoder setComputePipelineState: ctx->pipeline_mul_mm_id_f32_f32]; break ;
1763
1767
case GGML_TYPE_F16: [encoder setComputePipelineState: ctx->pipeline_mul_mm_id_f16_f32]; break ;
@@ -1787,7 +1791,7 @@ void ggml_metal_graph_compute(
1787
1791
[encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 11 ];
1788
1792
[encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 12 ];
1789
1793
[encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 13 ];
1790
- [encoder setBytes: &_ne1 length: sizeof (_ne1) atIndex: 14 ];
1794
+ [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 14 ];
1791
1795
[encoder setBytes: &nb1 length: sizeof (nb1) atIndex: 15 ];
1792
1796
[encoder setBytes: &r2 length: sizeof (r2) atIndex: 16 ];
1793
1797
[encoder setBytes: &r3 length: sizeof (r3) atIndex: 17 ];
@@ -1805,8 +1809,7 @@ void ggml_metal_graph_compute(
1805
1809
1806
1810
[encoder setThreadgroupMemoryLength: 8192 atIndex: 0 ];
1807
1811
1808
- // TODO: processing one row at a time (ne11 -> 1) is not efficient
1809
- [encoder dispatchThreadgroups: MTLSizeMake ( (_ne1 + 31 )/32 , (ne21 + 63 )/64 , ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
1812
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne11 + 31 )/32 , (ne21 + 63 )/64 , n_as*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
1810
1813
} else {
1811
1814
int nth0 = 32 ;
1812
1815
int nth1 = 1 ;
@@ -1889,11 +1892,17 @@ void ggml_metal_graph_compute(
1889
1892
} break ;
1890
1893
default :
1891
1894
{
1892
- GGML_METAL_LOG_ERROR (" Asserting on type %d \n " , (int )src0t );
1895
+ GGML_METAL_LOG_ERROR (" Asserting on type %d \n " , (int )src2t );
1893
1896
GGML_ASSERT (false && " not implemented" );
1894
1897
}
1895
1898
};
1896
1899
1900
+ if (ggml_is_quantized (src2t)) {
1901
+ GGML_ASSERT (ne20 >= nth0*nth1);
1902
+ }
1903
+
1904
+ const int64_t _ne1 = 1 ; // kernels needs a reference in constant memory
1905
+
1897
1906
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1898
1907
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1899
1908
[encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
0 commit comments