33import itertools
44import textwrap
55from contextlib import ExitStack as does_not_raise # noqa: N813
6+ from typing import NamedTuple
67
78import _pytask .parametrize
89import pytask
910import pytest
1011from _pytask .parametrize import _arg_value_to_id_component
12+ from _pytask .parametrize import _check_if_n_arg_names_matches_n_arg_values
1113from _pytask .parametrize import _parse_arg_names
14+ from _pytask .parametrize import _parse_arg_values
1215from _pytask .parametrize import _parse_parametrize_markers
1316from _pytask .parametrize import pytask_parametrize_task
1417from _pytask .pluginmanager import get_plugin_manager
@@ -54,14 +57,7 @@ def test_pytask_generate_tasks_1(session):
5457 def func (i , j ): # noqa: U100
5558 pass
5659
57- names_and_objs = pytask_parametrize_task (session , "func" , func )
58-
59- for (name , func ), values in zip (
60- names_and_objs , itertools .product (range (2 ), range (2 ))
61- ):
62- assert name == f"func[{ values [0 ]} -{ values [1 ]} ]"
63- assert func .keywords ["i" ] == values [0 ]
64- assert func .keywords ["j" ] == values [1 ]
60+ pytask_parametrize_task (session , "func" , func )
6561
6662
6763@pytest .mark .integration
@@ -72,16 +68,7 @@ def test_pytask_generate_tasks_2(session):
7268 def func (i , j , k ): # noqa: U100
7369 pass
7470
75- names_and_objs = pytask_parametrize_task (session , "func" , func )
76-
77- for (name , func ), values in zip (
78- names_and_objs ,
79- [(i , j , k ) for i in range (2 ) for j in range (2 ) for k in range (2 )],
80- ):
81- assert name == f"func[{ values [0 ]} -{ values [1 ]} -{ values [2 ]} ]"
82- assert func .keywords ["i" ] == values [0 ]
83- assert func .keywords ["j" ] == values [1 ]
84- assert func .keywords ["k" ] == values [2 ]
71+ pytask_parametrize_task (session , "func" , func )
8572
8673
8774@pytest .mark .integration
@@ -109,9 +96,32 @@ def func():
10996 (["i" , "j" ], ("i" , "j" )),
11097 ],
11198)
112- def test_parse_argnames (arg_names , expected ):
113- parsed_argnames = _parse_arg_names (arg_names )
114- assert parsed_argnames == expected
99+ def test_parse_arg_names (arg_names , expected ):
100+ parsed_arg_names = _parse_arg_names (arg_names )
101+ assert parsed_arg_names == expected
102+
103+
104+ class TaskArguments (NamedTuple ):
105+ a : int
106+ b : int
107+
108+
109+ @pytest .mark .unit
110+ @pytest .mark .parametrize (
111+ "arg_values, expected" ,
112+ [
113+ (["a" , "b" , "c" ], [("a" ,), ("b" ,), ("c" ,)]),
114+ ([(0 , 0 ), (0 , 1 ), (1 , 0 )], [(0 , 0 ), (0 , 1 ), (1 , 0 )]),
115+ ([[0 , 0 ], [0 , 1 ], [1 , 0 ]], [(0 , 0 ), (0 , 1 ), (1 , 0 )]),
116+ ({"a" : 0 , "b" : 1 }, [("a" ,), ("b" ,)]),
117+ ([TaskArguments (1 , 2 )], [(1 , 2 )]),
118+ ([TaskArguments (a = 1 , b = 2 )], [(1 , 2 )]),
119+ ([TaskArguments (b = 2 , a = 1 )], [(1 , 2 )]),
120+ ],
121+ )
122+ def test_parse_arg_values (arg_values , expected ):
123+ parsed_arg_values = _parse_arg_values (arg_values )
124+ assert parsed_arg_values == expected
115125
116126
117127@pytest .mark .unit
@@ -267,28 +277,20 @@ def task_func(i):
267277
268278
269279@pytest .mark .end_to_end
270- @pytest .mark .xfail (strict = True , reason = "Cartesian task product is disabled." )
271- def test_two_parametrize_w_ids (tmp_path ):
272- tmp_path .joinpath ("task_module.py" ).write_text (
273- textwrap .dedent (
274- """
275- import pytask
280+ def test_two_parametrize_w_ids (runner , tmp_path ):
281+ source = """
282+ import pytask
276283
277- @pytask.mark.parametrize('i', range(2), ids=["2.1", "2.2"])
278- @pytask.mark.parametrize('j', range(2), ids=["1.1", "1.2"])
279- def task_func(i, j):
280- pass
281- """
282- )
283- )
284- session = main ({"paths" : tmp_path })
284+ @pytask.mark.parametrize('i', range(2), ids=["2.1", "2.2"])
285+ @pytask.mark.parametrize('j', range(2), ids=["1.1", "1.2"])
286+ def task_func(i, j):
287+ pass
288+ """
289+ tmp_path .joinpath ("task_module.py" ).write_text (textwrap .dedent (source ))
290+ result = runner .invoke (cli , [tmp_path .as_posix ()])
285291
286- assert session .exit_code == 0
287- assert len (session .tasks ) == 4
288- for task , id_ in zip (
289- session .tasks , ["[1.1-2.1]" , "[1.1-2.2]" , "[1.2-2.1]" , "[1.2-2.2]" ]
290- ):
291- assert id_ in task .name
292+ assert result .exit_code == ExitCode .COLLECTION_FAILED
293+ assert "You cannot apply @pytask.mark.parametrize multiple" in result .output
292294
293295
294296@pytest .mark .end_to_end
@@ -430,3 +432,89 @@ def task_example(produces):
430432 session = main ({"paths" : tmp_path })
431433 assert session .exit_code == 0
432434 assert session .tasks [0 ].function .__wrapped__ .pytaskmark == []
435+
436+
437+ @pytest .mark .end_to_end
438+ def test_parametrizing_tasks_with_namedtuples (runner , tmp_path ):
439+ source = """
440+ from typing import NamedTuple
441+ import pytask
442+ from pathlib import Path
443+
444+
445+ class Task(NamedTuple):
446+ i: int
447+ produces: Path
448+
449+
450+ @pytask.mark.parametrize('i, produces', [
451+ Task(i=1, produces="1.txt"), Task(produces="2.txt", i=2),
452+ ])
453+ def task_write_numbers_to_file(produces, i):
454+ produces.write_text(str(i))
455+ """
456+ tmp_path .joinpath ("task_module.py" ).write_text (textwrap .dedent (source ))
457+
458+ result = runner .invoke (cli , [tmp_path .as_posix ()])
459+
460+ assert result .exit_code == 0
461+ for i in range (1 , 3 ):
462+ assert tmp_path .joinpath (f"{ i } .txt" ).read_text () == str (i )
463+
464+
465+ @pytest .mark .end_to_end
466+ def test_parametrization_with_different_n_of_arg_names_and_arg_values (runner , tmp_path ):
467+ source = """
468+ import pytask
469+
470+ @pytask.mark.parametrize('i, produces', [(1, "1.txt"), (2, 3, "2.txt")])
471+ def task_write_numbers_to_file(produces, i):
472+ produces.write_text(str(i))
473+ """
474+ tmp_path .joinpath ("task_module.py" ).write_text (textwrap .dedent (source ))
475+
476+ result = runner .invoke (cli , [tmp_path .as_posix ()])
477+
478+ assert result .exit_code == ExitCode .COLLECTION_FAILED
479+ assert "Task 'task_write_numbers_to_file' is parametrized with 2" in result .output
480+
481+
482+ @pytest .mark .unit
483+ @pytest .mark .parametrize (
484+ "arg_names, arg_values, name, expectation" ,
485+ [
486+ pytest .param (
487+ ("a" ,),
488+ [(1 ,), (2 ,)],
489+ "task_name" ,
490+ does_not_raise (),
491+ id = "normal one argument parametrization" ,
492+ ),
493+ pytest .param (
494+ ("a" , "b" ),
495+ [(1 , 2 ), (3 , 4 )],
496+ "task_name" ,
497+ does_not_raise (),
498+ id = "normal two argument argument parametrization" ,
499+ ),
500+ pytest .param (
501+ ("a" ,),
502+ [(1 , 2 ), (2 ,)],
503+ "task_name" ,
504+ pytest .raises (ValueError , match = "Task 'task_name' is parametrized with 1" ),
505+ id = "error with one argument parametrization" ,
506+ ),
507+ pytest .param (
508+ ("a" , "b" ),
509+ [(1 , 2 ), (3 , 4 , 5 )],
510+ "task_name" ,
511+ pytest .raises (ValueError , match = "Task 'task_name' is parametrized with 2" ),
512+ id = "error with two argument argument parametrization" ,
513+ ),
514+ ],
515+ )
516+ def test_check_if_n_arg_names_matches_n_arg_values (
517+ arg_names , arg_values , name , expectation
518+ ):
519+ with expectation :
520+ _check_if_n_arg_names_matches_n_arg_values (arg_names , arg_values , name )
0 commit comments