diff --git a/pina/equation/equation.py b/pina/equation/equation.py index 677b0e54a..057c6bcf5 100644 --- a/pina/equation/equation.py +++ b/pina/equation/equation.py @@ -1,7 +1,6 @@ """Module for the Equation.""" import inspect -import torch from .equation_interface import EquationInterface @@ -61,33 +60,3 @@ def residual(self, input_, output_, params_=None): f"Unexpected number of arguments in equation: {self.__len_sig}. " "Expected either 2 (direct problem) or 3 (inverse problem)." ) - - def to(self, device): - """ - Move all tensor attributes of the Equation to the specified device. - - :param torch.device device: The target device to move the tensors to. - :return: The Equation instance moved to the specified device. - :rtype: Equation - """ - # Iterate over all attributes of the Equation - for key, val in self.__dict__.items(): - - # Move tensors in dictionaries to the specified device - if isinstance(val, dict): - self.__dict__[key] = { - k: v.to(device) if torch.is_tensor(v) else v - for k, v in val.items() - } - - # Move tensors in lists to the specified device - elif isinstance(val, list): - self.__dict__[key] = [ - v.to(device) if torch.is_tensor(v) else v for v in val - ] - - # Move tensor attributes to the specified device - elif torch.is_tensor(val): - self.__dict__[key] = val.to(device) - - return self diff --git a/pina/equation/equation_interface.py b/pina/equation/equation_interface.py index f1cc74754..82b86dbd0 100644 --- a/pina/equation/equation_interface.py +++ b/pina/equation/equation_interface.py @@ -1,6 +1,7 @@ """Module for the Equation Interface.""" from abc import ABCMeta, abstractmethod +import torch class EquationInterface(metaclass=ABCMeta): @@ -33,3 +34,33 @@ def residual(self, input_, output_, params_): :return: The computed residual of the equation. :rtype: LabelTensor """ + + def to(self, device): + """ + Move all tensor attributes to the specified device. + + :param torch.device device: The target device to move the tensors to. + :return: The instance moved to the specified device. + :rtype: EquationInterface + """ + # Iterate over all attributes of the Equation + for key, val in self.__dict__.items(): + + # Move tensors in dictionaries to the specified device + if isinstance(val, dict): + self.__dict__[key] = { + k: v.to(device) if torch.is_tensor(v) else v + for k, v in val.items() + } + + # Move tensors in lists to the specified device + elif isinstance(val, list): + self.__dict__[key] = [ + v.to(device) if torch.is_tensor(v) else v for v in val + ] + + # Move tensor attributes to the specified device + elif torch.is_tensor(val): + self.__dict__[key] = val.to(device) + + return self diff --git a/pina/equation/system_equation.py b/pina/equation/system_equation.py index 21cb27160..3e8550d9b 100644 --- a/pina/equation/system_equation.py +++ b/pina/equation/system_equation.py @@ -101,6 +101,10 @@ def residual(self, input_, output_, params_=None): :return: The aggregated residuals of the system of equations. :rtype: LabelTensor """ + # Move the equation to the input_ device + self.to(input_.device) + + # Compute the residual for each equation residual = torch.hstack( [ equation.residual(input_, output_, params_) @@ -108,6 +112,7 @@ def residual(self, input_, output_, params_=None): ] ) + # Skip reduction if not specified if self.reduction is None: return residual