@@ -1231,7 +1231,7 @@ static void ggml_cuda_op_mul_mat_cublas(
1231
1231
1232
1232
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ] && dst->op_params [0 ] == GGML_PREC_DEFAULT) {
1233
1233
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1234
- ggml_cuda_pool_alloc<half> src0_as_f16 (ctx.pool ());
1234
+ ggml_cuda_pool_alloc<half> src0_as_f16 (ctx.pool (id ));
1235
1235
if (src0->type != GGML_TYPE_F16) {
1236
1236
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda (src0->type );
1237
1237
GGML_ASSERT (to_fp16_cuda != nullptr );
@@ -1241,7 +1241,7 @@ static void ggml_cuda_op_mul_mat_cublas(
1241
1241
}
1242
1242
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get ();
1243
1243
1244
- ggml_cuda_pool_alloc<half> src1_as_f16 (ctx.pool ());
1244
+ ggml_cuda_pool_alloc<half> src1_as_f16 (ctx.pool (id ));
1245
1245
if (src1->type != GGML_TYPE_F16) {
1246
1246
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda (src1->type );
1247
1247
GGML_ASSERT (to_fp16_cuda != nullptr );
@@ -1250,7 +1250,7 @@ static void ggml_cuda_op_mul_mat_cublas(
1250
1250
to_fp16_cuda (src1_ddf_i, src1_as_f16.get (), ne, stream);
1251
1251
}
1252
1252
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get ();
1253
- ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool (), row_diff*src1_ncols);
1253
+ ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool (id ), row_diff*src1_ncols);
1254
1254
1255
1255
const half alpha_f16 = 1 .0f ;
1256
1256
const half beta_f16 = 0 .0f ;
@@ -1960,20 +1960,84 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1960
1960
}
1961
1961
}
1962
1962
1963
+ struct mmid_row_mapping {
1964
+ int64_t i1;
1965
+ int64_t i2;
1966
+ };
1967
+
1968
+ static __global__ void k_copy_src1_to_contiguous (const char * src1_original, char * src1_contiguous,
1969
+ int * cur_src1_row, mmid_row_mapping * row_mapping,
1970
+ const char * ids_dev, int64_t i02, int64_t ids_nb1, int64_t ids_nb0,
1971
+ int64_t ids_ne1, int64_t n_ids,
1972
+ int64_t ne11,
1973
+ size_t nb11, size_t nb12) {
1974
+ int64_t iid1 = blockIdx .x ;
1975
+ int64_t id = blockIdx .y ;
1976
+
1977
+ if (iid1 >= ids_ne1 || id >= n_ids) {
1978
+ return ;
1979
+ }
1980
+
1981
+ const int32_t row_id_i = *(const int32_t *) (ids_dev + iid1*ids_nb1 + id*ids_nb0);
1982
+
1983
+ if (row_id_i != i02) {
1984
+ return ;
1985
+ }
1986
+
1987
+ const int64_t i11 = id % ne11;
1988
+ const int64_t i12 = iid1;
1989
+
1990
+ __shared__ int src1_row;
1991
+ if (threadIdx .x == 0 ) {
1992
+ src1_row = atomicAdd (cur_src1_row, 1 );
1993
+ row_mapping[src1_row] = {id, iid1};
1994
+ }
1995
+ __syncthreads ();
1996
+
1997
+ const char * src1_row_original = src1_original + i11*nb11 + i12*nb12;
1998
+ char * src1_row_contiguous = src1_contiguous + src1_row*nb11;
1999
+
2000
+ for (int i = threadIdx .x ; i < nb11; i += blockDim .x ) {
2001
+ src1_row_contiguous[i] = src1_row_original[i];
2002
+ }
2003
+ }
2004
+
2005
+ static __global__ void k_copy_dst_from_contiguous (char * dst_original, const char * dst_contiguous,
2006
+ const mmid_row_mapping * row_mapping,
2007
+ int64_t n_rows,
2008
+ int64_t nb1, int64_t nb2) {
2009
+ int64_t i = blockIdx .x ;
2010
+
2011
+ if (i >= n_rows) {
2012
+ return ;
2013
+ }
2014
+
2015
+ const int64_t i1 = row_mapping[i].i1 ;
2016
+ const int64_t i2 = row_mapping[i].i2 ;
2017
+
2018
+ const char * dst_row_contiguous = dst_contiguous + i*nb1;
2019
+ char * dst_row_original = dst_original + i1*nb1 + i2*nb2;
2020
+
2021
+ for (int j = threadIdx .x ; j < nb1; j += blockDim .x ) {
2022
+ dst_row_original[j] = dst_row_contiguous[j];
2023
+ }
2024
+ }
2025
+
2026
+ // #define MMID_MEMCPY
2027
+
1963
2028
static void ggml_cuda_mul_mat_id (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1964
2029
const ggml_tensor * src0 = dst->src [0 ];
1965
2030
const ggml_tensor * src1 = dst->src [1 ];
1966
2031
const ggml_tensor * ids = dst->src [2 ];
1967
2032
2033
+ GGML_TENSOR_BINARY_OP_LOCALS
2034
+
1968
2035
GGML_ASSERT (!ggml_backend_buffer_is_cuda_split (src0->buffer ) && " mul_mat_id does not support split buffers" );
1969
2036
1970
2037
cudaStream_t stream = ctx.stream ();
1971
2038
1972
- const size_t nb11 = src1->nb [1 ];
1973
- const size_t nb1 = dst->nb [1 ];
1974
-
1975
- const int32_t id = ((int32_t *) dst->op_params )[0 ];
1976
- const int32_t n_as = src0->ne [2 ];
2039
+ const int64_t n_as = ne02;
2040
+ const int64_t n_ids = ids->ne [0 ];
1977
2041
1978
2042
std::vector<char > ids_host (ggml_nbytes (ids));
1979
2043
const char * ids_dev = (const char *) ids->data ;
@@ -1982,27 +2046,47 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
1982
2046
1983
2047
ggml_tensor src0_row = *src0;
1984
2048
ggml_tensor src1_row = *src1;
1985
- ggml_tensor dst_row = *dst;
2049
+ ggml_tensor dst_row = *dst;
1986
2050
1987
2051
char * src0_original = (char *) src0->data ;
1988
2052
char * src1_original = (char *) src1->data ;
1989
2053
char * dst_original = (char *) dst->data ;
1990
2054
1991
2055
src0_row.ne [2 ] = 1 ;
1992
2056
src0_row.ne [3 ] = 1 ;
1993
- src0_row.nb [3 ] = src0-> nb [ 2 ] ;
2057
+ src0_row.nb [3 ] = nb02 ;
1994
2058
1995
- if (src1->ne [1 ] == 1 ) {
1996
- for (int64_t i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
1997
- const int32_t row_id = *(const int32_t *) (ids_host.data () + i01*ids->nb [1 ] + id*ids->nb [0 ]);
2059
+ src1_row.ne [1 ] = 1 ;
2060
+ src1_row.ne [2 ] = 1 ;
2061
+ src1_row.ne [3 ] = 1 ;
2062
+ src1_row.nb [2 ] = nb11;
2063
+ src1_row.nb [3 ] = nb11;
1998
2064
1999
- GGML_ASSERT (row_id >= 0 && row_id < n_as);
2065
+ dst_row.ne [1 ] = 1 ;
2066
+ dst_row.ne [2 ] = 1 ;
2067
+ dst_row.ne [3 ] = 1 ;
2068
+ dst_row.nb [2 ] = nb1;
2069
+ dst_row.nb [3 ] = nb1;
2000
2070
2001
- src0_row.data = src0_original + row_id*src0->nb [2 ];
2002
- src1_row.data = src1_original + i01*src1->nb [1 ];
2003
- dst_row.data = dst_original + i01*dst->nb [1 ];
2071
+ if (ne12 == 1 ) {
2072
+ for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2073
+ for (int64_t id = 0 ; id < n_ids; id++) {
2074
+ const int32_t i02 = *(const int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2004
2075
2005
- ggml_cuda_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
2076
+ GGML_ASSERT (i02 >= 0 && i02 < n_as);
2077
+
2078
+ const int64_t i11 = id % ne11;
2079
+ const int64_t i12 = iid1;
2080
+
2081
+ const int64_t i1 = id;
2082
+ const int64_t i2 = i12;
2083
+
2084
+ src0_row.data = src0_original + i02*nb02;
2085
+ src1_row.data = src1_original + i11*nb11 + i12*nb12;
2086
+ dst_row.data = dst_original + i1*nb1 + i2*nb2;
2087
+
2088
+ ggml_cuda_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
2089
+ }
2006
2090
}
2007
2091
} else {
2008
2092
ggml_cuda_pool_alloc<char > src1_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (src1));
@@ -2011,55 +2095,104 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2011
2095
src1_row.data = src1_contiguous.get ();
2012
2096
dst_row.data = dst_contiguous.get ();
2013
2097
2014
- for (int32_t row_id = 0 ; row_id < n_as; ++row_id ) {
2098
+ for (int64_t i02 = 0 ; i02 < n_as; i02++ ) {
2015
2099
int64_t num_src1_rows = 0 ;
2016
- for (int64_t i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
2017
- const int32_t row_id_i = *(const int32_t *) (ids_host.data () + i01*ids->nb [1 ] + id*ids->nb [0 ]);
2018
2100
2019
- if (row_id_i != row_id) {
2020
- continue ;
2021
- }
2101
+ for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2102
+ for (int64_t id = 0 ; id < n_ids; id++) {
2103
+ const int32_t row_id_i = *(const int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2104
+
2105
+ if (row_id_i != i02) {
2106
+ continue ;
2107
+ }
2022
2108
2023
- GGML_ASSERT (row_id >= 0 && row_id < n_as);
2109
+ GGML_ASSERT (i02 >= 0 && i02 < n_as);
2024
2110
2025
- CUDA_CHECK (cudaMemcpyAsync (src1_contiguous.get () + num_src1_rows*nb11, src1_original + i01*nb11,
2026
- nb11, cudaMemcpyDeviceToDevice, stream));
2027
- num_src1_rows++;
2111
+ #ifdef MMID_MEMCPY
2112
+ const int64_t i11 = id % ne11;
2113
+ const int64_t i12 = iid1;
2114
+ CUDA_CHECK (cudaMemcpyAsync (src1_contiguous.get () + num_src1_rows*nb11,
2115
+ src1_original + i11*nb11 + i12*nb12,
2116
+ nb11, cudaMemcpyDeviceToDevice, stream));
2117
+ #endif
2118
+ num_src1_rows++;
2119
+ }
2028
2120
}
2029
2121
2030
2122
if (num_src1_rows == 0 ) {
2031
2123
continue ;
2032
2124
}
2033
2125
2034
- src0_row.data = src0_original + row_id*src0->nb [2 ];
2126
+ #ifndef MMID_MEMCPY
2127
+ ggml_cuda_pool_alloc<int > dev_cur_src1_row (ctx.pool (), 1 );
2128
+ ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping (ctx.pool (), num_src1_rows);
2129
+ CUDA_CHECK (cudaMemsetAsync (dev_cur_src1_row.get (), 0 , sizeof (int ), stream));
2035
2130
2036
- src1_row.ne [1 ] = num_src1_rows;
2037
- dst_row.ne [1 ] = num_src1_rows;
2131
+ {
2132
+ dim3 block_dims (std::min ((uint )nb11, 1024u ));
2133
+ dim3 grid_dims (ids->ne [1 ], n_ids);
2134
+ k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2135
+ src1_original, src1_contiguous.get (),
2136
+ dev_cur_src1_row.get (), dev_row_mapping.get (),
2137
+ ids_dev, i02, ids->nb [1 ], ids->nb [0 ],
2138
+ ids->ne [1 ], n_ids,
2139
+ ne11,
2140
+ nb11, nb12);
2141
+ CUDA_CHECK (cudaGetLastError ());
2142
+ }
2143
+ #endif
2144
+
2145
+ src0_row.data = src0_original + i02*nb02;
2038
2146
2147
+ GGML_ASSERT (nb11 == sizeof (float )*ne10);
2148
+ GGML_ASSERT (nb1 == sizeof (float )*ne0);
2149
+
2150
+ src1_row.ne [1 ] = num_src1_rows;
2039
2151
src1_row.nb [1 ] = nb11;
2040
2152
src1_row.nb [2 ] = num_src1_rows*nb11;
2041
2153
src1_row.nb [3 ] = num_src1_rows*nb11;
2042
2154
2155
+ dst_row.ne [1 ] = num_src1_rows;
2043
2156
dst_row.nb [1 ] = nb1;
2044
2157
dst_row.nb [2 ] = num_src1_rows*nb1;
2045
2158
dst_row.nb [3 ] = num_src1_rows*nb1;
2046
2159
2047
2160
ggml_cuda_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
2048
2161
2162
+ #ifndef MMID_MEMCPY
2163
+ {
2164
+ dim3 block_dims (std::min ((uint )nb1, 1024u ));
2165
+ dim3 grid_dims (num_src1_rows);
2166
+ k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2167
+ dst_original, dst_contiguous.get (),
2168
+ dev_row_mapping.get (),
2169
+ num_src1_rows, nb1, nb2);
2170
+ CUDA_CHECK (cudaGetLastError ());
2171
+ }
2172
+ #endif
2173
+
2174
+ #ifdef MMID_MEMCPY
2049
2175
num_src1_rows = 0 ;
2050
- for (int64_t i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
2051
- const int32_t row_id_i = *(const int32_t *) (ids_host.data () + i01*ids->nb [1 ] + id*ids->nb [0 ]);
2176
+ for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2177
+ for (int64_t id = 0 ; id < n_ids; id++) {
2178
+ const int32_t row_id_i = *(const int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2052
2179
2053
- if (row_id_i != row_id ) {
2054
- continue ;
2055
- }
2180
+ if (row_id_i != i02 ) {
2181
+ continue ;
2182
+ }
2056
2183
2057
- GGML_ASSERT (row_id >= 0 && row_id < n_as);
2184
+ GGML_ASSERT (i02 >= 0 && i02 < n_as);
2058
2185
2059
- CUDA_CHECK (cudaMemcpyAsync (dst_original + i01*nb1, dst_contiguous.get () + num_src1_rows*nb1,
2060
- nb1, cudaMemcpyDeviceToDevice, stream));
2061
- num_src1_rows++;
2186
+ const int64_t i1 = id;
2187
+ const int64_t i2 = iid1;
2188
+
2189
+ CUDA_CHECK (cudaMemcpyAsync (dst_original + i1*nb1 + i2*nb2,
2190
+ dst_contiguous.get () + num_src1_rows*nb1,
2191
+ nb1, cudaMemcpyDeviceToDevice, stream));
2192
+ num_src1_rows++;
2193
+ }
2062
2194
}
2195
+ #endif
2063
2196
}
2064
2197
}
2065
2198
}
@@ -2487,7 +2620,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2487
2620
GGML_CALL static bool ggml_backend_cuda_offload_op (ggml_backend_t backend, const ggml_tensor * op) {
2488
2621
const int min_batch_size = 32 ;
2489
2622
2490
- return op->ne [1 ] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
2623
+ return (op->ne [1 ] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
2624
+ (op->ne [2 ] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
2491
2625
2492
2626
GGML_UNUSED (backend);
2493
2627
}
0 commit comments