|
1 | 1 | """Module for the Trainer.""" |
2 | 2 |
|
3 | 3 | import sys |
| 4 | +import warnings |
4 | 5 | import torch |
5 | 6 | import lightning |
6 | | -from .utils import check_consistency |
| 7 | +from .utils import check_consistency, custom_warning_format |
7 | 8 | from .data import PinaDataModule |
8 | 9 | from .solver import SolverInterface, PINNInterface |
9 | 10 |
|
| 11 | +# set the warning for compile options |
| 12 | +warnings.formatwarning = custom_warning_format |
| 13 | +warnings.filterwarnings("always", category=UserWarning) |
| 14 | + |
10 | 15 |
|
11 | 16 | class Trainer(lightning.pytorch.Trainer): |
12 | 17 | """ |
@@ -49,7 +54,8 @@ def __init__( |
49 | 54 | :param float val_size: The percentage of elements to include in the |
50 | 55 | validation dataset. Default is ``0.0``. |
51 | 56 | :param bool compile: If ``True``, the model is compiled before training. |
52 | | - Default is ``False``. For Windows users, it is always disabled. |
| 57 | + Default is ``False``. For Windows users, it is always disabled. Not |
| 58 | + supported for python version greater or equal than 3.14. |
53 | 59 | :param bool repeat: Whether to repeat the dataset data in each |
54 | 60 | condition during training. For further details, see the |
55 | 61 | :class:`~pina.data.data_module.PinaDataModule` class. Default is |
@@ -104,8 +110,17 @@ def __init__( |
104 | 110 | super().__init__(**kwargs) |
105 | 111 |
|
106 | 112 | # checking compilation and automatic batching |
107 | | - if compile is None or sys.platform == "win32": |
| 113 | + # compilation disabled for Windows and for Python 3.14+ |
| 114 | + if ( |
| 115 | + compile is None |
| 116 | + or sys.platform == "win32" |
| 117 | + or sys.version_info >= (3, 14) |
| 118 | + ): |
108 | 119 | compile = False |
| 120 | + warnings.warn( |
| 121 | + "Compilation is disabled for Python 3.14+ and for Windows.", |
| 122 | + UserWarning, |
| 123 | + ) |
109 | 124 |
|
110 | 125 | repeat = repeat if repeat is not None else False |
111 | 126 |
|
@@ -325,3 +340,23 @@ def _check_consistency_and_set_defaults( |
325 | 340 | if batch_size is not None: |
326 | 341 | check_consistency(batch_size, int) |
327 | 342 | return pin_memory, num_workers, shuffle, batch_size |
| 343 | + |
| 344 | + @property |
| 345 | + def compile(self): |
| 346 | + """ |
| 347 | + Whether compilation is required or not. |
| 348 | +
|
| 349 | + :return: ``True`` if compilation is required, ``False`` otherwise. |
| 350 | + :rtype: bool |
| 351 | + """ |
| 352 | + return self._compile |
| 353 | + |
| 354 | + @compile.setter |
| 355 | + def compile(self, value): |
| 356 | + """ |
| 357 | + Setting the value of compile. |
| 358 | +
|
| 359 | + :param bool value: Whether compilation is required or not. |
| 360 | + """ |
| 361 | + check_consistency(value, bool) |
| 362 | + self._compile = value |
0 commit comments