Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class EnEquivariantNetworkBlock(MessagePassing):
DOI: `<https://doi.org/10.48550/arXiv.2102.09844>`_.
"""

def __init__(
def __init__( # pylint: disable=R0913, R0917
self,
node_feature_dim,
edge_feature_dim,
Expand Down Expand Up @@ -143,7 +143,9 @@ def __init__(
func=activation,
)

def forward(self, x, pos, edge_index, edge_attr=None, vel=None):
def forward(
self, x, pos, edge_index, edge_attr=None, vel=None
): # pylint: disable=R0917
"""
Forward pass of the block, triggering the message-passing routine.

Expand All @@ -169,7 +171,9 @@ def forward(self, x, pos, edge_index, edge_attr=None, vel=None):
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, vel=vel
)

def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
def message(
self, x_i, x_j, pos_i, pos_j, edge_attr
): # pylint: disable=R0917
"""
Compute the message to be passed between nodes and edges.

Expand Down Expand Up @@ -234,7 +238,9 @@ def aggregate(self, inputs, index, ptr=None, dim_size=None):

return agg_message, agg_m_ij

def update(self, aggregated_inputs, x, pos, edge_index, vel):
def update(
self, aggregated_inputs, x, pos, edge_index, vel
): # pylint: disable=R0917
"""
Update node features, positions, and optionally velocities.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
<https://arxiv.org/abs/2401.11037>`_
"""

def __init__(
def __init__( # pylint: disable=R0913, R0917
self,
node_feature_dim,
edge_feature_dim,
Expand Down Expand Up @@ -101,7 +101,9 @@ def __init__(
flow=flow,
)

def forward(self, x, pos, vel, edge_index, edge_attr=None):
def forward( # pylint: disable=R0917
self, x, pos, vel, edge_index, edge_attr=None
):
"""
Forward pass of the Equivariant Graph Neural Operator block.

Expand Down Expand Up @@ -182,7 +184,11 @@ def _convolution(self, x, einsum_idx, real, img):
weights = torch.complex(real[..., :modes], img[..., :modes])

# Convolution in Fourier space
fourier = torch.fft.rfftn(x, dim=[0])[:modes]
# torch.fft.rfftn and irfftn are callable functions, but pylint
# incorrectly flags them as E1102 (not callable).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove these comments — the pylint disable is sufficient both to suppress the error and to trace it back if needed.

fourier = torch.fft.rfftn(x, dim=[0])[:modes] # pylint: disable=E1102
out = torch.einsum(einsum_idx, fourier, weights)

return torch.fft.irfftn(out, s=x.shape[0], dim=0)
return torch.fft.irfftn( # pylint: disable=E1102
out, s=x.shape[0], dim=0
)
8 changes: 6 additions & 2 deletions pina/model/equivariant_graph_neural_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from .block.message_passing import EquivariantGraphNeuralOperatorBlock


class EquivariantGraphNeuralOperator(torch.nn.Module):
# Disable pylint warnings for too few public methods (since this is a simple
# model class in a standard PyTorch style)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove these comments — the pylint disable is sufficient both to suppress the error and to trace it back if needed.

class EquivariantGraphNeuralOperator(torch.nn.Module): # pylint: disable=R0903
"""
Equivariant Graph Neural Operator (EGNO) for modeling 3D dynamics.

Expand All @@ -32,7 +34,9 @@ class EquivariantGraphNeuralOperator(torch.nn.Module):
<https://arxiv.org/abs/2401.11037>`_
"""

def __init__(
# Disable pylint warnings for too many arguments in init (since this is a
# model class with many configurable parameters)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove these comments — the pylint disable is sufficient both to suppress the error and to trace it back if needed.

def __init__( # pylint: disable=R0913, R0917, R0914
self,
n_egno_layers,
node_feature_dim,
Expand Down