diff --git a/doc/specs/index.md b/doc/specs/index.md index 1378fa8b8..b61b16042 100644 --- a/doc/specs/index.md +++ b/doc/specs/index.md @@ -22,6 +22,7 @@ This is an index/directory of the specifications (specs) for each new module/fea - [io](./stdlib_io.html) - Input/output helper & convenience - [kinds](./stdlib_kinds.html) - Kind parameters - [linalg](./stdlib_linalg.html) - Linear Algebra + - [linalg_state_type](./stdlib_linalg_state_type.html) - Linear Algebra state and error handling - [logger](./stdlib_logger.html) - Runtime logging system - [math](./stdlib_math.html) - General purpose mathematical functions - [optval](./stdlib_optval.html) - Fallback value for optional arguments diff --git a/doc/specs/stdlib_linalg_state_type.md b/doc/specs/stdlib_linalg_state_type.md new file mode 100644 index 000000000..54070fe4b --- /dev/null +++ b/doc/specs/stdlib_linalg_state_type.md @@ -0,0 +1,64 @@ +--- +title: linalg_state_type +--- + +# Linear Algebra -- State and Error Handling Module + +[TOC] + +## Introduction + +The `stdlib_linalg_state` module provides a derived type holding information on the +state of linear algebra operations, and procedures for expert control of linear algebra workflows. +All linear algebra procedures are engineered to support returning an optional `linalg_state_type` +variable to holds such information, as a form of expert API. If the user does not require state +information, but fatal errors are encountered during the execution of linear algebra routines, the +program will undergo a hard stop. +Instead, if the state argument is present, the program will never stop, but will return detailed error +information into the state handler. + +## Derived types provided + + +### The `linalg_state_type` derived type + +The `linalg_state_type` is defined as a derived type containing an integer error flag, and +fixed-size character strings to store an error message and the location of the error state change. +Fixed-size string storage was chosen to facilitate the compiler's memory allocation and ultimately +ensure maximum computational performance. + +A similarly named generic interface, `linalg_state_type`, is provided to allow the developer to +create diagnostic messages and raise error flags easily. The call starts with an error flag or +the location of the event, and is followed by an arbitrary list of `integer`, `real`, `complex` or +`character` variables. Numeric variables may be provided as either scalars or rank-1 (array) inputs. + +#### Type-bound procedures + +The following convenience type-bound procedures are provided: +- `print()` returns an allocatable character string containing state location, message, and error flag; +- `print_message()` returns an allocatable character string containing the state message; +- `ok()` returns a `logical` flag that is `.true.` in case of successful state (`flag==LINALG_SUCCESS`); +- `error()` returns a `logical` flag that is `.true.` in case of error state (`flag/=LINALG_SUCCESS`). + +#### Status + +Experimental + +#### Example + +```fortran +{!example/linalg/example_state1.f90!} +``` + +## Error flags provided + +The module provides the following state flags: +- `LINALG_SUCCESS`: Successful execution +- `LINALG_VALUE_ERROR`: Numerical errors (such as infinity, not-a-number, range bounds) are encountered. +- `LINALG_ERROR`: Linear Algebra errors are encountered, such as: non-converging iterations, impossible operations, etc. +- `LINALG_INTERNAL_ERROR`: Provided as a developer safeguard for internal errors that should never occur. + +## Comparison operators provided + +The module provides overloaded comparison operators for all comparisons of a `linalg_state_type` variable +with an integer error flag: `<`, `<=`, `==`, `>=`, `>`, `/=`. diff --git a/example/linalg/CMakeLists.txt b/example/linalg/CMakeLists.txt index 3f31a5574..1a5875502 100644 --- a/example/linalg/CMakeLists.txt +++ b/example/linalg/CMakeLists.txt @@ -14,3 +14,7 @@ ADD_EXAMPLE(is_symmetric) ADD_EXAMPLE(is_triangular) ADD_EXAMPLE(outer_product) ADD_EXAMPLE(trace) +ADD_EXAMPLE(state1) +ADD_EXAMPLE(state2) +ADD_EXAMPLE(blas_gemv) +ADD_EXAMPLE(lapack_getrf) diff --git a/example/linalg/example_state1.f90 b/example/linalg/example_state1.f90 new file mode 100644 index 000000000..d373a318d --- /dev/null +++ b/example/linalg/example_state1.f90 @@ -0,0 +1,20 @@ +program example_state1 + use stdlib_linalg_state, only: linalg_state_type, LINALG_SUCCESS, LINALG_VALUE_ERROR, & + operator(/=) + implicit none + type(linalg_state_type) :: err + + ! To create a state variable, we enter its integer state flag, followed by a list of variables + ! that will be automatically assembled into a formatted error message. No need to provide string formats + err = linalg_state_type(LINALG_VALUE_ERROR,'just an example with scalar ',& + 'integer=',1,'real=',2.0,'complex=',(3.0,1.0),'and array ',[1,2,3],'inputs') + + ! Print flag + print *, err%print() + + ! Check success + print *, 'Check error: ',err%error() + print *, 'Check flag : ',err /= LINALG_SUCCESS + + +end program example_state1 diff --git a/example/linalg/example_state2.f90 b/example/linalg/example_state2.f90 new file mode 100644 index 000000000..2b8b26d33 --- /dev/null +++ b/example/linalg/example_state2.f90 @@ -0,0 +1,64 @@ +program example_state2 + !! This example shows how to set a `type(linalg_state_type)` variable to process output conditions + !! out of a simple division routine. The example is meant to highlight: + !! 1) the different mechanisms that can be used to initialize the `linalg_state` variable providing + !! strings, scalars, or arrays, on input to it; + !! 2) `pure` setup of the error control + use stdlib_linalg_state, only: linalg_state_type, LINALG_VALUE_ERROR, LINALG_SUCCESS, & + linalg_error_handling + implicit none + integer :: info + type(linalg_state_type) :: err + real :: a_div_b + + ! OK + call very_simple_division(0.0,2.0,a_div_b,err) + print *, err%print() + + ! Division by zero + call very_simple_division(1.0,0.0,a_div_b,err) + print *, err%print() + + ! Out of bounds + call very_simple_division(huge(0.0),0.001,a_div_b,err) + print *, err%print() + + contains + + !> Simple division returning an integer flag (LAPACK style) + elemental subroutine very_simple_division(a,b,a_div_b,err) + real, intent(in) :: a,b + real, intent(out) :: a_div_b + type(linalg_state_type), optional, intent(out) :: err + + type(linalg_state_type) :: err0 + real, parameter :: MAXABS = huge(0.0) + character(*), parameter :: this = 'simple division' + + !> Check a + if (b==0.0) then + ! Division by zero + err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'Division by zero trying ',a,'/',b) + elseif (.not.abs(b) Location of the state change + character(len=NAME_LENGTH) :: where_at = repeat(' ',NAME_LENGTH) + + contains + + !> Cleanup + procedure :: destroy => state_destroy + + !> Print error message + procedure :: print => state_print + procedure :: print_msg => state_message + + !> State properties + procedure :: ok => state_is_ok + procedure :: error => state_is_error + + end type linalg_state_type + + !> Comparison operators + interface operator(==) + module procedure state_eq_flag + module procedure flag_eq_state + end interface + interface operator(/=) + module procedure state_neq_flag + module procedure flag_neq_state + end interface + interface operator(<) + module procedure state_lt_flag + module procedure flag_lt_state + end interface + interface operator(<=) + module procedure state_le_flag + module procedure flag_le_state + end interface + interface operator(>) + module procedure state_gt_flag + module procedure flag_gt_state + end interface + interface operator(>=) + module procedure state_ge_flag + module procedure flag_ge_state + end interface + + interface linalg_state_type + module procedure new_state + module procedure new_state_nowhere + end interface linalg_state_type + + contains + + !> Interface to print linalg state flags + pure function linalg_message(flag) result(msg) + integer(ilp),intent(in) :: flag + character(len=:),allocatable :: msg + + select case (flag) + case (LINALG_SUCCESS); msg = 'Success!' + case (LINALG_VALUE_ERROR); msg = 'Value Error' + case (LINALG_ERROR); msg = 'Algebra Error' + case (LINALG_INTERNAL_ERROR); msg = 'Internal Error' + case default; msg = 'ERROR/INVALID FLAG' + end select + + end function linalg_message + + !> Flow control: on output flag present, return it; otherwise, halt on error + pure subroutine linalg_error_handling(ierr,ierr_out) + type(linalg_state_type),intent(in) :: ierr + type(linalg_state_type),optional,intent(out) :: ierr_out + + character(len=:),allocatable :: err_msg + + if (present(ierr_out)) then + ! Return error flag + ierr_out = ierr + elseif (ierr%error()) then + err_msg = ierr%print() + error stop err_msg + end if + + end subroutine linalg_error_handling + + !> Formatted message + pure function state_message(this) result(msg) + class(linalg_state_type),intent(in) :: this + character(len=:),allocatable :: msg + + if (this%state == LINALG_SUCCESS) then + msg = 'Success!' + else + msg = linalg_message(this%state)//': '//trim(this%message) + end if + + end function state_message + + !> Produce a nice error string + pure function state_print(this) result(msg) + class(linalg_state_type),intent(in) :: this + character(len=:),allocatable :: msg + + if (len_trim(this%where_at) > 0) then + msg = '['//trim(this%where_at)//'] returned '//state_message(this) + elseif (this%error()) then + msg = 'Error encountered: '//state_message(this) + else + msg = state_message(this) + end if + + end function state_print + + !> Cleanup the object + elemental subroutine state_destroy(this) + class(linalg_state_type),intent(inout) :: this + + this%state = LINALG_SUCCESS + this%message = repeat(' ',len(this%message)) + this%where_at = repeat(' ',len(this%where_at)) + + end subroutine state_destroy + + !> Check if the current state is successful + elemental logical(lk) function state_is_ok(this) + class(linalg_state_type),intent(in) :: this + state_is_ok = this%state == LINALG_SUCCESS + end function state_is_ok + + !> Check if the current state is an error state + elemental logical(lk) function state_is_error(this) + class(linalg_state_type),intent(in) :: this + state_is_error = this%state /= LINALG_SUCCESS + end function state_is_error + + !> Compare an error state with an integer flag + elemental logical(lk) function state_eq_flag(err,flag) + type(linalg_state_type),intent(in) :: err + integer,intent(in) :: flag + state_eq_flag = err%state == flag + end function state_eq_flag + + !> Compare an integer flag with the error state + elemental logical(lk) function flag_eq_state(flag,err) + integer,intent(in) :: flag + type(linalg_state_type),intent(in) :: err + flag_eq_state = err%state == flag + end function flag_eq_state + + !> Compare the error state with an integer flag + elemental logical(lk) function state_neq_flag(err,flag) + type(linalg_state_type),intent(in) :: err + integer,intent(in) :: flag + state_neq_flag = .not. state_eq_flag(err,flag) + end function state_neq_flag + + !> Compare an integer flag with the error state + elemental logical(lk) function flag_neq_state(flag,err) + integer,intent(in) :: flag + type(linalg_state_type),intent(in) :: err + flag_neq_state = .not. state_eq_flag(err,flag) + end function flag_neq_state + + !> Compare the error state with an integer flag + elemental logical(lk) function state_lt_flag(err,flag) + type(linalg_state_type),intent(in) :: err + integer,intent(in) :: flag + state_lt_flag = err%state < flag + end function state_lt_flag + + !> Compare the error state with an integer flag + elemental logical(lk) function state_le_flag(err,flag) + type(linalg_state_type),intent(in) :: err + integer,intent(in) :: flag + state_le_flag = err%state <= flag + end function state_le_flag + + !> Compare an integer flag with the error state + elemental logical(lk) function flag_lt_state(flag,err) + integer,intent(in) :: flag + type(linalg_state_type),intent(in) :: err + flag_lt_state = err%state < flag + end function flag_lt_state + + !> Compare an integer flag with the error state + elemental logical(lk) function flag_le_state(flag,err) + integer,intent(in) :: flag + type(linalg_state_type),intent(in) :: err + flag_le_state = err%state <= flag + end function flag_le_state + + !> Compare the error state with an integer flag + elemental logical(lk) function state_gt_flag(err,flag) + type(linalg_state_type),intent(in) :: err + integer,intent(in) :: flag + state_gt_flag = err%state > flag + end function state_gt_flag + + !> Compare the error state with an integer flag + elemental logical(lk) function state_ge_flag(err,flag) + type(linalg_state_type),intent(in) :: err + integer,intent(in) :: flag + state_ge_flag = err%state >= flag + end function state_ge_flag + + !> Compare an integer flag with the error state + elemental logical(lk) function flag_gt_state(flag,err) + integer,intent(in) :: flag + type(linalg_state_type),intent(in) :: err + flag_gt_state = err%state > flag + end function flag_gt_state + + !> Compare an integer flag with the error state + elemental logical(lk) function flag_ge_state(flag,err) + integer,intent(in) :: flag + type(linalg_state_type),intent(in) :: err + flag_ge_state = err%state >= flag + end function flag_ge_state + + !> Error creation message, with location location + pure type(linalg_state_type) function new_state(where_at,flag,a1,a2,a3,a4,a5,a6,a7,a8,a9,a10, & + a11,a12,a13,a14,a15,a16,a17,a18,a19,a20) + + !> Location + character(len=*),intent(in) :: where_at + + !> Input error flag + integer,intent(in) :: flag + + !> Optional rank-agnostic arguments + class(*),optional,intent(in),dimension(..) :: a1,a2,a3,a4,a5,a6,a7,a8,a9,a10, & + a11,a12,a13,a14,a15,a16,a17,a18,a19,a20 + + !> Create state with no message + new_state = new_state_nowhere(flag,a1,a2,a3,a4,a5,a6,a7,a8,a9,a10, & + a11,a12,a13,a14,a15,a16,a17,a18,a19,a20) + + !> Add location + if (len_trim(where_at) > 0) new_state%where_at = adjustl(where_at) + + end function new_state + + !> Error creation message, from N input variables (numeric or strings) + pure type(linalg_state_type) function new_state_nowhere(flag,a1,a2,a3,a4,a5,a6,a7,a8,a9,a10, & + a11,a12,a13,a14,a15,a16,a17,a18,a19,a20) & + result(new_state) + + !> Input error flag + integer,intent(in) :: flag + + !> Optional rank-agnostic arguments + class(*),optional,intent(in),dimension(..) :: a1,a2,a3,a4,a5,a6,a7,a8,a9,a10, & + a11,a12,a13,a14,a15,a16,a17,a18,a19,a20 + + ! Init object + call new_state%destroy() + + !> Set error flag + new_state%state = flag + + !> Set chain + new_state%message = "" + call appendr(new_state%message,a1) + call appendr(new_state%message,a2) + call appendr(new_state%message,a3) + call appendr(new_state%message,a4) + call appendr(new_state%message,a5) + call appendr(new_state%message,a6) + call appendr(new_state%message,a7) + call appendr(new_state%message,a8) + call appendr(new_state%message,a9) + call appendr(new_state%message,a10) + call appendr(new_state%message,a11) + call appendr(new_state%message,a12) + call appendr(new_state%message,a13) + call appendr(new_state%message,a14) + call appendr(new_state%message,a15) + call appendr(new_state%message,a16) + call appendr(new_state%message,a17) + call appendr(new_state%message,a18) + call appendr(new_state%message,a19) + call appendr(new_state%message,a20) + + end function new_state_nowhere + + !> Append a generic value to the error flag (rank-agnostic) + pure subroutine appendr(msg,a,prefix) + class(*),optional,intent(in) :: a(..) + character(len=*),intent(inout) :: msg + character,optional,intent(in) :: prefix + + if (present(a)) then + select rank (v=>a) + rank (0) + call append (msg,v,prefix) + rank (1) + call appendv(msg,v) + rank default + msg = trim(msg)//' ' + + end select + endif + + end subroutine appendr + + ! Append a generic value to the error flag + pure subroutine append(msg,a,prefix) + class(*),intent(in) :: a + character(len=*),intent(inout) :: msg + character,optional,intent(in) :: prefix + + character(len=MSG_LENGTH) :: buffer,buffer2 + character(len=2) :: sep + integer :: ls + + ! Do not add separator if this is the first instance + sep = ' ' + ls = merge(1,0,len_trim(msg) > 0) + + if (present(prefix)) then + ls = ls + 1 + sep(ls:ls) = prefix + end if + + select type (aa => a) + + !> String type + type is (character(len=*)) + msg = trim(msg)//sep(:ls)//aa + + !> Numeric types +#:for k1, t1 in KINDS_TYPES + type is (${t1}$) + #:if 'complex' in t1 + write (buffer, FMT_REAL_${k1}$) aa%re + write (buffer2,FMT_REAL_${k1}$) aa%im + msg = trim(msg)//sep(:ls)//'('//trim(adjustl(buffer))//','//trim(adjustl(buffer2))//')' + #:else + #:if 'real' in t1 + write (buffer,FMT_REAL_${k1}$) aa + #:else + write (buffer,'(i0)') aa + #:endif + msg = trim(msg)//sep(:ls)//trim(adjustl(buffer)) + #:endif + +#:endfor + class default + msg = trim(msg)//' ' + + end select + + end subroutine append + + !> Append a generic vector to the error flag + pure subroutine appendv(msg,a) + class(*),intent(in) :: a(:) + character(len=*),intent(inout) :: msg + + integer :: j,ls + character(len=MSG_LENGTH) :: buffer,buffer2,buffer_format + character(len=2) :: sep + + if (size(a) <= 0) return + + ! Default: separate elements with one space + sep = ' ' + ls = 1 + + ! Open bracket + msg = trim(msg)//' [' + + ! Do not call append(msg(aa(j))), it will crash gfortran + select type (aa => a) + + !> Strings (cannot use string_type due to `sequence`) + type is (character(len=*)) + msg = trim(msg)//adjustl(aa(1)) + do j = 2,size(a) + msg = trim(msg)//sep(:ls)//adjustl(aa(j)) + end do + + !> Numeric types +#:for k1, t1 in KINDS_TYPES + type is (${t1}$) + #:if 'complex' in t1 + write (buffer,FMT_REAL_${k1}$) aa(1)%re + write (buffer2,FMT_REAL_${k1}$) aa(1)%im + msg = trim(msg)//'('//trim(adjustl(buffer))//','//trim(adjustl(buffer2))//')' + do j = 2,size(a) + write (buffer,FMT_REAL_${k1}$) aa(j)%re + write (buffer2,FMT_REAL_${k1}$) aa(j)%im + msg = trim(msg)//sep(:ls)//'('//trim(adjustl(buffer))//','//trim(adjustl(buffer2))//')' + end do + #:else + #:if 'real' in t1 + buffer_format = FMT_REAL_${k1}$ + #:else + buffer_format = '(i0)' + #:endif + + write (buffer,buffer_format) aa(1) + msg = trim(msg)//adjustl(buffer) + do j = 2,size(a) + write (buffer,buffer_format) aa(j) + msg = trim(msg)//sep(:ls)//adjustl(buffer) + end do + msg = trim(msg)//sep(:ls)//trim(adjustl(buffer)) + #:endif +#:endfor + class default + msg = trim(msg)//' ' + + end select + + ! Close bracket + msg = trim(msg)//']' + + end subroutine appendv + +end module stdlib_linalg_state diff --git a/test/linalg/test_linalg.fypp b/test/linalg/test_linalg.fypp index 6fdf7f17d..fcb9c6abb 100644 --- a/test/linalg/test_linalg.fypp +++ b/test/linalg/test_linalg.fypp @@ -5,6 +5,7 @@ module test_linalg use testdrive, only : new_unittest, unittest_type, error_type, check, skip_test use stdlib_kinds, only: sp, dp, xdp, qp, int8, int16, int32, int64 use stdlib_linalg, only: diag, eye, trace, outer_product, cross_product, kronecker_product + use stdlib_linalg_state, only: linalg_state_type, LINALG_SUCCESS, linalg_error_handling implicit none @@ -49,9 +50,9 @@ contains new_unittest("trace_int16", test_trace_int16), & new_unittest("trace_int32", test_trace_int32), & new_unittest("trace_int64", test_trace_int64), & - #:for k1, t1 in RCI_KINDS_TYPES + #:for k1, t1 in RCI_KINDS_TYPES new_unittest("kronecker_product_${t1[0]}$${k1}$", test_kronecker_product_${t1[0]}$${k1}$), & - #:endfor + #:endfor new_unittest("outer_product_rsp", test_outer_product_rsp), & new_unittest("outer_product_rdp", test_outer_product_rdp), & new_unittest("outer_product_rqp", test_outer_product_rqp), & @@ -71,7 +72,8 @@ contains new_unittest("cross_product_int8", test_cross_product_int8), & new_unittest("cross_product_int16", test_cross_product_int16), & new_unittest("cross_product_int32", test_cross_product_int32), & - new_unittest("cross_product_int64", test_cross_product_int64) & + new_unittest("cross_product_int64", test_cross_product_int64), & + new_unittest("state_handling", test_state_handling) & ] end subroutine collect_linalg @@ -560,7 +562,7 @@ contains #:for k1, t1 in RCI_KINDS_TYPES - subroutine test_kronecker_product_${t1[0]}$${k1}$(error) + subroutine test_kronecker_product_${t1[0]}$${k1}$(error) !> Error handling type(error_type), allocatable, intent(out) :: error integer, parameter :: m1 = 1, n1 = 2, m2 = 2, n2 = 3 @@ -593,7 +595,7 @@ contains ! Expected: C = [1*B, 2*B] = [[1,2,3, 2,4,6], [2,4,6, 4, 8, 12]] end subroutine test_kronecker_product_${t1[0]}$${k1}$ - #:endfor + #:endfor subroutine test_outer_product_rsp(error) !> Error handling @@ -911,6 +913,73 @@ contains #:endif end subroutine test_cross_product_cqp + subroutine test_state_handling(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + type(linalg_state_type) :: state,state_out + + state = linalg_state_type(LINALG_SUCCESS,' 32-bit real: ',1.0_sp) + call check(error, & + state%message==' 32-bit real: 1.00000000E+00', & + "malformed state message with 32-bit reals.") + if (allocated(error)) return + + state = linalg_state_type(LINALG_SUCCESS,' 64-bit real: ',1.0_dp) + call check(error, & + state%message==' 64-bit real: 1.0000000000000000E+000', & + "malformed state message with 64-bit reals.") + if (allocated(error)) return + +#:if WITH_QP + state = linalg_state_type(LINALG_SUCCESS,' 128-bit real: ',1.0_qp) + call check(error, & + state%message==' 128-bit real: 1.00000000000000000000000000000000000E+0000', & + "malformed state message with 128-bit reals.") + if (allocated(error)) return +#:endif + + state = linalg_state_type(LINALG_SUCCESS,' 32-bit complex: ',(1.0_sp,1.0_sp)) + call check(error, & + state%message==' 32-bit complex: (1.00000000E+00,1.00000000E+00)', & + "malformed state message with 32-bit complex: "//trim(state%message)) + if (allocated(error)) return + + state = linalg_state_type(LINALG_SUCCESS,' 64-bit complex: ',(1.0_dp,1.0_dp)) + call check(error, & + state%message==' 64-bit complex: (1.0000000000000000E+000,1.0000000000000000E+000)', & + "malformed state message with 64-bit complex.") + if (allocated(error)) return + +#:if WITH_QP + state = linalg_state_type(LINALG_SUCCESS,'128-bit complex: ',(1.0_qp,1.0_qp)) + call check(error, state%message== & + '128-bit complex: (1.00000000000000000000000000000000000E+0000,1.00000000000000000000000000000000000E+0000)', & + "malformed state message with 128-bit complex.") + +#:endif + + state = linalg_state_type(LINALG_SUCCESS,' 32-bit array: ',[(1.0_sp,0.0_sp),(0.0_sp,1.0_sp)]) + call check(error, state%message== & + ' 32-bit array: [(1.00000000E+00,0.00000000E+00) (0.00000000E+00,1.00000000E+00)]', & + "malformed state message with 32-bit real array.") + if (allocated(error)) return + + !> State flag with location + state = linalg_state_type('test_formats',LINALG_SUCCESS,' 32-bit real: ',1.0_sp) + call check(error, & + state%print()=='[test_formats] returned Success!', & + "malformed state message with 32-bit real and location.") + if (allocated(error)) return + + !> Test error handling procedure + call linalg_error_handling(state,state_out) + call check(error, state%print()==state_out%print(), & + "malformed state message on return from error handling procedure.") + + end subroutine test_state_handling + + pure recursive function catalan_number(n) result(value) integer, intent(in) :: n integer :: value