4
4
import inspect
5
5
import sys
6
6
import time
7
+ import warnings
7
8
from concurrent .futures import Future
8
9
from types import TracebackType
9
10
from typing import Any
11
+ from typing import Callable
10
12
11
13
import cloudpickle
12
14
from pybaum .tree_util import tree_map
13
15
from pytask import console
14
16
from pytask import ExecutionReport
17
+ from pytask import get_marks
15
18
from pytask import hookimpl
19
+ from pytask import Mark
16
20
from pytask import remove_internal_traceback_frames_from_exc_info
17
21
from pytask import Session
18
22
from pytask import Task
19
23
from pytask_parallel .backends import PARALLEL_BACKENDS
20
24
from rich .console import ConsoleOptions
21
25
from rich .traceback import Traceback
22
26
27
+ # Can be removed if pinned to pytask >= 0.2.6.
28
+ try :
29
+ from pytask import parse_warning_filter
30
+ from pytask import warning_record_to_str
31
+ from pytask import WarningReport
32
+ except ImportError :
33
+ from _pytask .warnings import parse_warning_filter
34
+ from _pytask .warnings import warning_record_to_str
35
+ from _pytask .warnings_utils import WarningReport
36
+
23
37
24
38
@hookimpl
25
39
def pytask_post_parse (config : dict [str , Any ]) -> None :
@@ -85,42 +99,38 @@ def pytask_execute_build(session: Session) -> bool | None:
85
99
86
100
for task_name in list (running_tasks ):
87
101
future = running_tasks [task_name ]
88
- if future .done () and (
89
- future .exception () is not None
90
- or future .result () is not None
91
- ):
92
- task = session .dag .nodes [task_name ]["task" ]
93
- if future .exception () is not None :
94
- exception = future .exception ()
95
- exc_info = (
96
- type (exception ),
97
- exception ,
98
- exception .__traceback__ ,
99
- )
100
- else :
101
- exc_info = future .result ()
102
-
103
- newly_collected_reports .append (
104
- ExecutionReport .from_task_and_exception (task , exc_info )
102
+ if future .done ():
103
+ warning_reports , task_exception = future .result ()
104
+ session .warnings .extend (warning_reports )
105
+ exc_info = (
106
+ _parse_future_exception (future .exception ())
107
+ or task_exception
105
108
)
106
- running_tasks .pop (task_name )
107
- session .scheduler .done (task_name )
108
- elif future .done () and future .exception () is None :
109
- task = session .dag .nodes [task_name ]["task" ]
110
- try :
111
- session .hook .pytask_execute_task_teardown (
112
- session = session , task = task
113
- )
114
- except Exception :
115
- report = ExecutionReport .from_task_and_exception (
116
- task , sys .exc_info ()
109
+ if exc_info is not None :
110
+ task = session .dag .nodes [task_name ]["task" ]
111
+ newly_collected_reports .append (
112
+ ExecutionReport .from_task_and_exception (
113
+ task , exc_info
114
+ )
117
115
)
116
+ running_tasks .pop (task_name )
117
+ session .scheduler .done (task_name )
118
118
else :
119
- report = ExecutionReport .from_task (task )
120
-
121
- running_tasks .pop (task_name )
122
- newly_collected_reports .append (report )
123
- session .scheduler .done (task_name )
119
+ task = session .dag .nodes [task_name ]["task" ]
120
+ try :
121
+ session .hook .pytask_execute_task_teardown (
122
+ session = session , task = task
123
+ )
124
+ except Exception :
125
+ report = ExecutionReport .from_task_and_exception (
126
+ task , sys .exc_info ()
127
+ )
128
+ else :
129
+ report = ExecutionReport .from_task (task )
130
+
131
+ running_tasks .pop (task_name )
132
+ newly_collected_reports .append (report )
133
+ session .scheduler .done (task_name )
124
134
else :
125
135
pass
126
136
@@ -144,6 +154,17 @@ def pytask_execute_build(session: Session) -> bool | None:
144
154
return None
145
155
146
156
157
+ def _parse_future_exception (
158
+ exception : BaseException | None ,
159
+ ) -> tuple [type [BaseException ], BaseException , TracebackType ] | None :
160
+ """Parse a future exception."""
161
+ return (
162
+ None
163
+ if exception is None
164
+ else (type (exception ), exception , exception .__traceback__ )
165
+ )
166
+
167
+
147
168
class ProcessesNameSpace :
148
169
"""The name space for hooks related to processes."""
149
170
@@ -167,6 +188,9 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
167
188
bytes_kwargs = bytes_kwargs ,
168
189
show_locals = session .config ["show_locals" ],
169
190
console_options = console .options ,
191
+ session_filterwarnings = session .config ["filterwarnings" ],
192
+ task_filterwarnings = get_marks (task , "filterwarnings" ),
193
+ task_short_name = task .short_name ,
170
194
)
171
195
return None
172
196
@@ -176,7 +200,10 @@ def _unserialize_and_execute_task(
176
200
bytes_kwargs : bytes ,
177
201
show_locals : bool ,
178
202
console_options : ConsoleOptions ,
179
- ) -> tuple [type [BaseException ], BaseException , str ] | None :
203
+ session_filterwarnings : tuple [str , ...],
204
+ task_filterwarnings : tuple [Mark , ...],
205
+ task_short_name : str ,
206
+ ) -> tuple [list [WarningReport ], tuple [type [BaseException ], BaseException , str ] | None ]:
180
207
"""Unserialize and execute task.
181
208
182
209
This function receives bytes and unpickles them to a task which is them execute in a
@@ -188,13 +215,40 @@ def _unserialize_and_execute_task(
188
215
task = cloudpickle .loads (bytes_function )
189
216
kwargs = cloudpickle .loads (bytes_kwargs )
190
217
191
- try :
192
- task .execute (** kwargs )
193
- except Exception :
194
- exc_info = sys .exc_info ()
195
- processed_exc_info = _process_exception (exc_info , show_locals , console_options )
196
- return processed_exc_info
197
- return None
218
+ with warnings .catch_warnings (record = True ) as log :
219
+ # mypy can't infer that record=True means log is not None; help it.
220
+ assert log is not None
221
+
222
+ for arg in session_filterwarnings :
223
+ warnings .filterwarnings (* parse_warning_filter (arg , escape = False ))
224
+
225
+ # apply filters from "filterwarnings" marks
226
+ for mark in task_filterwarnings :
227
+ for arg in mark .args :
228
+ warnings .filterwarnings (* parse_warning_filter (arg , escape = False ))
229
+
230
+ try :
231
+ task .execute (** kwargs )
232
+ except Exception :
233
+ exc_info = sys .exc_info ()
234
+ processed_exc_info = _process_exception (
235
+ exc_info , show_locals , console_options
236
+ )
237
+ else :
238
+ processed_exc_info = None
239
+
240
+ warning_reports = []
241
+ for warning_message in log :
242
+ fs_location = warning_message .filename , warning_message .lineno
243
+ warning_reports .append (
244
+ WarningReport (
245
+ message = warning_record_to_str (warning_message ),
246
+ fs_location = fs_location ,
247
+ id_ = task_short_name ,
248
+ )
249
+ )
250
+
251
+ return warning_reports , processed_exc_info
198
252
199
253
200
254
def _process_exception (
@@ -224,11 +278,33 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
224
278
"""
225
279
if session .config ["n_workers" ] > 1 :
226
280
kwargs = _create_kwargs_for_task (task )
227
- return session .executor .submit (task .execute , ** kwargs )
281
+ return session .executor .submit (
282
+ _mock_processes_for_threads , func = task .execute , ** kwargs
283
+ )
228
284
else :
229
285
return None
230
286
231
287
288
+ def _mock_processes_for_threads (
289
+ func : Callable [..., Any ], ** kwargs : Any
290
+ ) -> tuple [list [Any ], tuple [type [BaseException ], BaseException , TracebackType ] | None ]:
291
+ """Mock execution function such that it returns the same as for processes.
292
+
293
+ The function for processes returns ``warning_reports`` and an ``exception``. With
294
+ threads, these object are collected by the main and not the subprocess. So, we just
295
+ return placeholders.
296
+
297
+ """
298
+ __tracebackhide__ = True
299
+ try :
300
+ func (** kwargs )
301
+ except Exception :
302
+ exc_info = sys .exc_info ()
303
+ else :
304
+ exc_info = None
305
+ return [], exc_info
306
+
307
+
232
308
def _create_kwargs_for_task (task : Task ) -> dict [Any , Any ]:
233
309
"""Create kwargs for task function."""
234
310
kwargs = {** task .kwargs }
0 commit comments