|
6 | 6 | """ |
7 | 7 | from abc import ABC, abstractmethod |
8 | 8 | import argparse |
| 9 | +from collections import namedtuple |
9 | 10 | import re |
10 | 11 |
|
11 | 12 | FORTRAN_ERROR_NAME = 'ierror' |
@@ -378,47 +379,55 @@ def c_api_func_name_profile(fn_name, bigcount=False): |
378 | 379 | return f'P{c_api_func_name(fn_name, bigcount)}' |
379 | 380 |
|
380 | 381 |
|
381 | | -def print_header(): |
382 | | - """Print the fortran f08 file header.""" |
383 | | - print('#include "ompi/mpi/fortran/configure-fortran-output.h"') |
384 | | - print('#include "mpi-f08-rename.h"') |
| 382 | +FortranParameter = namedtuple('FortranParameter', ['type_', 'name', 'dep_name']) |
| 383 | +FortranPrototype = namedtuple('FortranPrototype', ['fn_name', 'parameters']) |
385 | 384 |
|
386 | 385 |
|
387 | | -class FortranBinding: |
388 | | - |
389 | | - def __init__(self, fname, bigcount=False): |
390 | | - self.bigcount = bigcount |
391 | | - with open(fname) as fp: |
392 | | - data = [] |
393 | | - for line in fp: |
394 | | - data.append(line.strip()) |
395 | | - data = ' '.join(data) |
396 | | - data = data.strip() |
397 | | - if PROTOTYPE_RE.match(data) is None: |
398 | | - raise PrototypeParseError('Invalid function prototype for Fortran interface') |
399 | | - start = data.index('(') |
400 | | - end = data.index(')') |
401 | | - self.fn_name = data[:start].strip() |
402 | | - |
403 | | - parameters = data[start+1:end].split(',') |
404 | | - self.parameters = [] |
405 | | - param_map = {} |
406 | | - dep_params = {} |
| 386 | +def load_prototypes(fname): |
| 387 | + """Load the prototypes from a file.""" |
| 388 | + with open(fname) as fp: |
| 389 | + prototypes = [] |
| 390 | + for i, line in enumerate(fp): |
| 391 | + lno = i + 1 |
| 392 | + if PROTOTYPE_RE.match(line) is None: |
| 393 | + raise PrototypeParseError(f'Invalid function prototype for Fortran interface on line {lno}') |
| 394 | + start = line.index('(') |
| 395 | + end = line.index(')') |
| 396 | + fn_name = line[:start].strip() |
| 397 | + parameters = line[start+1:end].split(',') |
| 398 | + parsed_parameters = [] |
407 | 399 | for param in parameters: |
408 | 400 | param = param.strip() |
409 | 401 | type_, name = param.split() |
410 | | - type_ = FortranType.get(type_) |
| 402 | + dep_name = None |
411 | 403 | # Check for 'param:other_param' parameters, indicating a |
412 | 404 | # dependency on that other parameter (such as for a count) |
413 | 405 | if ':' in name: |
414 | 406 | name, dep_name = name.split(':') |
415 | | - dep_params[name] = dep_name |
416 | | - param = type_(name, self.fn_name, bigcount=bigcount) |
417 | | - self.parameters.append(param) |
418 | | - param_map[name] = param |
419 | | - # Set dependent parameters for those that need them |
420 | | - for name, dep_name in dep_params.items(): |
421 | | - param_map[name].dep_param = param_map[dep_name] |
| 407 | + parsed_parameters.append(FortranParameter(type_, name, dep_name)) |
| 408 | + prototypes.append(FortranPrototype(fn_name, parsed_parameters)) |
| 409 | + return prototypes |
| 410 | + |
| 411 | + |
| 412 | +class FortranBinding: |
| 413 | + """Class for generating the binding for a single function.""" |
| 414 | + |
| 415 | + def __init__(self, prototype, bigcount=False): |
| 416 | + self.bigcount = bigcount |
| 417 | + self.fn_name = prototype.fn_name |
| 418 | + self.parameters = [] |
| 419 | + param_map = {} |
| 420 | + dep_params = {} |
| 421 | + for param in prototype.parameters: |
| 422 | + type_ = FortranType.get(param.type_) |
| 423 | + param_type = type_(param.name, self.fn_name, bigcount=bigcount) |
| 424 | + self.parameters.append(param_type) |
| 425 | + param_map[param.name] = param_type |
| 426 | + if param.dep_name is not None: |
| 427 | + dep_params[param.name] = param.dep_name |
| 428 | + # Set dependent parameters for those that need them |
| 429 | + for name, dep_name in dep_params.items(): |
| 430 | + param_map[name].dep_param = param_map[dep_name] |
422 | 431 |
|
423 | 432 | def _fn_name_suffix(self): |
424 | 433 | """Return a suffix for function names.""" |
@@ -474,10 +483,6 @@ def _print_fortran_interface(self): |
474 | 483 |
|
475 | 484 | def print_f_source(self): |
476 | 485 | """Output the main MPI Fortran subroutine.""" |
477 | | - print(f'! {GENERATED_MESSAGE}') |
478 | | - |
479 | | - print_header() |
480 | | - |
481 | 486 | sub_name = self.fortran_f08_name |
482 | 487 | c_func = self.c_func_name |
483 | 488 | print('subroutine', f'{sub_name}({self._param_list()},{FORTRAN_ERROR_NAME})') |
@@ -510,14 +515,6 @@ def print_f_source(self): |
510 | 515 |
|
511 | 516 | def print_c_source(self): |
512 | 517 | """Output the C source and function that the Fortran calls into.""" |
513 | | - print(f'/* {GENERATED_MESSAGE} */') |
514 | | - print('#include "ompi_config.h"') |
515 | | - print('#include "mpi.h"') |
516 | | - print('#include "ompi/errhandler/errhandler.h"') |
517 | | - print('#include "ompi/mpi/fortran/mpif-h/status-conversion.h"') |
518 | | - print('#include "ompi/mpi/fortran/base/constants.h"') |
519 | | - print('#include "ompi/mpi/fortran/base/fint_2_int.h"') |
520 | | - print('#include "ompi/request/request.h"') |
521 | 518 | parameters = [param.c_parameter() for param in self.parameters] |
522 | 519 | # Always append the integer error |
523 | 520 | parameters.append(f'MPI_Fint *{C_ERROR_NAME}') |
@@ -565,18 +562,52 @@ def print_c_source(self): |
565 | 562 | print('}') |
566 | 563 |
|
567 | 564 |
|
| 565 | +def print_f_source_header(): |
| 566 | + """Print the fortran f08 file header.""" |
| 567 | + print(f'! {GENERATED_MESSAGE}') |
| 568 | + print('#include "ompi/mpi/fortran/configure-fortran-output.h"') |
| 569 | + print('#include "mpi-f08-rename.h"') |
| 570 | + |
| 571 | + |
| 572 | +def print_c_source_header(): |
| 573 | + """Print the header of the C source file.""" |
| 574 | + print(f'/* {GENERATED_MESSAGE} */') |
| 575 | + print('#include "ompi_config.h"') |
| 576 | + print('#include "mpi.h"') |
| 577 | + print('#include "ompi/errhandler/errhandler.h"') |
| 578 | + print('#include "ompi/mpi/fortran/mpif-h/status-conversion.h"') |
| 579 | + print('#include "ompi/mpi/fortran/base/constants.h"') |
| 580 | + print('#include "ompi/mpi/fortran/base/fint_2_int.h"') |
| 581 | + print('#include "ompi/request/request.h"') |
| 582 | + |
| 583 | + |
| 584 | +def print_binding(prototype, lang, bigcount=False): |
| 585 | + """Print the binding with or without bigcount.""" |
| 586 | + binding = FortranBinding(prototype, bigcount=bigcount) |
| 587 | + if lang == 'fortran': |
| 588 | + binding.print_f_source() |
| 589 | + else: |
| 590 | + binding.print_c_source() |
| 591 | + |
| 592 | + |
568 | 593 | def main(): |
569 | 594 | parser = argparse.ArgumentParser(description='generate fortran binding files') |
570 | | - parser.add_argument('lang', choices=('fortran', 'c'), help='generate dependent files in C or Fortran') |
| 595 | + parser.add_argument('lang', choices=('fortran', 'c'), |
| 596 | + help='generate dependent files in C or Fortran') |
571 | 597 | parser.add_argument('template', help='template file to use') |
572 | | - parser.add_argument('--bigcount', action='store_true', help='generate bigcount interface for function') |
| 598 | + parser.add_argument('--bigcount', action='store_true', |
| 599 | + help='generate bigcount interface for function') |
573 | 600 | args = parser.parse_args() |
574 | 601 |
|
575 | | - binding = FortranBinding(args.template, bigcount=args.bigcount) |
| 602 | + prototypes = load_prototypes(args.template) |
576 | 603 | if args.lang == 'fortran': |
577 | | - binding.print_f_source() |
| 604 | + print_f_source_header() |
578 | 605 | else: |
579 | | - binding.print_c_source() |
| 606 | + print_c_source_header() |
| 607 | + for prototype in prototypes: |
| 608 | + print_binding(prototype, args.lang) |
| 609 | + if any(param.type_ == 'COUNT' for param in prototype.parameters): |
| 610 | + print_binding(prototype, args.lang, bigcount=True) |
580 | 611 |
|
581 | 612 |
|
582 | 613 | if __name__ == '__main__': |
|
0 commit comments