From 89f98ebd357e7bd7d0089e87557606248e47c80b Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Thu, 19 Sep 2024 11:57:25 +0100 Subject: [PATCH] Added benchmarks for parallel_fit_dask --- benchmarks/modeling/parallel_fitting.py | 41 +++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 benchmarks/modeling/parallel_fitting.py diff --git a/benchmarks/modeling/parallel_fitting.py b/benchmarks/modeling/parallel_fitting.py new file mode 100644 index 0000000000..041cd2e151 --- /dev/null +++ b/benchmarks/modeling/parallel_fitting.py @@ -0,0 +1,41 @@ +# Test the parallel_fit_dask function in synchronous mode + +import numpy as np + +from astropy.modeling.models import Gaussian1D, Const1D +from astropy.modeling.fitting import TRFLSQFitter, parallel_fit_dask + +x = np.linspace(0, 100, 20) + +y = 5 * np.exp(-((x - 30) ** 2) / (2 * 10**2)) + np.random.normal(0, 1, 20) +y = y.reshape((20, 1, 1)) +y = np.broadcast_to(y, (20, 30, 10)) + +y_plus_y0 = y + 3 + +g = Gaussian1D(1, 20, 1) +g_plus_const = g + Const1D(0) + +fitter = TRFLSQFitter() + + +def time_parallel_gaussian_fit(): + parallel_fit_dask( + model=g, + fitter=fitter, + data=y, + fitting_axes=0, + world=(x,), + scheduler="synchronous", + ) + + +def time_parallel_compound_fit(): + parallel_fit_dask( + model=g_plus_const, + fitter=fitter, + data=y, + fitting_axes=0, + world=(x,), + scheduler="synchronous", + )