Skip to content

Commit 738a3ae

Browse files
authored
Support pure float16 add/sub/mul/div operations in the CUDA (and CPU) backend (#1121)
* Support float16-to-float16 add/sub/mul/div operations in the CUDA backend * Add fp16 support for add/sub/mul/div on the CPU backend * Add test cases for fp16 add/sub/mul/div
1 parent c21d976 commit 738a3ae

File tree

3 files changed

+238
-44
lines changed

3 files changed

+238
-44
lines changed

src/ggml-cpu/ggml-cpu.c

Lines changed: 202 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,15 +1415,35 @@ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x)
14151415
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
14161416
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
14171417
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
1418+
inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
1419+
for (int i = 0; i < n; ++i) {
1420+
z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) + GGML_FP16_TO_FP32(y[i]));
1421+
}
1422+
}
14181423
inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
14191424
inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
14201425
inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
14211426
inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
1427+
inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
1428+
for (int i = 0; i < n; ++i) {
1429+
z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) - GGML_FP16_TO_FP32(y[i]));
1430+
}
1431+
}
14221432
inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
14231433
inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
14241434
inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
14251435
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
1436+
inline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
1437+
for (int i = 0; i < n; ++i) {
1438+
z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) * GGML_FP16_TO_FP32(y[i]));
1439+
}
1440+
}
14261441
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
1442+
inline static void ggml_vec_div_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
1443+
for (int i = 0; i < n; ++i) {
1444+
z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) / GGML_FP16_TO_FP32(y[i]));
1445+
}
1446+
}
14271447

14281448
static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
14291449
assert(nrc == 1);
@@ -4379,7 +4399,7 @@ static void ggml_compute_forward_add_f16_f16(
43794399
const struct ggml_tensor * src0 = dst->src[0];
43804400
const struct ggml_tensor * src1 = dst->src[1];
43814401

4382-
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
4402+
GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
43834403

43844404
const int ith = params->ith;
43854405
const int nth = params->nth;
@@ -4404,17 +4424,22 @@ static void ggml_compute_forward_add_f16_f16(
44044424

44054425
if (nb10 == sizeof(ggml_fp16_t)) {
44064426
for (int ir = ir0; ir < ir1; ++ir) {
4407-
// src0, src1 and dst are same shape => same indices
4408-
const int i3 = ir/(ne2*ne1);
4409-
const int i2 = (ir - i3*ne2*ne1)/ne1;
4410-
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
4427+
// src1 is broadcastable across src0 and dst in i1, i2, i3
4428+
const int64_t i03 = ir/(ne02*ne01);
4429+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
4430+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
44114431

4412-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
4413-
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
4414-
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
4432+
const int64_t i13 = i03 % ne13;
4433+
const int64_t i12 = i02 % ne12;
4434+
const int64_t i11 = i01 % ne11;
4435+
const int64_t nr0 = ne00 / ne10;
44154436

4416-
for (int i = 0; i < ne0; i++) {
4417-
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i]));
4437+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
4438+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
4439+
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
4440+
4441+
for (int64_t r = 0; r < nr0; ++r) {
4442+
ggml_vec_add_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
44184443
}
44194444
}
44204445
}
@@ -5202,6 +5227,62 @@ static void ggml_compute_forward_sub_f32(
52025227
}
52035228
}
52045229

5230+
static void ggml_compute_forward_sub_f16(
5231+
const struct ggml_compute_params * params,
5232+
struct ggml_tensor * dst) {
5233+
5234+
const struct ggml_tensor * src0 = dst->src[0];
5235+
const struct ggml_tensor * src1 = dst->src[1];
5236+
5237+
assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
5238+
5239+
const int ith = params->ith;
5240+
const int nth = params->nth;
5241+
5242+
const int nr = ggml_nrows(src0);
5243+
5244+
GGML_TENSOR_BINARY_OP_LOCALS
5245+
5246+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
5247+
GGML_ASSERT(src1->type == GGML_TYPE_F16);
5248+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
5249+
5250+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
5251+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5252+
5253+
// rows per thread
5254+
const int dr = (nr + nth - 1)/nth;
5255+
5256+
// row range for this thread
5257+
const int ir0 = dr*ith;
5258+
const int ir1 = MIN(ir0 + dr, nr);
5259+
5260+
if (nb10 == sizeof(ggml_fp16_t)) {
5261+
for (int ir = ir0; ir < ir1; ++ir) {
5262+
// src1 is broadcastable across src0 and dst in i1, i2, i3
5263+
const int64_t i03 = ir/(ne02*ne01);
5264+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
5265+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
5266+
5267+
const int64_t i13 = i03 % ne13;
5268+
const int64_t i12 = i02 % ne12;
5269+
const int64_t i11 = i01 % ne11;
5270+
const int64_t nr0 = ne00 / ne10;
5271+
5272+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
5273+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
5274+
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
5275+
5276+
for (int64_t r = 0; r < nr0; ++r) {
5277+
ggml_vec_sub_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
5278+
}
5279+
}
5280+
} else {
5281+
// src1 is not contiguous
5282+
GGML_ABORT("unimplemented error");
5283+
}
5284+
}
5285+
52055286
static void ggml_compute_forward_sub(
52065287
const struct ggml_compute_params * params,
52075288
struct ggml_tensor * dst) {
@@ -5213,6 +5294,10 @@ static void ggml_compute_forward_sub(
52135294
{
52145295
ggml_compute_forward_sub_f32(params, dst);
52155296
} break;
5297+
case GGML_TYPE_F16:
5298+
{
5299+
ggml_compute_forward_sub_f16(params, dst);
5300+
} break;
52165301
default:
52175302
{
52185303
GGML_ABORT("fatal error");
@@ -5293,20 +5378,73 @@ static void ggml_compute_forward_mul_f32(
52935378
}
52945379
}
52955380

5381+
static void ggml_compute_forward_mul_f16(
5382+
const struct ggml_compute_params * params,
5383+
struct ggml_tensor * dst) {
5384+
5385+
const struct ggml_tensor * src0 = dst->src[0];
5386+
const struct ggml_tensor * src1 = dst->src[1];
5387+
5388+
GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
5389+
5390+
const int ith = params->ith;
5391+
const int nth = params->nth;
5392+
5393+
const int64_t nr = ggml_nrows(src0);
5394+
5395+
GGML_TENSOR_BINARY_OP_LOCALS
5396+
5397+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
5398+
GGML_ASSERT(src1->type == GGML_TYPE_F16);
5399+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
5400+
5401+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
5402+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5403+
5404+
if (nb10 == sizeof(ggml_fp16_t)) {
5405+
for (int64_t ir = ith; ir < nr; ir += nth) {
5406+
// src0 and dst are same shape => same indices
5407+
const int64_t i03 = ir/(ne02*ne01);
5408+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
5409+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
5410+
5411+
const int64_t i13 = i03 % ne13;
5412+
const int64_t i12 = i02 % ne12;
5413+
const int64_t i11 = i01 % ne11;
5414+
const int64_t nr0 = ne00 / ne10;
5415+
5416+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
5417+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
5418+
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
5419+
5420+
for (int64_t r = 0 ; r < nr0; ++r) {
5421+
ggml_vec_mul_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
5422+
}
5423+
}
5424+
} else {
5425+
// src1 is not contiguous
5426+
GGML_ABORT("unimplemented error");
5427+
}
5428+
}
5429+
52965430
static void ggml_compute_forward_mul(
52975431
const struct ggml_compute_params * params,
52985432
struct ggml_tensor * dst) {
52995433

53005434
const struct ggml_tensor * src0 = dst->src[0];
53015435
const struct ggml_tensor * src1 = dst->src[1];
53025436

5303-
GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now");
5437+
GGML_ASSERT((src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16) && "only f32/f16 src1 supported for now");
53045438

53055439
switch (src0->type) {
53065440
case GGML_TYPE_F32:
53075441
{
53085442
ggml_compute_forward_mul_f32(params, dst);
53095443
} break;
5444+
case GGML_TYPE_F16:
5445+
{
5446+
ggml_compute_forward_mul_f16(params, dst);
5447+
} break;
53105448
default:
53115449
{
53125450
GGML_ABORT("fatal error");
@@ -5387,6 +5525,55 @@ static void ggml_compute_forward_div_f32(
53875525
}
53885526
}
53895527

5528+
static void ggml_compute_forward_div_f16(
5529+
const struct ggml_compute_params * params,
5530+
struct ggml_tensor * dst) {
5531+
5532+
const struct ggml_tensor * src0 = dst->src[0];
5533+
const struct ggml_tensor * src1 = dst->src[1];
5534+
5535+
GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
5536+
5537+
const int ith = params->ith;
5538+
const int nth = params->nth;
5539+
5540+
const int64_t nr = ggml_nrows(src0);
5541+
5542+
GGML_TENSOR_BINARY_OP_LOCALS
5543+
5544+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
5545+
GGML_ASSERT(src1->type == GGML_TYPE_F16);
5546+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
5547+
5548+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
5549+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5550+
5551+
if (nb10 == sizeof(ggml_fp16_t)) {
5552+
for (int64_t ir = ith; ir < nr; ir += nth) {
5553+
// src0 and dst are same shape => same indices
5554+
const int64_t i03 = ir/(ne02*ne01);
5555+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
5556+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
5557+
5558+
const int64_t i13 = i03 % ne13;
5559+
const int64_t i12 = i02 % ne12;
5560+
const int64_t i11 = i01 % ne11;
5561+
const int64_t nr0 = ne00 / ne10;
5562+
5563+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
5564+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
5565+
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
5566+
5567+
for (int64_t r = 0; r < nr0; ++r) {
5568+
ggml_vec_div_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
5569+
}
5570+
}
5571+
} else {
5572+
// src1 is not contiguous
5573+
GGML_ABORT("unimplemented error");
5574+
}
5575+
}
5576+
53905577
static void ggml_compute_forward_div(
53915578
const struct ggml_compute_params * params,
53925579
struct ggml_tensor * dst) {
@@ -5398,6 +5585,10 @@ static void ggml_compute_forward_div(
53985585
{
53995586
ggml_compute_forward_div_f32(params, dst);
54005587
} break;
5588+
case GGML_TYPE_F16:
5589+
{
5590+
ggml_compute_forward_div_f16(params, dst);
5591+
} break;
54015592
default:
54025593
{
54035594
GGML_ABORT("fatal error");

src/ggml-cuda/binbcast.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,13 @@ static void ggml_cuda_op_bin_bcast(
294294
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
295295
const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
296296

297-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
297+
GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
298298

299299
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
300300
op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
301-
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
301+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
302+
op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream);
303+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
302304
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
303305
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
304306
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);

tests/test-backend-ops.cpp

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3943,37 +3943,38 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39433943
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
39443944
}
39453945
};
3946-
3947-
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1});
3948-
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1});
3949-
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1});
3950-
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 1, 1}, {1, 1, 1, 1});
3951-
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 1}, {1, 1, 1, 1});
3952-
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 1});
3953-
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1});
3954-
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 1, 1});
3955-
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1});
3956-
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 2});
3957-
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 2});
3958-
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 2, 2});
3959-
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 2, 2, 2});
3960-
3961-
// stable diffusion
3962-
add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 1, 1, 1});
3963-
add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 16, 16, 1});
3964-
add_test_bin_bcast(GGML_TYPE_F32, {1280, 16, 16, 1}, {1, 1, 1, 1});
3965-
add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 256, 1, 1});
3966-
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {16, 16, 1, 1});
3967-
add_test_bin_bcast(GGML_TYPE_F32, {16, 16, 1280, 1}, {1, 1, 1, 1});
3968-
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {16, 16, 1, 1});
3969-
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 2560, 1}, {16, 16, 1, 1});
3970-
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {32, 32, 1, 1});
3971-
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {32, 32, 1, 1});
3972-
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1});
3973-
add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1});
3974-
add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1});
3975-
//add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
3976-
//add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
3946+
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
3947+
add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1});
3948+
add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1});
3949+
add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1});
3950+
add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1});
3951+
add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1});
3952+
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1});
3953+
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1});
3954+
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1});
3955+
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1});
3956+
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2});
3957+
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2});
3958+
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
3959+
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
3960+
3961+
// stable diffusion
3962+
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1});
3963+
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1});
3964+
add_test_bin_bcast(type, {1280, 16, 16, 1}, {1, 1, 1, 1});
3965+
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 256, 1, 1});
3966+
add_test_bin_bcast(type, {1, 1, 1280, 1}, {16, 16, 1, 1});
3967+
add_test_bin_bcast(type, {16, 16, 1280, 1}, {1, 1, 1, 1});
3968+
add_test_bin_bcast(type, {1, 1, 1920, 1}, {16, 16, 1, 1});
3969+
add_test_bin_bcast(type, {1, 1, 2560, 1}, {16, 16, 1, 1});
3970+
add_test_bin_bcast(type, {1, 1, 1280, 1}, {32, 32, 1, 1});
3971+
add_test_bin_bcast(type, {1, 1, 1920, 1}, {32, 32, 1, 1});
3972+
add_test_bin_bcast(type, {1, 1, 640, 1}, {32, 32, 1, 1});
3973+
add_test_bin_bcast(type, {5120, 1, 1, 1}, {1, 256, 1, 1});
3974+
add_test_bin_bcast(type, {640, 1, 1, 1}, {1, 1, 1, 1});
3975+
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {1, 1, 1, 1});
3976+
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
3977+
}
39773978

39783979
test_cases.emplace_back(new test_add1());
39793980
test_cases.emplace_back(new test_scale());

0 commit comments

Comments
 (0)