diff --git a/doc/specs/stdlib_linalg.md b/doc/specs/stdlib_linalg.md index cab16279c..7ca3b9198 100644 --- a/doc/specs/stdlib_linalg.md +++ b/doc/specs/stdlib_linalg.md @@ -101,21 +101,41 @@ end program demo_diag5 Experimental +### Class + +Pure function. + ### Description -Construct the identity matrix +Construct the identity matrix. ### Syntax -`I = [[stdlib_linalg(module):eye(function)]](n)` +`I = [[stdlib_linalg(module):eye(function)]](dim1 [, dim2])` ### Arguments -`n`: Shall be a scalar of default type `integer`. +`dim1`: Shall be a scalar of default type `integer`. +This is an `intent(in)` argument. + +`dim2`: Shall be a scalar of default type `integer`. +This is an `intent(in)` and `optional` argument. ### Return value -Returns the identity matrix, i.e. a square matrix with ones on the main diagonal and zeros elsewhere. The return value is of type `integer(int8)`. +Return the identity matrix, i.e. a matrix with ones on the main diagonal and zeros elsewhere. The return value is of type `integer(int8)`. +The use of `int8` was suggested to save storage. + +#### Warning + +Since the result of `eye` is of `integer(int8)` type, one should be careful about using it in arithmetic expressions. For example: +```fortran +real :: A(:,:) +!> Be careful +A = eye(2,2)/2 !! A == 0.0 +!> Recommend +A = eye(2,2)/2.0 !! A == diag([0.5, 0.5]) +``` ### Example @@ -123,8 +143,16 @@ Returns the identity matrix, i.e. a square matrix with ones on the main diagonal program demo_eye1 use stdlib_linalg, only: eye implicit none + integer :: i(2,2) real :: a(3,3) - A = eye(3) + real :: b(2,3) !! Matrix is non-square. + complex :: c(2,2) + I = eye(2) !! [1,0; 0,1] + A = eye(3) !! [1.0,0.0,0.0; 0.0,1.0,0.0; 0.0,0.0,1.0] + A = eye(3,3) !! [1.0,0.0,0.0; 0.0,1.0,0.0; 0.0,0.0,1.0] + B = eye(2,3) !! [1.0,0.0,0.0; 0.0,1.0,0.0] + C = eye(2,2) !! [(1.0,0.0),(0.0,0.0); (0.0,0.0),(1.0,0.0)] + C = (1.0,1.0)*eye(2,2) !! [(1.0,1.0),(0.0,0.0); (0.0,0.0),(1.0,1.0)] end program demo_eye1 ``` diff --git a/src/stdlib_linalg.fypp b/src/stdlib_linalg.fypp index 5e0388c0b..3faaeb9af 100644 --- a/src/stdlib_linalg.fypp +++ b/src/stdlib_linalg.fypp @@ -5,6 +5,7 @@ module stdlib_linalg !! ([Specification](../page/specs/stdlib_linalg.html)) use stdlib_kinds, only: sp, dp, qp, & int8, int16, int32, int64 + use stdlib_optval, only: optval implicit none private @@ -82,20 +83,28 @@ module stdlib_linalg contains - function eye(n) result(res) - !! version: experimental - !! - !! Constructs the identity matrix - !! ([Specification](../page/specs/stdlib_linalg.html#description_1)) - integer, intent(in) :: n - integer(int8) :: res(n, n) - integer :: i - res = 0 - do i = 1, n - res(i, i) = 1 - end do - end function eye + !> Version: experimental + !> + !> Constructs the identity matrix. + !> ([Specification](../page/specs/stdlib_linalg.html#eye-construct-the-identity-matrix)) + pure function eye(dim1, dim2) result(result) + + integer, intent(in) :: dim1 + integer, intent(in), optional :: dim2 + integer(int8), allocatable :: result(:, :) + + integer :: dim2_ + integer :: i + dim2_ = optval(dim2, dim1) + allocate(result(dim1, dim2_)) + + result = 0_int8 + do i = 1, min(dim1, dim2_) + result(i, i) = 1_int8 + end do + + end function eye #:for k1, t1 in RCI_KINDS_TYPES function trace_${t1[0]}$${k1}$(A) result(res) @@ -108,4 +117,5 @@ contains end do end function trace_${t1[0]}$${k1}$ #:endfor -end module + +end module stdlib_linalg diff --git a/src/tests/Makefile.manual b/src/tests/Makefile.manual index 7ab184016..c29170e24 100644 --- a/src/tests/Makefile.manual +++ b/src/tests/Makefile.manual @@ -11,3 +11,4 @@ all test clean: $(MAKE) -f Makefile.manual --directory=stats $@ $(MAKE) -f Makefile.manual --directory=string $@ $(MAKE) -f Makefile.manual --directory=math $@ + $(MAKE) -f Makefile.manual --directory=linalg $@ diff --git a/src/tests/linalg/Makefile.manual b/src/tests/linalg/Makefile.manual new file mode 100644 index 000000000..a95f600c7 --- /dev/null +++ b/src/tests/linalg/Makefile.manual @@ -0,0 +1,4 @@ +PROGS_SRC = test_linalg.f90 + + +include ../Makefile.manual.test.mk diff --git a/src/tests/linalg/test_linalg.f90 b/src/tests/linalg/test_linalg.f90 index cc8d0db68..7583f0585 100644 --- a/src/tests/linalg/test_linalg.f90 +++ b/src/tests/linalg/test_linalg.f90 @@ -81,12 +81,19 @@ subroutine test_eye integer :: i write(*,*) "test_eye" + call check(all(eye(3,3) == diag([(1,i=1,3)])), & + msg="all(eye(3,3) == diag([(1,i=1,3)])) failed.",warn=warn) + + rye = eye(3,4) + call check(sum(abs(rye(:,1:3) - diag([(1.0_sp,i=1,3)]))) < sptol, & + msg="sum(abs(rye(:,1:3) - diag([(1.0_sp,i=1,3)]))) < sptol failed", warn=warn) + call check(all(eye(5) == diag([(1,i=1,5)])), & msg="all(eye(5) == diag([(1,i=1,5)] failed.",warn=warn) rye = eye(6) - call check(sum(rye - diag([(1.0_sp,i=1,6)])) < sptol, & - msg="sum(rye - diag([(1.0_sp,i=1,6)])) < sptol failed.",warn=warn) + call check(sum(abs(rye - diag([(1.0_sp,i=1,6)]))) < sptol, & + msg="sum(abs(rye - diag([(1.0_sp,i=1,6)]))) < sptol failed.",warn=warn) cye = eye(7) call check(abs(trace(cye) - cmplx(7.0_sp,0.0_sp,kind=sp)) < sptol, &