Skip to content

Commit fb76ec3

Browse files
authored
ggml : fix YARN + add tests + add asserts (#7617)
* tests : add rope tests ggml-ci * ggml : fixes (hopefully) ggml-ci * tests : add non-cont tests ggml-ci * cuda : add asserts for rope/norm + fix DS2 ggml-ci * ggml : assert contiguousness * tests : reduce RoPE tests ggml-ci
1 parent cce3dcf commit fb76ec3

12 files changed

+168
-105
lines changed

ggml-cuda.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1870,7 +1870,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18701870
}
18711871
}
18721872
#else
1873-
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
1873+
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
18741874
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
18751875
// use cublasGemmStridedBatchedEx
18761876
CUBLAS_CHECK(
@@ -2886,7 +2886,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28862886
case GGML_OP_CONT:
28872887
case GGML_OP_DIAG_MASK_INF:
28882888
case GGML_OP_SOFT_MAX:
2889+
return true;
28892890
case GGML_OP_ROPE:
2891+
return ggml_is_contiguous(op->src[0]);
28902892
case GGML_OP_IM2COL:
28912893
case GGML_OP_POOL_2D:
28922894
case GGML_OP_SUM_ROWS:

ggml-cuda/norm.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
170170
float * dst_d = (float *)dst->data;
171171
cudaStream_t stream = ctx.stream();
172172

173+
GGML_ASSERT(ggml_is_contiguous(src0));
174+
173175
GGML_ASSERT(src0->type == GGML_TYPE_F32);
174176
GGML_ASSERT( dst->type == GGML_TYPE_F32);
175177

@@ -188,6 +190,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
188190
float * dst_d = (float *)dst->data;
189191
cudaStream_t stream = ctx.stream();
190192

193+
GGML_ASSERT(ggml_is_contiguous(src0));
194+
191195
GGML_ASSERT(src0->type == GGML_TYPE_F32);
192196
GGML_ASSERT( dst->type == GGML_TYPE_F32);
193197

@@ -202,6 +206,8 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
202206
float * dst_d = (float *)dst->data;
203207
cudaStream_t stream = ctx.stream();
204208

209+
GGML_ASSERT(ggml_is_contiguous(src0));
210+
205211
GGML_ASSERT(src0->type == GGML_TYPE_F32);
206212
GGML_ASSERT( dst->type == GGML_TYPE_F32);
207213

ggml-cuda/rope.cu

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ static __global__ void rope(
6161
template<typename T, bool has_pos, bool has_freq_facs>
6262
static __global__ void rope_neox(
6363
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
64-
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
64+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
6565
) {
6666
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
6767

@@ -85,15 +85,13 @@ static __global__ void rope_neox(
8585
const int i = row*ncols + ib*n_dims + ic/2;
8686
const int i2 = row/p_delta_rows;
8787

88-
float cur_rot = inv_ndims * ic - ib;
89-
9088
const int p = has_pos ? pos[i2] : 0;
9189
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
9290

93-
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
91+
const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor;
9492

9593
float cos_theta, sin_theta;
96-
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
94+
rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
9795

9896
const float x0 = x[i + 0];
9997
const float x1 = x[i + n_dims/2];
@@ -174,30 +172,29 @@ static void rope_neox_cuda(
174172
const dim3 block_nums(nrows, num_blocks_x, 1);
175173

176174
const float theta_scale = powf(freq_base, -2.0f/n_dims);
177-
const float inv_ndims = -1.0f / n_dims;
178175

179176
if (pos == nullptr) {
180177
if (freq_factors == nullptr) {
181178
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
182179
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
183-
theta_scale, inv_ndims, freq_factors
180+
theta_scale, freq_factors
184181
);
185182
} else {
186183
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
187184
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
188-
theta_scale, inv_ndims, freq_factors
185+
theta_scale, freq_factors
189186
);
190187
}
191188
} else {
192189
if (freq_factors == nullptr) {
193190
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
194191
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
195-
theta_scale, inv_ndims, freq_factors
192+
theta_scale, freq_factors
196193
);
197194
} else {
198195
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
199196
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
200-
theta_scale, inv_ndims, freq_factors
197+
theta_scale, freq_factors
201198
);
202199
}
203200
}
@@ -254,6 +251,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
254251
float * dst_d = (float *)dst->data;
255252
cudaStream_t stream = ctx.stream();
256253

254+
GGML_ASSERT(ggml_is_contiguous(src0));
257255
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
258256
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
259257
GGML_ASSERT(src0->type == dst->type);

ggml-kompute.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1597,7 +1597,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
15971597
{
15981598
GGML_ASSERT(ne00 == ne10);
15991599

1600-
// TODO: assert that dim2 and dim3 are contiguous
1600+
ggml_is_contiguous_2(src0);
1601+
ggml_is_contiguous_2(src1);
1602+
16011603
GGML_ASSERT(ne12 % ne02 == 0);
16021604
GGML_ASSERT(ne13 % ne03 == 0);
16031605

ggml-metal.m

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1519,7 +1519,9 @@ static enum ggml_status ggml_metal_graph_compute(
15191519
{
15201520
GGML_ASSERT(ne00 == ne10);
15211521

1522-
// TODO: assert that dim2 and dim3 are contiguous
1522+
ggml_is_contiguous_2(src0);
1523+
ggml_is_contiguous_2(src1);
1524+
15231525
GGML_ASSERT(ne12 % ne02 == 0);
15241526
GGML_ASSERT(ne13 % ne03 == 0);
15251527

@@ -2187,6 +2189,7 @@ static enum ggml_status ggml_metal_graph_compute(
21872189
case GGML_OP_RMS_NORM:
21882190
{
21892191
GGML_ASSERT(ne00 % 4 == 0);
2192+
GGML_ASSERT(ggml_is_contiguous_1(src0));
21902193

21912194
float eps;
21922195
memcpy(&eps, dst->op_params, sizeof(float));
@@ -2214,6 +2217,7 @@ static enum ggml_status ggml_metal_graph_compute(
22142217
case GGML_OP_GROUP_NORM:
22152218
{
22162219
GGML_ASSERT(ne00 % 4 == 0);
2220+
GGML_ASSERT(ggml_is_contiguous(src0));
22172221

22182222
//float eps;
22192223
//memcpy(&eps, dst->op_params, sizeof(float));
@@ -2247,6 +2251,8 @@ static enum ggml_status ggml_metal_graph_compute(
22472251
} break;
22482252
case GGML_OP_NORM:
22492253
{
2254+
GGML_ASSERT(ggml_is_contiguous_1(src0));
2255+
22502256
float eps;
22512257
memcpy(&eps, dst->op_params, sizeof(float));
22522258

ggml-metal.metal

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1767,13 +1767,13 @@ kernel void kernel_rope(
17671767

17681768
const int64_t p = pos[i2];
17691769

1770-
const float theta_0 = (float)p;
1770+
const float theta_base = (float)p;
17711771
const float inv_ndims = -1.f/n_dims;
17721772

17731773
if (!is_neox) {
17741774
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1775+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
17751776

1776-
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
17771777
float cos_theta, sin_theta;
17781778
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
17791779

@@ -1789,18 +1789,14 @@ kernel void kernel_rope(
17891789
} else {
17901790
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
17911791
if (ic < n_dims) {
1792-
const int64_t ib = 0;
1792+
const int64_t i0 = ic/2;
17931793

1794-
// simplified from `(ib * n_dims + ic) * inv_ndims`
1795-
const float cur_rot = inv_ndims*ic - ib;
1796-
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
1794+
const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
17971795

1798-
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
1796+
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
17991797

18001798
float cos_theta, sin_theta;
1801-
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1802-
1803-
const int64_t i0 = ib*n_dims + ic/2;
1799+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
18041800

18051801
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
18061802
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

ggml-sycl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15183,7 +15183,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
1518315183
const int64_t r2 = ne12/ne02;
1518415184
const int64_t r3 = ne13/ne03;
1518515185

15186-
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
15186+
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
1518715187
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
1518815188
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
1518915189
*g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,

ggml.c

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3221,7 +3221,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
32213221
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
32223222
}
32233223

3224-
static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) {
3224+
GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
3225+
return ggml_is_contiguous(tensor);
3226+
}
3227+
3228+
GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
32253229
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
32263230

32273231
return
@@ -3230,6 +3234,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
32303234
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
32313235
}
32323236

3237+
GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
3238+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3239+
3240+
return
3241+
tensor->nb[0] == ggml_type_size(tensor->type) &&
3242+
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3243+
}
3244+
32333245
GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
32343246
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
32353247

@@ -11420,8 +11432,8 @@ static void ggml_compute_forward_gelu_f32(
1142011432

1142111433
const struct ggml_tensor * src0 = dst->src[0];
1142211434

11423-
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11424-
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11435+
GGML_ASSERT(ggml_is_contiguous_1(src0));
11436+
GGML_ASSERT(ggml_is_contiguous_1(dst));
1142511437
GGML_ASSERT(ggml_are_same_shape(src0, dst));
1142611438

1142711439
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11483,8 +11495,8 @@ static void ggml_compute_forward_gelu_quick_f32(
1148311495

1148411496
const struct ggml_tensor * src0 = dst->src[0];
1148511497

11486-
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11487-
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11498+
GGML_ASSERT(ggml_is_contiguous_1(src0));
11499+
GGML_ASSERT(ggml_is_contiguous_1(dst));
1148811500
GGML_ASSERT(ggml_are_same_shape(src0, dst));
1148911501

1149011502
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11546,8 +11558,8 @@ static void ggml_compute_forward_silu_f32(
1154611558

1154711559
const struct ggml_tensor * src0 = dst->src[0];
1154811560

11549-
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11550-
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11561+
GGML_ASSERT(ggml_is_contiguous_1(src0));
11562+
GGML_ASSERT(ggml_is_contiguous_1(dst));
1155111563
GGML_ASSERT(ggml_are_same_shape(src0, dst));
1155211564

1155311565
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11658,9 +11670,9 @@ static void ggml_compute_forward_silu_back_f32(
1165811670
const struct ggml_tensor * src0 = dst->src[0];
1165911671
const struct ggml_tensor * grad = dst->src[1];
1166011672

11661-
GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad));
11662-
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11663-
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11673+
GGML_ASSERT(ggml_is_contiguous_1(grad));
11674+
GGML_ASSERT(ggml_is_contiguous_1(src0));
11675+
GGML_ASSERT(ggml_is_contiguous_1(dst));
1166411676
GGML_ASSERT(ggml_are_same_shape(src0, dst));
1166511677
GGML_ASSERT(ggml_are_same_shape(src0, grad));
1166611678

@@ -14358,7 +14370,7 @@ static void ggml_compute_forward_rope_f32(
1435814370
int ir = 0;
1435914371

1436014372
const float theta_scale = powf(freq_base, -2.0f/n_dims);
14361-
const float inv_ndims = -1.f/n_dims;
14373+
1436214374
float corr_dims[2];
1436314375
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1436414376

@@ -14407,7 +14419,7 @@ static void ggml_compute_forward_rope_f32(
1440714419
const float cos_block_theta = cosf(block_theta);
1440814420
const float sin_block_theta = sinf(block_theta) * sin_sign;
1440914421

14410-
theta_base *= theta_scale;
14422+
theta_base *= theta_scale;
1441114423
block_theta *= theta_scale;
1441214424

1441314425
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14442,29 +14454,22 @@ static void ggml_compute_forward_rope_f32(
1444214454
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
1444314455
}
1444414456
} else {
14445-
// TODO: this might be wrong for ne0 != n_dims - need double check
14446-
// it seems we have to rope just the first n_dims elements and do nothing with the rest
14447-
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14448-
theta_base *= freq_scale;
14457+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
1444914458
for (int64_t ic = 0; ic < ne0; ic += 2) {
1445014459
if (ic < n_dims) {
14451-
const int64_t ib = 0;
14460+
const int64_t i0 = ic/2;
1445214461

14453-
// simplified from `(ib * n_dims + ic) * inv_ndims`
14454-
float cur_rot = inv_ndims * ic - ib;
14455-
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14462+
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
1445614463

1445714464
float cos_theta, sin_theta;
1445814465
rope_yarn(
14459-
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14466+
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
1446014467
&cos_theta, &sin_theta
1446114468
);
14462-
sin_theta *= sin_sign;
1446314469

14470+
sin_theta *= sin_sign;
1446414471
theta_base *= theta_scale;
1446514472

14466-
const int64_t i0 = ib*n_dims + ic/2;
14467-
1446814473
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1446914474
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1447014475

@@ -14543,7 +14548,7 @@ static void ggml_compute_forward_rope_f16(
1454314548
int ir = 0;
1454414549

1454514550
const float theta_scale = powf(freq_base, -2.0f/n_dims);
14546-
const float inv_ndims = -1.f/n_dims;
14551+
1454714552
float corr_dims[2];
1454814553
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1454914554

@@ -14592,7 +14597,7 @@ static void ggml_compute_forward_rope_f16(
1459214597
const float cos_block_theta = cosf(block_theta);
1459314598
const float sin_block_theta = sinf(block_theta) * sin_sign;
1459414599

14595-
theta_base *= theta_scale;
14600+
theta_base *= theta_scale;
1459614601
block_theta *= theta_scale;
1459714602

1459814603
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14623,29 +14628,22 @@ static void ggml_compute_forward_rope_f16(
1462314628
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
1462414629
}
1462514630
} else {
14626-
// TODO: this might be wrong for ne0 != n_dims - need double check
14627-
// it seems we have to rope just the first n_dims elements and do nothing with the rest
14628-
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14629-
theta_base *= freq_scale;
14631+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
1463014632
for (int64_t ic = 0; ic < ne0; ic += 2) {
1463114633
if (ic < n_dims) {
14632-
const int64_t ib = 0;
14634+
const int64_t i0 = ic/2;
1463314635

14634-
// simplified from `(ib * n_dims + ic) * inv_ndims`
14635-
float cur_rot = inv_ndims * ic - ib;
14636-
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14636+
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
1463714637

1463814638
float cos_theta, sin_theta;
1463914639
rope_yarn(
14640-
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14640+
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
1464114641
&cos_theta, &sin_theta
1464214642
);
14643-
sin_theta *= sin_sign;
1464414643

14644+
sin_theta *= sin_sign;
1464514645
theta_base *= theta_scale;
1464614646

14647-
const int64_t i0 = ib*n_dims + ic/2;
14648-
1464914647
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1465014648
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1465114649

0 commit comments

Comments
 (0)