|
6 | 6 | """
|
7 | 7 | from abc import ABC, abstractmethod
|
8 | 8 | import argparse
|
| 9 | +from collections import namedtuple |
9 | 10 | import re
|
10 | 11 |
|
11 | 12 | FORTRAN_ERROR_NAME = 'ierror'
|
@@ -378,47 +379,55 @@ def c_api_func_name_profile(fn_name, bigcount=False):
|
378 | 379 | return f'P{c_api_func_name(fn_name, bigcount)}'
|
379 | 380 |
|
380 | 381 |
|
381 |
| -def print_header(): |
382 |
| - """Print the fortran f08 file header.""" |
383 |
| - print('#include "ompi/mpi/fortran/configure-fortran-output.h"') |
384 |
| - print('#include "mpi-f08-rename.h"') |
| 382 | +FortranParameter = namedtuple('FortranParameter', ['type_', 'name', 'dep_name']) |
| 383 | +FortranPrototype = namedtuple('FortranPrototype', ['fn_name', 'parameters']) |
385 | 384 |
|
386 | 385 |
|
387 |
| -class FortranBinding: |
388 |
| - |
389 |
| - def __init__(self, fname, bigcount=False): |
390 |
| - self.bigcount = bigcount |
391 |
| - with open(fname) as fp: |
392 |
| - data = [] |
393 |
| - for line in fp: |
394 |
| - data.append(line.strip()) |
395 |
| - data = ' '.join(data) |
396 |
| - data = data.strip() |
397 |
| - if PROTOTYPE_RE.match(data) is None: |
398 |
| - raise PrototypeParseError('Invalid function prototype for Fortran interface') |
399 |
| - start = data.index('(') |
400 |
| - end = data.index(')') |
401 |
| - self.fn_name = data[:start].strip() |
402 |
| - |
403 |
| - parameters = data[start+1:end].split(',') |
404 |
| - self.parameters = [] |
405 |
| - param_map = {} |
406 |
| - dep_params = {} |
| 386 | +def load_prototypes(fname): |
| 387 | + """Load the prototypes from a file.""" |
| 388 | + with open(fname) as fp: |
| 389 | + prototypes = [] |
| 390 | + for i, line in enumerate(fp): |
| 391 | + lno = i + 1 |
| 392 | + if PROTOTYPE_RE.match(line) is None: |
| 393 | + raise PrototypeParseError(f'Invalid function prototype for Fortran interface on line {lno}') |
| 394 | + start = line.index('(') |
| 395 | + end = line.index(')') |
| 396 | + fn_name = line[:start].strip() |
| 397 | + parameters = line[start+1:end].split(',') |
| 398 | + parsed_parameters = [] |
407 | 399 | for param in parameters:
|
408 | 400 | param = param.strip()
|
409 | 401 | type_, name = param.split()
|
410 |
| - type_ = FortranType.get(type_) |
| 402 | + dep_name = None |
411 | 403 | # Check for 'param:other_param' parameters, indicating a
|
412 | 404 | # dependency on that other parameter (such as for a count)
|
413 | 405 | if ':' in name:
|
414 | 406 | name, dep_name = name.split(':')
|
415 |
| - dep_params[name] = dep_name |
416 |
| - param = type_(name, self.fn_name, bigcount=bigcount) |
417 |
| - self.parameters.append(param) |
418 |
| - param_map[name] = param |
419 |
| - # Set dependent parameters for those that need them |
420 |
| - for name, dep_name in dep_params.items(): |
421 |
| - param_map[name].dep_param = param_map[dep_name] |
| 407 | + parsed_parameters.append(FortranParameter(type_, name, dep_name)) |
| 408 | + prototypes.append(FortranPrototype(fn_name, parsed_parameters)) |
| 409 | + return prototypes |
| 410 | + |
| 411 | + |
| 412 | +class FortranBinding: |
| 413 | + """Class for generating the binding for a single function.""" |
| 414 | + |
| 415 | + def __init__(self, prototype, bigcount=False): |
| 416 | + self.bigcount = bigcount |
| 417 | + self.fn_name = prototype.fn_name |
| 418 | + self.parameters = [] |
| 419 | + param_map = {} |
| 420 | + dep_params = {} |
| 421 | + for param in prototype.parameters: |
| 422 | + type_ = FortranType.get(param.type_) |
| 423 | + param_type = type_(param.name, self.fn_name, bigcount=bigcount) |
| 424 | + self.parameters.append(param_type) |
| 425 | + param_map[param.name] = param_type |
| 426 | + if param.dep_name is not None: |
| 427 | + dep_params[param.name] = param.dep_name |
| 428 | + # Set dependent parameters for those that need them |
| 429 | + for name, dep_name in dep_params.items(): |
| 430 | + param_map[name].dep_param = param_map[dep_name] |
422 | 431 |
|
423 | 432 | def _fn_name_suffix(self):
|
424 | 433 | """Return a suffix for function names."""
|
@@ -474,10 +483,6 @@ def _print_fortran_interface(self):
|
474 | 483 |
|
475 | 484 | def print_f_source(self):
|
476 | 485 | """Output the main MPI Fortran subroutine."""
|
477 |
| - print(f'! {GENERATED_MESSAGE}') |
478 |
| - |
479 |
| - print_header() |
480 |
| - |
481 | 486 | sub_name = self.fortran_f08_name
|
482 | 487 | c_func = self.c_func_name
|
483 | 488 | print('subroutine', f'{sub_name}({self._param_list()},{FORTRAN_ERROR_NAME})')
|
@@ -510,14 +515,6 @@ def print_f_source(self):
|
510 | 515 |
|
511 | 516 | def print_c_source(self):
|
512 | 517 | """Output the C source and function that the Fortran calls into."""
|
513 |
| - print(f'/* {GENERATED_MESSAGE} */') |
514 |
| - print('#include "ompi_config.h"') |
515 |
| - print('#include "mpi.h"') |
516 |
| - print('#include "ompi/errhandler/errhandler.h"') |
517 |
| - print('#include "ompi/mpi/fortran/mpif-h/status-conversion.h"') |
518 |
| - print('#include "ompi/mpi/fortran/base/constants.h"') |
519 |
| - print('#include "ompi/mpi/fortran/base/fint_2_int.h"') |
520 |
| - print('#include "ompi/request/request.h"') |
521 | 518 | parameters = [param.c_parameter() for param in self.parameters]
|
522 | 519 | # Always append the integer error
|
523 | 520 | parameters.append(f'MPI_Fint *{C_ERROR_NAME}')
|
@@ -565,18 +562,52 @@ def print_c_source(self):
|
565 | 562 | print('}')
|
566 | 563 |
|
567 | 564 |
|
| 565 | +def print_f_source_header(): |
| 566 | + """Print the fortran f08 file header.""" |
| 567 | + print(f'! {GENERATED_MESSAGE}') |
| 568 | + print('#include "ompi/mpi/fortran/configure-fortran-output.h"') |
| 569 | + print('#include "mpi-f08-rename.h"') |
| 570 | + |
| 571 | + |
| 572 | +def print_c_source_header(): |
| 573 | + """Print the header of the C source file.""" |
| 574 | + print(f'/* {GENERATED_MESSAGE} */') |
| 575 | + print('#include "ompi_config.h"') |
| 576 | + print('#include "mpi.h"') |
| 577 | + print('#include "ompi/errhandler/errhandler.h"') |
| 578 | + print('#include "ompi/mpi/fortran/mpif-h/status-conversion.h"') |
| 579 | + print('#include "ompi/mpi/fortran/base/constants.h"') |
| 580 | + print('#include "ompi/mpi/fortran/base/fint_2_int.h"') |
| 581 | + print('#include "ompi/request/request.h"') |
| 582 | + |
| 583 | + |
| 584 | +def print_binding(prototype, lang, bigcount=False): |
| 585 | + """Print the binding with or without bigcount.""" |
| 586 | + binding = FortranBinding(prototype, bigcount=bigcount) |
| 587 | + if lang == 'fortran': |
| 588 | + binding.print_f_source() |
| 589 | + else: |
| 590 | + binding.print_c_source() |
| 591 | + |
| 592 | + |
568 | 593 | def main():
|
569 | 594 | parser = argparse.ArgumentParser(description='generate fortran binding files')
|
570 |
| - parser.add_argument('lang', choices=('fortran', 'c'), help='generate dependent files in C or Fortran') |
| 595 | + parser.add_argument('lang', choices=('fortran', 'c'), |
| 596 | + help='generate dependent files in C or Fortran') |
571 | 597 | parser.add_argument('template', help='template file to use')
|
572 |
| - parser.add_argument('--bigcount', action='store_true', help='generate bigcount interface for function') |
| 598 | + parser.add_argument('--bigcount', action='store_true', |
| 599 | + help='generate bigcount interface for function') |
573 | 600 | args = parser.parse_args()
|
574 | 601 |
|
575 |
| - binding = FortranBinding(args.template, bigcount=args.bigcount) |
| 602 | + prototypes = load_prototypes(args.template) |
576 | 603 | if args.lang == 'fortran':
|
577 |
| - binding.print_f_source() |
| 604 | + print_f_source_header() |
578 | 605 | else:
|
579 |
| - binding.print_c_source() |
| 606 | + print_c_source_header() |
| 607 | + for prototype in prototypes: |
| 608 | + print_binding(prototype, args.lang) |
| 609 | + if any(param.type_ == 'COUNT' for param in prototype.parameters): |
| 610 | + print_binding(prototype, args.lang, bigcount=True) |
580 | 611 |
|
581 | 612 |
|
582 | 613 | if __name__ == '__main__':
|
|
0 commit comments