Skip to content

Commit fca3db7

Browse files
fix bugs for helmholtz and advection (#686)
1 parent 64930c4 commit fca3db7

File tree

3 files changed

+9
-10
lines changed

3 files changed

+9
-10
lines changed

pina/equation/equation_factory.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,19 +239,19 @@ def equation(input_, output_):
239239
)
240240

241241
# Ensure consistency of c length
242-
if len(self.c) != (len(input_lbl) - 1) and len(self.c) > 1:
242+
if self.c.shape[-1] != len(input_lbl) - 1 and self.c.shape[-1] > 1:
243243
raise ValueError(
244244
"If 'c' is passed as a list, its length must be equal to "
245245
"the number of spatial dimensions."
246246
)
247247

248248
# Repeat c to ensure consistent shape for advection
249-
self.c = self.c.repeat(output_.shape[0], 1)
250-
if self.c.shape[1] != (len(input_lbl) - 1):
251-
self.c = self.c.repeat(1, len(input_lbl) - 1)
249+
c = self.c.repeat(output_.shape[0], 1)
250+
if c.shape[1] != (len(input_lbl) - 1):
251+
c = c.repeat(1, len(input_lbl) - 1)
252252

253253
# Add a dimension to c for the following operations
254-
self.c = self.c.unsqueeze(-1)
254+
c = c.unsqueeze(-1)
255255

256256
# Compute the time derivative and the spatial gradient
257257
time_der = grad(output_, input_, components=None, d="t")
@@ -262,7 +262,7 @@ def equation(input_, output_):
262262
tmp = tmp.transpose(-1, -2)
263263

264264
# Compute advection term
265-
adv = (tmp * self.c).sum(dim=tmp.tensor.ndim - 2)
265+
adv = (tmp * c).sum(dim=tmp.tensor.ndim - 2)
266266

267267
return time_der + adv
268268

pina/problem/zoo/helmholtz.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,10 @@ def __init__(self, alpha=3.0):
4848
:type alpha: float | int
4949
"""
5050
super().__init__()
51-
52-
self.alpha = alpha
5351
check_consistency(alpha, (int, float))
52+
self.alpha = alpha
5453

55-
def forcing_term(self, input_):
54+
def forcing_term(input_):
5655
"""
5756
Implementation of the forcing term.
5857
"""

tests/test_equation/test_equation_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_advection_equation(c):
104104

105105
# Should fail if c is a list and its length != spatial dimension
106106
with pytest.raises(ValueError):
107-
Advection([1, 2, 3])
107+
equation = Advection([1, 2, 3])
108108
residual = equation.residual(pts, u)
109109

110110

0 commit comments

Comments
 (0)