diff --git a/doc/specs/stdlib_linalg.md b/doc/specs/stdlib_linalg.md index f7569e0c6..29eb9878d 100644 --- a/doc/specs/stdlib_linalg.md +++ b/doc/specs/stdlib_linalg.md @@ -160,6 +160,36 @@ Returns a rank-2 array equal to `u v^T` (where `u, v` are considered column vect {!example/linalg/example_outer_product.f90!} ``` +## `cross_product` - Computes the cross product of two vectors + +### Status + +Experimental + +### Description + +Computes the cross product of two vectors + +### Syntax + +`c = [[stdlib_linalg(module):cross_product(interface)]](a, b)` + +### Arguments + +`a`: Shall be a rank-1 and size-3 array + +`b`: Shall be a rank-1 and size-3 array + +### Return value + +Returns a rank-1 and size-3 array which is perpendicular to both `a` and `b`. + +### Example + +```fortran +{!example/linalg/example_cross_product.f90!} +``` + ## `is_square` - Checks if a matrix is square ### Status diff --git a/example/linalg/example_cross_product.f90 b/example/linalg/example_cross_product.f90 new file mode 100644 index 000000000..e546647f4 --- /dev/null +++ b/example/linalg/example_cross_product.f90 @@ -0,0 +1,9 @@ +program demo_cross_product + use stdlib_linalg, only: cross_product + implicit none + real :: a(3), b(3), c(3) + a = [1., 0., 0.] + b = [0., 1., 0.] + c = cross_product(a, b) + !c = [0., 0., 1.] +end program demo_cross_product diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6f1fd0a18..8f512af56 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,6 +22,7 @@ set(fppFiles stdlib_linalg.fypp stdlib_linalg_diag.fypp stdlib_linalg_outer_product.fypp + stdlib_linalg_cross_product.fypp stdlib_optval.fypp stdlib_selection.fypp stdlib_sorting.fypp diff --git a/src/stdlib_linalg.fypp b/src/stdlib_linalg.fypp index bc1017f0a..cfa43d3d9 100644 --- a/src/stdlib_linalg.fypp +++ b/src/stdlib_linalg.fypp @@ -14,6 +14,7 @@ module stdlib_linalg public :: eye public :: trace public :: outer_product + public :: cross_product public :: is_square public :: is_diagonal public :: is_symmetric @@ -93,6 +94,21 @@ module stdlib_linalg end interface outer_product + ! Cross product (of two vectors) + interface cross_product + !! version: experimental + !! + !! Computes the cross product of two vectors, returning a rank-1 and size-3 array + !! ([Specification](../page/specs/stdlib_linalg.html#cross_product-computes-the-cross-product-of-two-3-d-vectors)) + #:for k1, t1 in RCI_KINDS_TYPES + pure module function cross_product_${t1[0]}$${k1}$(a, b) result(res) + ${t1}$, intent(in) :: a(3), b(3) + ${t1}$ :: res(3) + end function cross_product_${t1[0]}$${k1}$ + #:endfor + end interface cross_product + + ! Check for squareness interface is_square !! version: experimental diff --git a/src/stdlib_linalg_cross_product.fypp b/src/stdlib_linalg_cross_product.fypp new file mode 100644 index 000000000..46d9e736a --- /dev/null +++ b/src/stdlib_linalg_cross_product.fypp @@ -0,0 +1,21 @@ +#:include "common.fypp" +#:set RCI_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES + INT_KINDS_TYPES +submodule (stdlib_linalg) stdlib_linalg_cross_product + + implicit none + +contains + + #:for k1, t1 in RCI_KINDS_TYPES + pure module function cross_product_${t1[0]}$${k1}$(a, b) result(res) + ${t1}$, intent(in) :: a(3), b(3) + ${t1}$ :: res(3) + + res(1) = a(2) * b(3) - a(3) * b(2) + res(2) = a(3) * b(1) - a(1) * b(3) + res(3) = a(1) * b(2) - a(2) * b(1) + + end function cross_product_${t1[0]}$${k1}$ + #:endfor + +end submodule diff --git a/test/linalg/test_linalg.fypp b/test/linalg/test_linalg.fypp index f74cbff6b..2ffd2d7de 100644 --- a/test/linalg/test_linalg.fypp +++ b/test/linalg/test_linalg.fypp @@ -3,7 +3,7 @@ module test_linalg use testdrive, only : new_unittest, unittest_type, error_type, check, skip_test use stdlib_kinds, only: sp, dp, xdp, qp, int8, int16, int32, int64 - use stdlib_linalg, only: diag, eye, trace, outer_product + use stdlib_linalg, only: diag, eye, trace, outer_product, cross_product implicit none @@ -57,7 +57,17 @@ contains new_unittest("outer_product_int8", test_outer_product_int8), & new_unittest("outer_product_int16", test_outer_product_int16), & new_unittest("outer_product_int32", test_outer_product_int32), & - new_unittest("outer_product_int64", test_outer_product_int64) & + new_unittest("outer_product_int64", test_outer_product_int64), & + new_unittest("cross_product_rsp", test_cross_product_rsp), & + new_unittest("cross_product_rdp", test_cross_product_rdp), & + new_unittest("cross_product_rqp", test_cross_product_rqp), & + new_unittest("cross_product_csp", test_cross_product_csp), & + new_unittest("cross_product_cdp", test_cross_product_cdp), & + new_unittest("cross_product_cqp", test_cross_product_cqp), & + new_unittest("cross_product_int8", test_cross_product_int8), & + new_unittest("cross_product_int16", test_cross_product_int16), & + new_unittest("cross_product_int32", test_cross_product_int32), & + new_unittest("cross_product_int64", test_cross_product_int64) & ] end subroutine collect_linalg @@ -702,6 +712,163 @@ contains "all(abs(diff) == 0) failed.") end subroutine test_outer_product_int64 + subroutine test_cross_product_int8(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + integer(int8) :: u(n), v(n), expected(n), diff(n) + + u = [1,0,0] + v = [0,1,0] + expected = [0,0,1] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) == 0), & + "cross_product(u,v) == expected failed.") + end subroutine test_cross_product_int8 + + subroutine test_cross_product_int16(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + integer(int16) :: u(n), v(n), expected(n), diff(n) + + u = [1,0,0] + v = [0,1,0] + expected = [0,0,1] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) == 0), & + "cross_product(u,v) == expected failed.") + end subroutine test_cross_product_int16 + + subroutine test_cross_product_int32(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + integer(int32) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_int32" + u = [1,0,0] + v = [0,1,0] + expected = [0,0,1] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) == 0), & + "cross_product(u,v) == expected failed.") + end subroutine test_cross_product_int32 + + subroutine test_cross_product_int64(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + integer(int64) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_int64" + u = [1,0,0] + v = [0,1,0] + expected = [0,0,1] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) == 0), & + "cross_product(u,v) == expected failed.") + end subroutine test_cross_product_int64 + + subroutine test_cross_product_rsp(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + real(sp) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_rsp" + u = [1.1_sp,2.5_sp,2.4_sp] + v = [0.5_sp,1.5_sp,2.5_sp] + expected = [2.65_sp,-1.55_sp,0.4_sp] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) < sptol), & + "all(abs(cross_product(u,v)-expected)) < sptol failed.") + end subroutine test_cross_product_rsp + + subroutine test_cross_product_rdp(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + real(dp) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_rdp" + u = [1.1_dp,2.5_dp,2.4_dp] + v = [0.5_dp,1.5_dp,2.5_dp] + expected = [2.65_dp,-1.55_dp,0.4_dp] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) < dptol), & + "all(abs(cross_product(u,v)-expected)) < dptol failed.") + end subroutine test_cross_product_rdp + + subroutine test_cross_product_rqp(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + +#:if WITH_QP + integer, parameter :: n = 3 + real(qp) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_rqp" + u = [1.1_qp,2.5_qp,2.4_qp] + v = [0.5_qp,1.5_qp,2.5_qp] + expected = [2.65_qp,-1.55_qp,0.4_qp] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) < qptol), & + "all(abs(cross_product(u,v)-expected)) < qptol failed.") +#:else + call skip_test(error, "Quadruple precision is not enabled") +#:endif + end subroutine test_cross_product_rqp + + subroutine test_cross_product_csp(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + complex(sp) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_csp" + u = [cmplx(0,1,sp),cmplx(1,0,sp),cmplx(0,0,sp)] + v = [cmplx(1,1,sp),cmplx(0,0,sp),cmplx(1,0,sp)] + expected = [cmplx(1,0,sp),cmplx(0,-1,sp),cmplx(-1,-1,sp)] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) < sptol), & + "all(abs(cross_product(u,v)-expected)) < sptol failed.") + end subroutine test_cross_product_csp + + subroutine test_cross_product_cdp(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + complex(dp) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_cdp" + u = [cmplx(0,1,dp),cmplx(1,0,dp),cmplx(0,0,dp)] + v = [cmplx(1,1,dp),cmplx(0,0,dp),cmplx(1,0,dp)] + expected = [cmplx(1,0,dp),cmplx(0,-1,dp),cmplx(-1,-1,dp)] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) < dptol), & + "all(abs(cross_product(u,v)-expected)) < dptol failed.") + end subroutine test_cross_product_cdp + + subroutine test_cross_product_cqp(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + +#:if WITH_QP + integer, parameter :: n = 3 + complex(qp) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_cqp" + u = [cmplx(0,1,qp),cmplx(1,0,qp),cmplx(0,0,qp)] + v = [cmplx(1,1,qp),cmplx(0,0,qp),cmplx(1,0,qp)] + expected = [cmplx(1,0,qp),cmplx(0,-1,qp),cmplx(-1,-1,qp)] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) < qptol), & + "all(abs(cross_product(u,v)-expected)) < qptol failed.") +#:else + call skip_test(error, "Quadruple precision is not enabled") +#:endif + end subroutine test_cross_product_cqp pure recursive function catalan_number(n) result(value) integer, intent(in) :: n