Skip to content

Commit d4fa3ea

Browse files
move to method to the interface (#694)
1 parent fca3db7 commit d4fa3ea

File tree

3 files changed

+36
-31
lines changed

3 files changed

+36
-31
lines changed

pina/equation/equation.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Module for the Equation."""
22

33
import inspect
4-
import torch
54
from .equation_interface import EquationInterface
65

76

@@ -61,33 +60,3 @@ def residual(self, input_, output_, params_=None):
6160
f"Unexpected number of arguments in equation: {self.__len_sig}. "
6261
"Expected either 2 (direct problem) or 3 (inverse problem)."
6362
)
64-
65-
def to(self, device):
66-
"""
67-
Move all tensor attributes of the Equation to the specified device.
68-
69-
:param torch.device device: The target device to move the tensors to.
70-
:return: The Equation instance moved to the specified device.
71-
:rtype: Equation
72-
"""
73-
# Iterate over all attributes of the Equation
74-
for key, val in self.__dict__.items():
75-
76-
# Move tensors in dictionaries to the specified device
77-
if isinstance(val, dict):
78-
self.__dict__[key] = {
79-
k: v.to(device) if torch.is_tensor(v) else v
80-
for k, v in val.items()
81-
}
82-
83-
# Move tensors in lists to the specified device
84-
elif isinstance(val, list):
85-
self.__dict__[key] = [
86-
v.to(device) if torch.is_tensor(v) else v for v in val
87-
]
88-
89-
# Move tensor attributes to the specified device
90-
elif torch.is_tensor(val):
91-
self.__dict__[key] = val.to(device)
92-
93-
return self

pina/equation/equation_interface.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Module for the Equation Interface."""
22

33
from abc import ABCMeta, abstractmethod
4+
import torch
45

56

67
class EquationInterface(metaclass=ABCMeta):
@@ -33,3 +34,33 @@ def residual(self, input_, output_, params_):
3334
:return: The computed residual of the equation.
3435
:rtype: LabelTensor
3536
"""
37+
38+
def to(self, device):
39+
"""
40+
Move all tensor attributes to the specified device.
41+
42+
:param torch.device device: The target device to move the tensors to.
43+
:return: The instance moved to the specified device.
44+
:rtype: EquationInterface
45+
"""
46+
# Iterate over all attributes of the Equation
47+
for key, val in self.__dict__.items():
48+
49+
# Move tensors in dictionaries to the specified device
50+
if isinstance(val, dict):
51+
self.__dict__[key] = {
52+
k: v.to(device) if torch.is_tensor(v) else v
53+
for k, v in val.items()
54+
}
55+
56+
# Move tensors in lists to the specified device
57+
elif isinstance(val, list):
58+
self.__dict__[key] = [
59+
v.to(device) if torch.is_tensor(v) else v for v in val
60+
]
61+
62+
# Move tensor attributes to the specified device
63+
elif torch.is_tensor(val):
64+
self.__dict__[key] = val.to(device)
65+
66+
return self

pina/equation/system_equation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,18 @@ def residual(self, input_, output_, params_=None):
101101
:return: The aggregated residuals of the system of equations.
102102
:rtype: LabelTensor
103103
"""
104+
# Move the equation to the input_ device
105+
self.to(input_.device)
106+
107+
# Compute the residual for each equation
104108
residual = torch.hstack(
105109
[
106110
equation.residual(input_, output_, params_)
107111
for equation in self.equations
108112
]
109113
)
110114

115+
# Skip reduction if not specified
111116
if self.reduction is None:
112117
return residual
113118

0 commit comments

Comments
 (0)