Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 0 additions & 31 deletions pina/equation/equation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Module for the Equation."""

import inspect
import torch
from .equation_interface import EquationInterface


Expand Down Expand Up @@ -61,33 +60,3 @@ def residual(self, input_, output_, params_=None):
f"Unexpected number of arguments in equation: {self.__len_sig}. "
"Expected either 2 (direct problem) or 3 (inverse problem)."
)

def to(self, device):
"""
Move all tensor attributes of the Equation to the specified device.

:param torch.device device: The target device to move the tensors to.
:return: The Equation instance moved to the specified device.
:rtype: Equation
"""
# Iterate over all attributes of the Equation
for key, val in self.__dict__.items():

# Move tensors in dictionaries to the specified device
if isinstance(val, dict):
self.__dict__[key] = {
k: v.to(device) if torch.is_tensor(v) else v
for k, v in val.items()
}

# Move tensors in lists to the specified device
elif isinstance(val, list):
self.__dict__[key] = [
v.to(device) if torch.is_tensor(v) else v for v in val
]

# Move tensor attributes to the specified device
elif torch.is_tensor(val):
self.__dict__[key] = val.to(device)

return self
31 changes: 31 additions & 0 deletions pina/equation/equation_interface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module for the Equation Interface."""

from abc import ABCMeta, abstractmethod
import torch


class EquationInterface(metaclass=ABCMeta):
Expand Down Expand Up @@ -33,3 +34,33 @@ def residual(self, input_, output_, params_):
:return: The computed residual of the equation.
:rtype: LabelTensor
"""

def to(self, device):
"""
Move all tensor attributes to the specified device.

:param torch.device device: The target device to move the tensors to.
:return: The instance moved to the specified device.
:rtype: EquationInterface
"""
# Iterate over all attributes of the Equation
for key, val in self.__dict__.items():

# Move tensors in dictionaries to the specified device
if isinstance(val, dict):
self.__dict__[key] = {
k: v.to(device) if torch.is_tensor(v) else v
for k, v in val.items()
}

# Move tensors in lists to the specified device
elif isinstance(val, list):
self.__dict__[key] = [
v.to(device) if torch.is_tensor(v) else v for v in val
]

# Move tensor attributes to the specified device
elif torch.is_tensor(val):
self.__dict__[key] = val.to(device)

return self
5 changes: 5 additions & 0 deletions pina/equation/system_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,18 @@ def residual(self, input_, output_, params_=None):
:return: The aggregated residuals of the system of equations.
:rtype: LabelTensor
"""
# Move the equation to the input_ device
self.to(input_.device)

# Compute the residual for each equation
residual = torch.hstack(
[
equation.residual(input_, output_, params_)
for equation in self.equations
]
)

# Skip reduction if not specified
if self.reduction is None:
return residual

Expand Down