16
16
17
17
class FortranType (ABC ):
18
18
19
- def __init__ (self , name , ** kwargs ):
19
+ def __init__ (self , name , bigcount = False , ** kwargs ):
20
20
self .name = name
21
- self .bigcount = False
21
+ self .bigcount = bigcount
22
22
23
23
TYPES = {}
24
24
@@ -93,6 +93,9 @@ def declare(self):
93
93
else :
94
94
return f'INTEGER, INTENT(IN) :: { self .name } '
95
95
96
+ def use (self ):
97
+ return [('mpi_f08_types' , 'MPI_COUNT_KIND' )]
98
+
96
99
def c_parameter (self ):
97
100
type_ = 'MPI_Count' if self .bigcount else 'MPI_Fint'
98
101
return f'{ type_ } *{ self .name } '
@@ -209,20 +212,6 @@ class PrototypeParseError(Exception):
209
212
"""Thrown when a parsing error is encountered."""
210
213
211
214
212
- def fortran_f08_name (base_name ):
213
- """Produce the final f08 name from base_name."""
214
- return f'MPI_{ base_name .capitalize ()} _f08'
215
-
216
-
217
- def c_func_name (base_name ):
218
- """Produce the final C func name from base_name."""
219
- return f'ompi_{ base_name } _wrapper_f08'
220
-
221
-
222
- def c_api_func_name (base_name ):
223
- """Produce the actual MPI API function name to call into."""
224
- return f'PMPI_{ base_name .capitalize ()} '
225
-
226
215
227
216
def print_header ():
228
217
"""Print the fortran f08 file header."""
@@ -232,7 +221,8 @@ def print_header():
232
221
233
222
class FortranBinding :
234
223
235
- def __init__ (self , fname ):
224
+ def __init__ (self , fname , bigcount = False ):
225
+ self .bigcount = bigcount
236
226
with open (fname ) as fp :
237
227
data = []
238
228
for line in fp :
@@ -251,7 +241,26 @@ def __init__(self, fname):
251
241
type_ , name = param .split ()
252
242
type_ = FortranType .get (type_ )
253
243
indent = ' '
254
- self .parameters .append (type_ (name ))
244
+ self .parameters .append (type_ (name , bigcount = bigcount ))
245
+
246
+ def _fn_name_suffix (self ):
247
+ """Return a suffix for function names."""
248
+ return '_c' if self .bigcount else ''
249
+
250
+ @property
251
+ def fortran_f08_name (self ):
252
+ """Produce the final f08 name from base_name."""
253
+ return f'MPI_{ self .fn_name .capitalize ()} _f08{ self ._fn_name_suffix ()} '
254
+
255
+ @property
256
+ def c_func_name (self ):
257
+ """Produce the final C func name from base_name."""
258
+ return f'ompi_{ self .fn_name } _wrapper_f08{ self ._fn_name_suffix ()} '
259
+
260
+ @property
261
+ def c_api_func_name (self ):
262
+ """Produce the actual MPI API function name to call into."""
263
+ return f'PMPI_{ self .fn_name .capitalize ()} { self ._fn_name_suffix ()} '
255
264
256
265
def _param_list (self ):
257
266
return ',' .join (type_ .name for type_ in self .parameters )
@@ -277,7 +286,7 @@ def _use_stmts(self):
277
286
278
287
def _print_fortran_interface (self ):
279
288
"""Output the C subroutine binding for the Fortran code."""
280
- name = c_func_name ( self .fn_name )
289
+ name = self .c_func_name
281
290
print (' interface' )
282
291
print (f' subroutine { name } ({ self ._param_list ()} ,ierror) &' )
283
292
print (f' BIND(C, name="{ name } ")' )
@@ -297,8 +306,8 @@ def print_f_source(self):
297
306
298
307
print_header ()
299
308
300
- sub_name = fortran_f08_name ( self .fn_name )
301
- c_func = c_func_name ( self .fn_name )
309
+ sub_name = self .fortran_f08_name
310
+ c_func = self .c_func_name
302
311
print ('subroutine' , f'{ sub_name } ({ self ._param_list ()} ,ierror)' )
303
312
# Use statements
304
313
use_stmts = self ._use_stmts ()
@@ -335,7 +344,7 @@ def print_c_source(self):
335
344
print ('#include "ompi/mpi/fortran/mpif-h/status-conversion.h"' )
336
345
print ('#include "ompi/mpi/fortran/base/constants.h"' )
337
346
print ('#include "ompi/mpi/fortran/base/fint_2_int.h"' )
338
- c_func = c_func_name ( self .fn_name )
347
+ c_func = self .c_func_name
339
348
parameters = [param .c_parameter () for param in self .parameters ]
340
349
# Always append the integer error
341
350
parameters .append ('MPI_Fint *ierr' )
@@ -348,25 +357,25 @@ def print_c_source(self):
348
357
for param in self .parameters :
349
358
for line in param .c_prepare ():
350
359
print (f' { line } ' )
351
- c_api_func = c_api_func_name ( self .fn_name )
360
+ c_api_func = self .c_api_func_name
352
361
arguments = [param .c_argument () for param in self .parameters ]
353
362
arguments = ', ' .join (arguments )
354
363
print (f' { C_ERROR_TEMP_NAME } = { c_api_func } ({ arguments } );' )
364
+ print (f' *ierr = OMPI_INT_2_FINT({ C_ERROR_TEMP_NAME } );' )
355
365
for param in self .parameters :
356
366
for line in param .c_post ():
357
367
print (f' { line } ' )
358
- # TODO: Is this NULL check necessary for mpi_f08?
359
- print (f' if (NULL != ierr) *ierr = OMPI_INT_2_FINT({ C_ERROR_TEMP_NAME } );' )
360
368
print ('}' )
361
369
362
370
363
371
def main ():
364
372
parser = argparse .ArgumentParser (description = 'generate fortran binding files' )
365
373
parser .add_argument ('lang' , choices = ('fortran' , 'c' ), help = 'generate dependent files in C or Fortran' )
366
374
parser .add_argument ('template' , help = 'template file to use' )
375
+ parser .add_argument ('--bigcount' , action = 'store_true' , help = 'generate bigcount interface for function' )
367
376
args = parser .parse_args ()
368
377
369
- binding = FortranBinding (args .template )
378
+ binding = FortranBinding (args .template , bigcount = args . bigcount )
370
379
if args .lang == 'fortran' :
371
380
binding .print_f_source ()
372
381
else :
0 commit comments