Skip to content

Commit 497f44f

Browse files
Merge pull request #697 from mathLab/dev
Dev updates 0.2.5
2 parents 9c3e55d + f07e59b commit 497f44f

File tree

14 files changed

+92
-25
lines changed

14 files changed

+92
-25
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ SPDX-License-Identifier: Apache-2.0
77
<table>
88
<tr>
99
<td>
10-
<a href="https://github.com/mathLab/PINA/raw/master/readme/pina_logo.png">
11-
<img src="https://github.com/mathLab/PINA/raw/master/readme/pina_logo.png"
10+
<a href="readme/pina_logo.png">
11+
<img src="readme/pina_logo.png"
1212
alt="PINA logo"
1313
style="width: 220px; aspect-ratio: 1 / 1; object-fit: contain;">
1414
</a>
234 KB
Loading

pina/equation/equation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Module for the Equation."""
22

33
import inspect
4-
54
from .equation_interface import EquationInterface
65

76

@@ -49,6 +48,10 @@ def residual(self, input_, output_, params_=None):
4948
:raises RuntimeError: If the underlying equation signature length is not
5049
2 (direct problem) or 3 (inverse problem).
5150
"""
51+
# Move the equation to the input_ device
52+
self.to(input_.device)
53+
54+
# Call the underlying equation based on its signature length
5255
if self.__len_sig == 2:
5356
return self.__equation(input_, output_)
5457
if self.__len_sig == 3:

pina/equation/equation_factory.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,19 +239,19 @@ def equation(input_, output_):
239239
)
240240

241241
# Ensure consistency of c length
242-
if len(self.c) != (len(input_lbl) - 1) and len(self.c) > 1:
242+
if self.c.shape[-1] != len(input_lbl) - 1 and self.c.shape[-1] > 1:
243243
raise ValueError(
244244
"If 'c' is passed as a list, its length must be equal to "
245245
"the number of spatial dimensions."
246246
)
247247

248248
# Repeat c to ensure consistent shape for advection
249-
self.c = self.c.repeat(output_.shape[0], 1)
250-
if self.c.shape[1] != (len(input_lbl) - 1):
251-
self.c = self.c.repeat(1, len(input_lbl) - 1)
249+
c = self.c.repeat(output_.shape[0], 1)
250+
if c.shape[1] != (len(input_lbl) - 1):
251+
c = c.repeat(1, len(input_lbl) - 1)
252252

253253
# Add a dimension to c for the following operations
254-
self.c = self.c.unsqueeze(-1)
254+
c = c.unsqueeze(-1)
255255

256256
# Compute the time derivative and the spatial gradient
257257
time_der = grad(output_, input_, components=None, d="t")
@@ -262,7 +262,7 @@ def equation(input_, output_):
262262
tmp = tmp.transpose(-1, -2)
263263

264264
# Compute advection term
265-
adv = (tmp * self.c).sum(dim=tmp.tensor.ndim - 2)
265+
adv = (tmp * c).sum(dim=tmp.tensor.ndim - 2)
266266

267267
return time_der + adv
268268

pina/equation/equation_interface.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Module for the Equation Interface."""
22

33
from abc import ABCMeta, abstractmethod
4+
import torch
45

56

67
class EquationInterface(metaclass=ABCMeta):
@@ -33,3 +34,33 @@ def residual(self, input_, output_, params_):
3334
:return: The computed residual of the equation.
3435
:rtype: LabelTensor
3536
"""
37+
38+
def to(self, device):
39+
"""
40+
Move all tensor attributes to the specified device.
41+
42+
:param torch.device device: The target device to move the tensors to.
43+
:return: The instance moved to the specified device.
44+
:rtype: EquationInterface
45+
"""
46+
# Iterate over all attributes of the Equation
47+
for key, val in self.__dict__.items():
48+
49+
# Move tensors in dictionaries to the specified device
50+
if isinstance(val, dict):
51+
self.__dict__[key] = {
52+
k: v.to(device) if torch.is_tensor(v) else v
53+
for k, v in val.items()
54+
}
55+
56+
# Move tensors in lists to the specified device
57+
elif isinstance(val, list):
58+
self.__dict__[key] = [
59+
v.to(device) if torch.is_tensor(v) else v for v in val
60+
]
61+
62+
# Move tensor attributes to the specified device
63+
elif torch.is_tensor(val):
64+
self.__dict__[key] = val.to(device)
65+
66+
return self

pina/equation/system_equation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,18 @@ def residual(self, input_, output_, params_=None):
101101
:return: The aggregated residuals of the system of equations.
102102
:rtype: LabelTensor
103103
"""
104+
# Move the equation to the input_ device
105+
self.to(input_.device)
106+
107+
# Compute the residual for each equation
104108
residual = torch.hstack(
105109
[
106110
equation.residual(input_, output_, params_)
107111
for equation in self.equations
108112
]
109113
)
110114

115+
# Skip reduction if not specified
111116
if self.reduction is None:
112117
return residual
113118

pina/problem/zoo/helmholtz.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,10 @@ def __init__(self, alpha=3.0):
4848
:type alpha: float | int
4949
"""
5050
super().__init__()
51-
52-
self.alpha = alpha
5351
check_consistency(alpha, (int, float))
52+
self.alpha = alpha
5453

55-
def forcing_term(self, input_):
54+
def forcing_term(input_):
5655
"""
5756
Implementation of the forcing term.
5857
"""

pina/solver/physics_informed_solver/pinn_interface.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ def setup(self, stage):
7171
"""
7272
# Override the compilation, compiling only for torch < 2.8, see
7373
# related issue at https://github.com/mathLab/PINA/issues/621
74-
if torch.__version__ < "2.8":
75-
self.trainer.compile = True
76-
else:
74+
if torch.__version__ >= "2.8":
7775
self.trainer.compile = False
7876
warnings.warn(
7977
"Compilation is disabled for torch >= 2.8. "

pina/solver/solver.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,7 @@ def setup(self, stage):
174174
:return: The result of the parent class ``setup`` method.
175175
:rtype: Any
176176
"""
177-
if stage == "fit" and self.trainer.compile:
178-
self._setup_compile()
179-
if stage == "test" and (
180-
self.trainer.compile and not self._is_compiled()
181-
):
177+
if self.trainer.compile and not self._is_compiled():
182178
self._setup_compile()
183179
return super().setup(stage)
184180

pina/trainer.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
"""Module for the Trainer."""
22

33
import sys
4+
import warnings
45
import torch
56
import lightning
6-
from .utils import check_consistency
7+
from .utils import check_consistency, custom_warning_format
78
from .data import PinaDataModule
89
from .solver import SolverInterface, PINNInterface
910

11+
# set the warning for compile options
12+
warnings.formatwarning = custom_warning_format
13+
warnings.filterwarnings("always", category=UserWarning)
14+
1015

1116
class Trainer(lightning.pytorch.Trainer):
1217
"""
@@ -49,7 +54,8 @@ def __init__(
4954
:param float val_size: The percentage of elements to include in the
5055
validation dataset. Default is ``0.0``.
5156
:param bool compile: If ``True``, the model is compiled before training.
52-
Default is ``False``. For Windows users, it is always disabled.
57+
Default is ``False``. For Windows users, it is always disabled. Not
58+
supported for python version greater or equal than 3.14.
5359
:param bool repeat: Whether to repeat the dataset data in each
5460
condition during training. For further details, see the
5561
:class:`~pina.data.data_module.PinaDataModule` class. Default is
@@ -104,8 +110,17 @@ def __init__(
104110
super().__init__(**kwargs)
105111

106112
# checking compilation and automatic batching
107-
if compile is None or sys.platform == "win32":
113+
# compilation disabled for Windows and for Python 3.14+
114+
if (
115+
compile is None
116+
or sys.platform == "win32"
117+
or sys.version_info >= (3, 14)
118+
):
108119
compile = False
120+
warnings.warn(
121+
"Compilation is disabled for Python 3.14+ and for Windows.",
122+
UserWarning,
123+
)
109124

110125
repeat = repeat if repeat is not None else False
111126

@@ -325,3 +340,23 @@ def _check_consistency_and_set_defaults(
325340
if batch_size is not None:
326341
check_consistency(batch_size, int)
327342
return pin_memory, num_workers, shuffle, batch_size
343+
344+
@property
345+
def compile(self):
346+
"""
347+
Whether compilation is required or not.
348+
349+
:return: ``True`` if compilation is required, ``False`` otherwise.
350+
:rtype: bool
351+
"""
352+
return self._compile
353+
354+
@compile.setter
355+
def compile(self, value):
356+
"""
357+
Setting the value of compile.
358+
359+
:param bool value: Whether compilation is required or not.
360+
"""
361+
check_consistency(value, bool)
362+
self._compile = value

0 commit comments

Comments
 (0)