Skip to content

Commit 3262fc8

Browse files
committed
TST: Calculate RMS and diff image in C++
The current implementation is not slow, but uses a lot of memory per image. In `compare_images`, we have: - one actual and one expected image as uint8 (2×image) - both converted to int16 (though original is thrown away) (4×) which adds up to 4× the image allocated in this function. Then it calls `calculate_rms`, which has: - a difference between them as int16 (2×) - the difference cast to 64-bit float (8×) - the square of the difference as 64-bit float (though possibly the original difference was thrown away) (8×) which at its peak has 16× the image allocated in parallel. If the RMS is over the desired tolerance, then `save_diff_image` is called, which: - loads the actual and expected images _again_ as uint8 (2× image) - converts both to 64-bit float (throwing away the original) (16×) - calculates the difference (8×) - calculates the absolute value (8×) - multiples that by 10 (in-place, so no allocation) - clips to 0-255 (8×) - casts to uint8 (1×) which at peak uses 32× the image. So at their peak, `compare_images`→`calculate_rms` will have 20× the image allocated, and then `compare_images`→`save_diff_image` will have 36× the image allocated. This is generally not a problem, but on resource-constrained places like WASM, it can sometimes run out of memory just in `calculate_rms`. This implementation in C++ always allocates the diff image, even when not needed, but doesn't have all the temporaries, so it's a maximum of 3× the image size (plus a few scalar temporaries).
1 parent 2c1ec43 commit 3262fc8

File tree

2 files changed

+73
-9
lines changed

2 files changed

+73
-9
lines changed

lib/matplotlib/testing/compare.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from PIL import Image
2020

2121
import matplotlib as mpl
22-
from matplotlib import cbook
22+
from matplotlib import cbook, _image
2323
from matplotlib.testing.exceptions import ImageComparisonFailure
2424

2525
_log = logging.getLogger(__name__)
@@ -412,7 +412,7 @@ def compare_images(expected, actual, tol, in_decorator=False):
412412
413413
The two given filenames may point to files which are convertible to
414414
PNG via the `!converter` dictionary. The underlying RMS is calculated
415-
with the `.calculate_rms` function.
415+
in a similar way to the `.calculate_rms` function.
416416
417417
Parameters
418418
----------
@@ -483,17 +483,12 @@ def compare_images(expected, actual, tol, in_decorator=False):
483483
if np.array_equal(expected_image, actual_image):
484484
return None
485485

486-
# convert to signed integers, so that the images can be subtracted without
487-
# overflow
488-
expected_image = expected_image.astype(np.int16)
489-
actual_image = actual_image.astype(np.int16)
490-
491-
rms = calculate_rms(expected_image, actual_image)
486+
rms, abs_diff = _image.calculate_rms_and_diff(expected_image, actual_image)
492487

493488
if rms <= tol:
494489
return None
495490

496-
save_diff_image(expected, actual, diff_image)
491+
Image.fromarray(abs_diff).save(diff_image, format="png")
497492

498493
results = dict(rms=rms, expected=str(expected),
499494
actual=str(actual), diff=str(diff_image), tol=tol)

src/_image_wrapper.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <pybind11/pybind11.h>
22
#include <pybind11/numpy.h>
33

4+
#include <algorithm>
5+
46
#include "_image_resample.h"
57
#include "py_converters.h"
68

@@ -202,6 +204,70 @@ image_resample(py::array input_array,
202204
}
203205

204206

207+
// This is used by matplotlib.testing.compare to calculate RMS and a difference image.
208+
static py::tuple
209+
calculate_rms_and_diff(py::array_t<unsigned char> expected_image,
210+
py::array_t<unsigned char> actual_image)
211+
{
212+
for (const auto & [image, name] : {std::pair{expected_image, "Expected"},
213+
std::pair{actual_image, "Actual"}})
214+
{
215+
if (image.ndim() != 3) {
216+
auto exceptions = py::module_::import("matplotlib.testing.exceptions");
217+
auto ImageComparisonFailure = exceptions.attr("ImageComparisonFailure");
218+
py::set_error(
219+
ImageComparisonFailure,
220+
"{name} image must be 3-dimensional, but is {ndim}-dimensional"_s.format(
221+
"name"_a=name, "ndim"_a=expected_image.ndim()));
222+
throw py::error_already_set();
223+
}
224+
}
225+
226+
auto height = expected_image.shape(0);
227+
auto width = expected_image.shape(1);
228+
auto depth = expected_image.shape(2);
229+
230+
if (height != actual_image.shape(0) || width != actual_image.shape(1) ||
231+
depth != actual_image.shape(2)) {
232+
auto exceptions = py::module_::import("matplotlib.testing.exceptions");
233+
auto ImageComparisonFailure = exceptions.attr("ImageComparisonFailure");
234+
py::set_error(
235+
ImageComparisonFailure,
236+
"Image sizes do not match expected size: {expected_image.shape} "_s
237+
"actual size {actual_image.shape}"_s.format(
238+
"expected_image"_a=expected_image, "actual_image"_a=actual_image));
239+
throw py::error_already_set();
240+
}
241+
auto expected = expected_image.unchecked<3>();
242+
auto actual = actual_image.unchecked<3>();
243+
244+
py::ssize_t diff_dims[3] = {height, width, 3};
245+
py::array_t<unsigned char> diff_image(diff_dims);
246+
auto diff = diff_image.mutable_unchecked<3>();
247+
248+
double total = 0.0;
249+
for (auto i = 0; i < height; i++) {
250+
for (auto j = 0; j < width; j++) {
251+
for (auto k = 0; k < depth; k++) {
252+
auto pixel_diff = static_cast<double>(expected(i, j, k)) -
253+
static_cast<double>(actual(i, j, k));
254+
255+
total += pixel_diff*pixel_diff;
256+
257+
if (k != 3) { // Hard-code a fully solid alpha channel by omitting it.
258+
diff(i, j, k) = static_cast<unsigned char>(std::clamp(
259+
abs(pixel_diff) * 10, // Expand differences in luminance domain.
260+
0.0, 255.0));
261+
}
262+
}
263+
}
264+
}
265+
total = total / (width * height * depth);
266+
267+
return py::make_tuple(sqrt(total), diff_image);
268+
}
269+
270+
205271
PYBIND11_MODULE(_image, m, py::mod_gil_not_used())
206272
{
207273
py::enum_<interpolation_e>(m, "_InterpolationType")
@@ -234,4 +300,7 @@ PYBIND11_MODULE(_image, m, py::mod_gil_not_used())
234300
"norm"_a = false,
235301
"radius"_a = 1,
236302
image_resample__doc__);
303+
304+
m.def("calculate_rms_and_diff", &calculate_rms_and_diff,
305+
"expected_image"_a, "actual_image"_a);
237306
}

0 commit comments

Comments
 (0)