diff --git a/doc/specs/stdlib_linalg.md b/doc/specs/stdlib_linalg.md index cab16279c..bbff70a68 100644 --- a/doc/specs/stdlib_linalg.md +++ b/doc/specs/stdlib_linalg.md @@ -206,3 +206,78 @@ program demo_outer_product !A = reshape([3., 6., 9., 4., 8., 12.], [3,2]) end program demo_outer_product ``` + +## `zeros/ones` + +### Description + +`zeros` creates a rank-1 or rank-2 `array` of the given shape, filled completely with `0` `integer` type values. +`ones` creates a rank-1 or rank-2 `array` of the given shape, filled completely with `1` `integer` type values. + +### Status + +Experimental + +### Class + +Pure function. + +### Syntax + +For rank-1 array: +`result = [[stdlib_linalg(module):zeros(interface)]](dim)` +`result = [[stdlib_linalg(module):ones(interface)]](dim)` + +For rank-2 array: +`result = [[stdlib_linalg(module):zeros(interface)]](dim1, dim2)` +`result = [[stdlib_linalg(module):ones(interface)]](dim1, dim2)` + + +### Arguments + +`dim/dim1`: Shall be an `integer` type. +This is an `intent(in)` argument. + +`dim2`: Shall be an `integer` type. +This is an `intent(in)` argument. + +### Return value + +Returns a rank-1 or rank-2 `array` of the given shape, filled completely with either `0` or `1` `integer` type values. + +#### Warning + +Since the result of `ones` is of `integer` type, one should be careful about using it in arithmetic expressions. For example: +```fortran +real :: A(:,:) + +!> Be careful +A = ones(2,2)/2 !! A = 1/2 = 0.0 + +!> Recommend +A = ones(2,2)/2.0 !! A = 1/2.0 = 0.5 +``` + +### Example + +```fortran +program demo + use stdlib_linalg, only: zeros, ones + implicit none + real, allocatable :: A(:,:) + integer :: iA(2) + complex :: cA(2), cB(2,3) + + A = zeros(2,2) !! [0.0,0.0; 0.0,0.0] (Same as `reshape(spread(0,1,2*2),[2,2])`) + A = ones(4,4) !! [1.0,1.0,1.0,1.0; 1.0,1.0,1.0,1.0; 1.0,1.0,1.0,1.0; 1.0,1.0,1.0 1.0] + A = 2.0*ones(2,2) !! [2.0,2.0; 2.0,2.0] + + print *, reshape(ones(2*3*4),[2,3,4]) !! Same as `reshape(spread(1,1,2*3*4),[2,3,4])` + + iA = ones(2) !! [1,1] (Same as `spread(1,1,2)`) + cA = ones(2) !! [(1.0,0.0),(1.0,0.0)] + cA = (1.0,1.0)*ones(2) !! [(1.0,1.0),(1.0,1.0)] + cB = ones(2,3) !! [(1.0,0.0),(1.0,0.0),(1.0,0.0); (1.0,0.0),(1.0,0.0),(1.0,0.0)] + +end program demo +``` diff --git a/src/Makefile.manual b/src/Makefile.manual index a12f81255..2d036d409 100644 --- a/src/Makefile.manual +++ b/src/Makefile.manual @@ -81,7 +81,8 @@ stdlib_io.o: \ stdlib_optval.o \ stdlib_kinds.o stdlib_linalg.o: \ - stdlib_kinds.o + stdlib_kinds.o \ + stdlib_string_type.o stdlib_linalg_diag.o: \ stdlib_linalg.o \ stdlib_kinds.o diff --git a/src/stdlib_linalg.fypp b/src/stdlib_linalg.fypp index 5e0388c0b..c9e021231 100644 --- a/src/stdlib_linalg.fypp +++ b/src/stdlib_linalg.fypp @@ -4,7 +4,8 @@ module stdlib_linalg !!Provides a support for various linear algebra procedures !! ([Specification](../page/specs/stdlib_linalg.html)) use stdlib_kinds, only: sp, dp, qp, & - int8, int16, int32, int64 + int8, int16, int32, int64, lk, c_bool + use stdlib_string_type, only: string_type implicit none private @@ -12,6 +13,7 @@ module stdlib_linalg public :: eye public :: trace public :: outer_product + public :: zeros, ones interface diag !! version: experimental @@ -80,6 +82,26 @@ module stdlib_linalg #:endfor end interface outer_product + !> Version: experimental + !> + !> `ones` creates a rank-1 or rank-2 array of the given shape, + !> filled completely with `1` `integer` type values. + !> ([Specification](../page/specs/stdlib_linalg.html#zerosones)) + interface ones + procedure :: ones_1_default + procedure :: ones_2_default + end interface ones + + !> Version: experimental + !> + !> `zeros` creates a rank-1 or rank-2 array of the given shape, + !> filled completely with `0` `integer` type values. + !> ([Specification](../page/specs/stdlib_linalg.html#zerosones)) + interface zeros + procedure :: zeros_1_default + procedure :: zeros_2_default + end interface zeros + contains function eye(n) result(res) @@ -108,4 +130,41 @@ contains end do end function trace_${t1[0]}$${k1}$ #:endfor -end module + + !> `ones` creates a rank-1 array, filled completely with `1` `integer` type values. + pure function ones_1_default(dim) result(result) + integer, intent(in) :: dim + integer, allocatable :: result(:) + + allocate(result(dim), source=1) + + end function ones_1_default + + !> `ones` creates a rank-2 array, filled completely with `1` `integer` type values. + pure function ones_2_default(dim1, dim2) result(result) + integer, intent(in) :: dim1, dim2 + integer, allocatable :: result(:, :) + + allocate(result(dim1, dim2), source=1) + + end function ones_2_default + + !> `zeros` creates a rank-1 array, filled completely with `0` `integer` type values. + pure function zeros_1_default(dim) result(result) + integer, intent(in) :: dim + integer, allocatable :: result(:) + + allocate(result(dim), source=0) + + end function zeros_1_default + + !> `zeros` creates a rank-2 array, filled completely with `0` `integer` type values. + pure function zeros_2_default(dim1, dim2) result(result) + integer, intent(in) :: dim1, dim2 + integer, allocatable :: result(:, :) + + allocate(result(dim1, dim2), source=0) + + end function zeros_2_default + +end module stdlib_linalg diff --git a/src/tests/Makefile.manual b/src/tests/Makefile.manual index 7ab184016..c29170e24 100644 --- a/src/tests/Makefile.manual +++ b/src/tests/Makefile.manual @@ -11,3 +11,4 @@ all test clean: $(MAKE) -f Makefile.manual --directory=stats $@ $(MAKE) -f Makefile.manual --directory=string $@ $(MAKE) -f Makefile.manual --directory=math $@ + $(MAKE) -f Makefile.manual --directory=linalg $@ diff --git a/src/tests/linalg/CMakeLists.txt b/src/tests/linalg/CMakeLists.txt index f1098405b..4ddd4b5cf 100644 --- a/src/tests/linalg/CMakeLists.txt +++ b/src/tests/linalg/CMakeLists.txt @@ -1,2 +1,3 @@ ADDTEST(linalg) +ADDTEST(linalg_ones_zeros) diff --git a/src/tests/linalg/Makefile.manual b/src/tests/linalg/Makefile.manual new file mode 100644 index 000000000..616db4875 --- /dev/null +++ b/src/tests/linalg/Makefile.manual @@ -0,0 +1,4 @@ +PROGS_SRC = test_linalg_ones_zeros.f90 + + +include ../Makefile.manual.test.mk diff --git a/src/tests/linalg/test_linalg_ones_zeros.f90 b/src/tests/linalg/test_linalg_ones_zeros.f90 new file mode 100644 index 000000000..57e82b9a1 --- /dev/null +++ b/src/tests/linalg/test_linalg_ones_zeros.f90 @@ -0,0 +1,81 @@ +!> SPDX-Identifier: MIT +module test_linalg_ones_zeros + + use stdlib_linalg, only: zeros, ones + use stdlib_error, only: check + use stdlib_string_type + implicit none + + logical, parameter :: warn = .false. + +contains + + !> `zeros` tests + subroutine test_linalg_zeros_integer + call check(all(zeros(2) == [0, 0]), msg="all(zeros(2)==[0, 0] failed", warn=warn) + call check(all(zeros(2, 2) == reshape([0, 0, 0, 0], [2, 2])), & + msg="all(zeros(2,2)==reshape([0, 0, 0, 0],[2,2]) failed", warn=warn) + end subroutine test_linalg_zeros_integer + + subroutine test_linalg_zeros_real + real, allocatable :: rA(:), rB(:, :) + rA = zeros(2) + call check(all(rA == spread(0.0_4, 1, 2)), msg="all(rA == spread(0.0_4,1,2)) failed", warn=warn) + rB = zeros(2, 2) + call check(all(rB == reshape(spread(0.0_4, 1, 2*2), [2, 2])), & + msg="all(rB == reshape(spread(0.0_4, 1,2*2),[2,2])) failed", warn=warn) + end subroutine test_linalg_zeros_real + + subroutine test_linalg_zeros_complex + complex, allocatable :: cA(:), cB(:, :) + cA = zeros(2) + call check(all(cA == spread((0.0_4, 0.0_4), 1, 2)), msg="all(cA == spread((0.0_4,0.0_4),1,2)) failed", warn=warn) + cB = zeros(2, 2) + call check(all(cB == reshape(spread((0.0_4, 0.0_4), 1, 2*2), [2, 2])), & + msg="all(cB == reshape(spread((0.0_4,0.0_4), 1, 2*2), [2, 2])) failed", warn=warn) + end subroutine test_linalg_zeros_complex + + !> `ones` tests + subroutine test_linalg_ones_integer + call check(all(ones(2) == [1, 1]), msg="all(ones(2)==[1, 1] failed", warn=warn) + call check(all(ones(2, 2) == reshape([1, 1, 1, 1], [2, 2])), & + msg="all(ones(2,2)==reshape([1, 1, 1, 1],[2,2])) failed", warn=warn) + end subroutine test_linalg_ones_integer + + subroutine test_linalg_ones_real + real, allocatable :: rA(:), rB(:, :) + rA = ones(2) + call check(all(rA == spread(1.0_4, 1, 2)), msg="all(rA == spread(1.0_4,1,2)) failed", warn=warn) + rB = ones(2, 2) + call check(all(rB == reshape(spread(1.0_4, 1, 2*2), [2, 2])), & + msg="all(rB == reshape(spread(1.0_4, 1, 2*2), [2, 2])) failed", warn=warn) + end subroutine test_linalg_ones_real + + subroutine test_linalg_ones_complex + complex, allocatable :: cA(:), cB(:, :) + cA = ones(2) + call check(all(cA == spread((1.0_4, 0.0_4), 1, 2)), msg="all(cA == spread((1.0_4,0.0_4),1,2)) failed", warn=warn) + cB = ones(2, 2) + call check(all(cB == reshape(spread((1.0_4, 0.0_4), 1, 2*2), [2, 2])), & + msg="all(cB == reshape(spread((1.0_4, 0.0_4), 1, 2*2), [2, 2])) failed", warn=warn) + end subroutine test_linalg_ones_complex + +end module test_linalg_ones_zeros + +program tester + + use test_linalg_ones_zeros + + print *, "`zeros` tests" + call test_linalg_zeros_integer + call test_linalg_zeros_real + call test_linalg_zeros_complex + + print *, "`ones` tests" + call test_linalg_ones_integer + call test_linalg_ones_real + call test_linalg_ones_complex + + print *, "All tests in `test_linalg_ones_zeros` passed" + +end program tester