@@ -7517,10 +7517,17 @@ static void ggml_compute_forward_upscale_f32(
75177517 float sf1 = (float )ne1/src0->ne [1 ];
75187518 float sf2 = (float )ne2/src0->ne [2 ];
75197519 float sf3 = (float )ne3/src0->ne [3 ];
7520+ float pixel_offset = 0 .5f ;
75207521
75217522 const int32_t mode_flags = ggml_get_op_params_i32 (dst, 0 );
75227523 const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF );
75237524
7525+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7526+ pixel_offset = 0 .0f ;
7527+ sf0 = ne0 > 1 && ne00 > 1 ? (float )(ne0 - 1 ) / (ne00 - 1 ) : sf0;
7528+ sf1 = ne1 > 1 && ne01 > 1 ? (float )(ne1 - 1 ) / (ne01 - 1 ) : sf1;
7529+ }
7530+
75247531 if (mode == GGML_SCALE_MODE_NEAREST) {
75257532 for (int64_t i3 = 0 ; i3 < ne3; i3++) {
75267533 const int64_t i03 = i3 / sf3;
@@ -7540,13 +7547,6 @@ static void ggml_compute_forward_upscale_f32(
75407547 }
75417548 }
75427549 } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7543- float pixel_offset = 0 .5f ;
7544- if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7545- pixel_offset = 0 .0f ;
7546- sf0 = ne0 > 1 && ne00 > 1 ? (float )(ne0 - 1 ) / (ne00 - 1 ) : sf0;
7547- sf1 = ne1 > 1 && ne01 > 1 ? (float )(ne1 - 1 ) / (ne01 - 1 ) : sf1;
7548- }
7549-
75507550 for (int64_t i3 = 0 ; i3 < ne3; i3++) {
75517551 const int64_t i03 = i3 / sf3;
75527552 for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
@@ -7581,6 +7581,51 @@ static void ggml_compute_forward_upscale_f32(
75817581
75827582 const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
75837583
7584+ float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7585+ *y_dst = val;
7586+ }
7587+ }
7588+ }
7589+ }
7590+ } else if (mode == GGML_SCALE_MODE_BICUBIC) {
7591+ // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
7592+ const float a = -0 .75f ; // use alpha = -0.75 (same as PyTorch)
7593+ auto weight1 = [a](float x) { return ((a + 2 ) * x - (a + 3 )) * x * x + 1 ; };
7594+ auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
7595+ auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
7596+ const float w0 = weight2 (x + 1 );
7597+ const float w1 = weight1 (x + 0 );
7598+ const float w2 = weight1 (1 - x);
7599+ const float w3 = weight2 (2 - x);
7600+ return p0*w0 + p1*w1 + p2*w2 + p3*w3;
7601+ };
7602+
7603+ for (int64_t i3 = 0 ; i3 < ne3; i3++) {
7604+ const int64_t i03 = i3 / sf3;
7605+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7606+ const int64_t i02 = i2 / sf2;
7607+ for (int64_t i1 = 0 ; i1 < ne1; i1++) {
7608+ const float y = ((float )i1 + pixel_offset) / sf1 - pixel_offset;
7609+ const int64_t y0 = (int64_t )floorf (y);
7610+ const float dy = y - (float )y0;
7611+
7612+ for (int64_t i0 = 0 ; i0 < ne0; i0++) {
7613+ const float x = ((float )i0 + pixel_offset) / sf0 - pixel_offset;
7614+ const int64_t x0 = (int64_t )floorf (x);
7615+ const float dx = x - (float )x0;
7616+
7617+ auto p = [=](int64_t x_off, int64_t y_off) -> float {
7618+ int64_t i00 = std::max (int64_t (0 ), std::min (x0 + x_off, ne00 - 1 ));
7619+ int64_t i01 = std::max (int64_t (0 ), std::min (y0 + y_off, ne01 - 1 ));
7620+ return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7621+ };
7622+
7623+ const float val = bicubic (
7624+ bicubic (p (-1 ,-1 ), p (0 ,-1 ), p (1 ,-1 ), p (2 ,-1 ), dx),
7625+ bicubic (p (-1 , 0 ), p (0 , 0 ), p (1 , 0 ), p (2 , 0 ), dx),
7626+ bicubic (p (-1 , 1 ), p (0 , 1 ), p (1 , 1 ), p (2 , 1 ), dx),
7627+ bicubic (p (-1 , 2 ), p (0 , 2 ), p (1 , 2 ), p (2 , 2 ), dx), dy);
7628+
75847629 float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
75857630 *y_dst = val;
75867631 }
0 commit comments