Skip to content

Commit 3d252a1

Browse files
AclyGitty Burstein
authored andcommitted
ggml-cpu : bicubic interpolation (#16891)
1 parent 642e43d commit 3d252a1

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed

ggml/include/ggml.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2107,6 +2107,7 @@ extern "C" {
21072107
enum ggml_scale_mode {
21082108
GGML_SCALE_MODE_NEAREST = 0,
21092109
GGML_SCALE_MODE_BILINEAR = 1,
2110+
GGML_SCALE_MODE_BICUBIC = 2,
21102111

21112112
GGML_SCALE_MODE_COUNT
21122113
};

ggml/src/ggml-cpu/ops.cpp

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7517,10 +7517,17 @@ static void ggml_compute_forward_upscale_f32(
75177517
float sf1 = (float)ne1/src0->ne[1];
75187518
float sf2 = (float)ne2/src0->ne[2];
75197519
float sf3 = (float)ne3/src0->ne[3];
7520+
float pixel_offset = 0.5f;
75207521

75217522
const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
75227523
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
75237524

7525+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7526+
pixel_offset = 0.0f;
7527+
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
7528+
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
7529+
}
7530+
75247531
if (mode == GGML_SCALE_MODE_NEAREST) {
75257532
for (int64_t i3 = 0; i3 < ne3; i3++) {
75267533
const int64_t i03 = i3 / sf3;
@@ -7540,13 +7547,6 @@ static void ggml_compute_forward_upscale_f32(
75407547
}
75417548
}
75427549
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
7543-
float pixel_offset = 0.5f;
7544-
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7545-
pixel_offset = 0.0f;
7546-
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
7547-
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
7548-
}
7549-
75507550
for (int64_t i3 = 0; i3 < ne3; i3++) {
75517551
const int64_t i03 = i3 / sf3;
75527552
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
@@ -7581,6 +7581,51 @@ static void ggml_compute_forward_upscale_f32(
75817581

75827582
const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
75837583

7584+
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7585+
*y_dst = val;
7586+
}
7587+
}
7588+
}
7589+
}
7590+
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
7591+
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
7592+
const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
7593+
auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
7594+
auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
7595+
auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
7596+
const float w0 = weight2(x + 1);
7597+
const float w1 = weight1(x + 0);
7598+
const float w2 = weight1(1 - x);
7599+
const float w3 = weight2(2 - x);
7600+
return p0*w0 + p1*w1 + p2*w2 + p3*w3;
7601+
};
7602+
7603+
for (int64_t i3 = 0; i3 < ne3; i3++) {
7604+
const int64_t i03 = i3 / sf3;
7605+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7606+
const int64_t i02 = i2 / sf2;
7607+
for (int64_t i1 = 0; i1 < ne1; i1++) {
7608+
const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7609+
const int64_t y0 = (int64_t)floorf(y);
7610+
const float dy = y - (float)y0;
7611+
7612+
for (int64_t i0 = 0; i0 < ne0; i0++) {
7613+
const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7614+
const int64_t x0 = (int64_t)floorf(x);
7615+
const float dx = x - (float)x0;
7616+
7617+
auto p = [=](int64_t x_off, int64_t y_off) -> float {
7618+
int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
7619+
int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
7620+
return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7621+
};
7622+
7623+
const float val = bicubic(
7624+
bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
7625+
bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
7626+
bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
7627+
bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
7628+
75847629
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
75857630
*y_dst = val;
75867631
}

0 commit comments

Comments
 (0)