@@ -54,25 +54,26 @@ def run(n, backend, datatype, benchmark_mode):
5454 if backend == "sharpy" :
5555 import sharpy as np
5656 from sharpy import fini , init , sync
57- from sharpy .numpy import fromfunction as _fromfunction
5857
5958 device = os .getenv ("SHARPY_DEVICE" , "" )
6059 create_full = partial (np .full , device = device )
61- fromfunction = partial (_fromfunction , device = device )
60+
61+ def transpose (a ):
62+ return np .permute_dims (a , [1 , 0 ])
6263
6364 all_axes = [0 , 1 ]
6465 init (False )
6566
6667 elif backend == "numpy" :
6768 import numpy as np
68- from numpy import fromfunction
6969
7070 if comm is not None :
7171 assert (
7272 comm .Get_size () == 1
7373 ), "Numpy backend only supports serial execution."
7474
7575 create_full = np .full
76+ transpose = np .transpose
7677
7778 fini = sync = lambda x = None : None
7879 all_axes = None
@@ -110,34 +111,32 @@ def run(n, backend, datatype, benchmark_mode):
110111 t_export = 0.02
111112 t_end = 1.0
112113
113- # coordinate arrays
114- x_t_2d = fromfunction (
115- lambda i , j : xmin + i * dx + dx / 2 ,
116- (nx , ny ),
117- dtype = dtype ,
118- )
119- y_t_2d = fromfunction (
120- lambda i , j : ymin + j * dy + dy / 2 ,
121- (nx , ny ),
122- dtype = dtype ,
123- )
124- x_u_2d = fromfunction (lambda i , j : xmin + i * dx , (nx + 1 , ny ), dtype = dtype )
125- y_u_2d = fromfunction (
126- lambda i , j : ymin + j * dy + dy / 2 ,
127- (nx + 1 , ny ),
128- dtype = dtype ,
129- )
130- x_v_2d = fromfunction (
131- lambda i , j : xmin + i * dx + dx / 2 ,
132- (nx , ny + 1 ),
133- dtype = dtype ,
134- )
135- y_v_2d = fromfunction (lambda i , j : ymin + j * dy , (nx , ny + 1 ), dtype = dtype )
114+ def ind_arr (shape , columns = False ):
115+ """Construct an (nx, ny) array where each row/col is an arange"""
116+ nx , ny = shape
117+ if columns :
118+ ind = np .arange (0 , nx * ny , 1 , dtype = np .int32 ) % nx
119+ ind = transpose (np .reshape (ind , (ny , nx )))
120+ else :
121+ ind = np .arange (0 , nx * ny , 1 , dtype = np .int32 ) % ny
122+ ind = np .reshape (ind , (nx , ny ))
123+ return ind .astype (dtype )
136124
125+ # coordinate arrays
137126 T_shape = (nx , ny )
138127 U_shape = (nx + 1 , ny )
139128 V_shape = (nx , ny + 1 )
140129 F_shape = (nx + 1 , ny + 1 )
130+ sync ()
131+ x_t_2d = xmin + ind_arr (T_shape , True ) * dx + dx / 2
132+ y_t_2d = ymin + ind_arr (T_shape ) * dy + dy / 2
133+
134+ x_u_2d = xmin + ind_arr (U_shape , True ) * dx
135+ y_u_2d = ymin + ind_arr (U_shape ) * dy + dy / 2
136+
137+ x_v_2d = xmin + ind_arr (V_shape , True ) * dx + dx / 2
138+ y_v_2d = ymin + ind_arr (V_shape ) * dy
139+ sync ()
141140
142141 dofs_T = int (numpy .prod (numpy .asarray (T_shape )))
143142 dofs_U = int (numpy .prod (numpy .asarray (U_shape )))
@@ -205,14 +204,6 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
205204 bath = 1.0
206205 return bath * create_full (T_shape , 1.0 , dtype )
207206
208- # inital elevation
209- u0 , v0 , e0 = exact_solution (
210- 0 , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d
211- )
212- e [:, :] = e0
213- u [:, :] = u0
214- v [:, :] = v0
215-
216207 # set bathymetry
217208 h [:, :] = bathymetry (x_t_2d , y_t_2d , lx , ly )
218209 # steady state potential energy
@@ -329,6 +320,18 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
329320 v [:, 1 :- 1 ] = v [:, 1 :- 1 ] / 3.0 + 2.0 / 3.0 * (v2 [:, 1 :- 1 ] + dt * dvdt )
330321 e [:, :] = e [:, :] / 3.0 + 2.0 / 3.0 * (e2 [:, :] + dt * dedt )
331322
323+ # warm up jit cache
324+ step (u , v , e , u1 , v1 , e1 , u2 , v2 , e2 )
325+ sync ()
326+
327+ # initial solution
328+ u0 , v0 , e0 = exact_solution (
329+ 0 , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d
330+ )
331+ e [:, :] = e0
332+ u [:, :] = u0
333+ v [:, :] = v0
334+
332335 t = 0
333336 i_export = 0
334337 next_t_export = 0
0 commit comments