9
9
import contextlib
10
10
import enum
11
11
import warnings
12
- from collections .abc import Callable , Iterator , Sequence
12
+ from collections .abc import Callable , Generator , Iterator , Sequence
13
13
from functools import wraps
14
14
from types import ModuleType
15
15
from typing import TYPE_CHECKING , Any , ParamSpec , TypeVar , cast
@@ -216,8 +216,11 @@ def test_myfunc(xp):
216
216
217
217
218
218
def patch_lazy_xp_functions (
219
- request : pytest .FixtureRequest , monkeypatch : pytest .MonkeyPatch , * , xp : ModuleType
220
- ) -> None :
219
+ request : pytest .FixtureRequest ,
220
+ monkeypatch : pytest .MonkeyPatch | None = None ,
221
+ * ,
222
+ xp : ModuleType ,
223
+ ) -> contextlib .AbstractContextManager [None ]:
221
224
"""
222
225
Test lazy execution of functions tagged with :func:`lazy_xp_function`.
223
226
@@ -233,10 +236,15 @@ def patch_lazy_xp_functions(
233
236
This function should be typically called by your library's `xp` fixture that runs
234
237
tests on multiple backends::
235
238
236
- @pytest.fixture(params=[numpy, array_api_strict, jax.numpy, dask.array])
237
- def xp(request, monkeypatch):
238
- patch_lazy_xp_functions(request, monkeypatch, xp=request.param)
239
- return request.param
239
+ @pytest.fixture(params=[
240
+ numpy,
241
+ array_api_strict,
242
+ pytest.param(jax.numpy, marks=pytest.mark.thread_unsafe),
243
+ pytest.param(dask.array, marks=pytest.mark.thread_unsafe),
244
+ ])
245
+ def xp(request):
246
+ with patch_lazy_xp_functions(request, xp=request.param):
247
+ yield request.param
240
248
241
249
but it can be otherwise be called by the test itself too.
242
250
@@ -245,18 +253,50 @@ def xp(request, monkeypatch):
245
253
request : pytest.FixtureRequest
246
254
Pytest fixture, as acquired by the test itself or by one of its fixtures.
247
255
monkeypatch : pytest.MonkeyPatch
248
- Pytest fixture, as acquired by the test itself or by one of its fixtures.
256
+ Deprecated
249
257
xp : array_namespace
250
258
Array namespace to be tested.
251
259
252
260
See Also
253
261
--------
254
262
lazy_xp_function : Tag a function to be tested on lazy backends.
255
263
pytest.FixtureRequest : `request` test function parameter.
264
+
265
+ Notes
266
+ -----
267
+ This context manager monkey-patches modules and as such is thread unsafe
268
+ on Dask and JAX. If you run your test suite with
269
+ `pytest-run-parallel <https://github.com/Quansight-Labs/pytest-run-parallel/>`_,
270
+ you should mark these backends with ``@pytest.mark.thread_unsafe``, as shown in
271
+ the example above.
256
272
"""
257
273
mod = cast (ModuleType , request .module )
258
274
mods = [mod , * cast (list [ModuleType ], getattr (mod , "lazy_xp_modules" , []))]
259
275
276
+ to_revert : list [tuple [ModuleType , str , object ]] = []
277
+
278
+ def temp_setattr (mod : ModuleType , name : str , func : object ) -> None :
279
+ """
280
+ Variant of monkeypatch.setattr, which allows monkey-patching only selected
281
+ parameters of a test so that pytest-run-parallel can run on the remainder.
282
+ """
283
+ assert hasattr (mod , name )
284
+ to_revert .append ((mod , name , getattr (mod , name )))
285
+ setattr (mod , name , func )
286
+
287
+ if monkeypatch is not None :
288
+ warnings .warn (
289
+ (
290
+ "The `monkeypatch` parameter is deprecated and will be removed in a "
291
+ "future version. "
292
+ "Use `patch_lazy_xp_function` as a context manager instead."
293
+ ),
294
+ DeprecationWarning ,
295
+ stacklevel = 2 ,
296
+ )
297
+ # Enable using patch_lazy_xp_function not as a context manager
298
+ temp_setattr = monkeypatch .setattr # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
299
+
260
300
def iter_tagged () -> (
261
301
Iterator [tuple [ModuleType , str , Callable [..., Any ], dict [str , Any ]]]
262
302
):
@@ -279,13 +319,26 @@ def iter_tagged() -> (
279
319
elif n is False :
280
320
n = 0
281
321
wrapped = _dask_wrap (func , n )
282
- monkeypatch . setattr (mod , name , wrapped )
322
+ temp_setattr (mod , name , wrapped )
283
323
284
324
elif is_jax_namespace (xp ):
285
325
for mod , name , func , tags in iter_tagged ():
286
326
if tags ["jax_jit" ]:
287
327
wrapped = jax_autojit (func )
288
- monkeypatch .setattr (mod , name , wrapped )
328
+ temp_setattr (mod , name , wrapped )
329
+
330
+ # We can't just decorate patch_lazy_xp_functions with
331
+ # @contextlib.contextmanager because it would not work with the
332
+ # deprecated monkeypatch when not used as a context manager.
333
+ @contextlib .contextmanager
334
+ def revert_on_exit () -> Generator [None ]:
335
+ try :
336
+ yield
337
+ finally :
338
+ for mod , name , orig_func in to_revert :
339
+ setattr (mod , name , orig_func )
340
+
341
+ return revert_on_exit ()
289
342
290
343
291
344
class CountingDaskScheduler (SchedulerGetCallable ):
0 commit comments