Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib_linalg] Update eye function. #481

Merged
merged 6 commits into from
Aug 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions doc/specs/stdlib_linalg.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,30 +101,58 @@ end program demo_diag5

Experimental

### Class

Pure function.

### Description

Construct the identity matrix
Construct the identity matrix.

### Syntax

`I = [[stdlib_linalg(module):eye(function)]](n)`
`I = [[stdlib_linalg(module):eye(function)]](dim1 [, dim2])`

### Arguments

`n`: Shall be a scalar of default type `integer`.
`dim1`: Shall be a scalar of default type `integer`.
This is an `intent(in)` argument.

`dim2`: Shall be a scalar of default type `integer`.
This is an `intent(in)` and `optional` argument.

### Return value

Returns the identity matrix, i.e. a square matrix with ones on the main diagonal and zeros elsewhere. The return value is of type `integer(int8)`.
Return the identity matrix, i.e. a matrix with ones on the main diagonal and zeros elsewhere. The return value is of type `integer(int8)`.
The use of `int8` was suggested to save storage.

#### Warning

Since the result of `eye` is of `integer(int8)` type, one should be careful about using it in arithmetic expressions. For example:
```fortran
real :: A(:,:)
!> Be careful
A = eye(2,2)/2 !! A == 0.0
!> Recommend
A = eye(2,2)/2.0 !! A == diag([0.5, 0.5])
```

### Example

```fortran
program demo_eye1
use stdlib_linalg, only: eye
implicit none
integer :: i(2,2)
real :: a(3,3)
A = eye(3)
real :: b(2,3) !! Matrix is non-square.
complex :: c(2,2)
I = eye(2) !! [1,0; 0,1]
A = eye(3) !! [1.0,0.0,0.0; 0.0,1.0,0.0; 0.0,0.0,1.0]
A = eye(3,3) !! [1.0,0.0,0.0; 0.0,1.0,0.0; 0.0,0.0,1.0]
B = eye(2,3) !! [1.0,0.0,0.0; 0.0,1.0,0.0]
C = eye(2,2) !! [(1.0,0.0),(0.0,0.0); (0.0,0.0),(1.0,0.0)]
C = (1.0,1.0)*eye(2,2) !! [(1.0,1.0),(0.0,0.0); (0.0,0.0),(1.0,1.0)]
end program demo_eye1
```

Expand Down
38 changes: 24 additions & 14 deletions src/stdlib_linalg.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module stdlib_linalg
!! ([Specification](../page/specs/stdlib_linalg.html))
use stdlib_kinds, only: sp, dp, qp, &
int8, int16, int32, int64
use stdlib_optval, only: optval
implicit none
private

Expand Down Expand Up @@ -82,20 +83,28 @@ module stdlib_linalg

contains

function eye(n) result(res)
!! version: experimental
!!
!! Constructs the identity matrix
!! ([Specification](../page/specs/stdlib_linalg.html#description_1))
integer, intent(in) :: n
integer(int8) :: res(n, n)
integer :: i
res = 0
do i = 1, n
res(i, i) = 1
end do
end function eye
!> Version: experimental
!>
!> Constructs the identity matrix.
!> ([Specification](../page/specs/stdlib_linalg.html#eye-construct-the-identity-matrix))
pure function eye(dim1, dim2) result(result)

integer, intent(in) :: dim1
integer, intent(in), optional :: dim2
integer(int8), allocatable :: result(:, :)

integer :: dim2_
integer :: i

dim2_ = optval(dim2, dim1)
allocate(result(dim1, dim2_))

result = 0_int8
do i = 1, min(dim1, dim2_)
result(i, i) = 1_int8
end do

end function eye

#:for k1, t1 in RCI_KINDS_TYPES
function trace_${t1[0]}$${k1}$(A) result(res)
Expand All @@ -108,4 +117,5 @@ contains
end do
end function trace_${t1[0]}$${k1}$
#:endfor
end module

end module stdlib_linalg
1 change: 1 addition & 0 deletions src/tests/Makefile.manual
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ all test clean:
$(MAKE) -f Makefile.manual --directory=stats $@
$(MAKE) -f Makefile.manual --directory=string $@
$(MAKE) -f Makefile.manual --directory=math $@
$(MAKE) -f Makefile.manual --directory=linalg $@
4 changes: 4 additions & 0 deletions src/tests/linalg/Makefile.manual
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
PROGS_SRC = test_linalg.f90


include ../Makefile.manual.test.mk
11 changes: 9 additions & 2 deletions src/tests/linalg/test_linalg.f90
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,19 @@ subroutine test_eye
integer :: i
write(*,*) "test_eye"

call check(all(eye(3,3) == diag([(1,i=1,3)])), &
msg="all(eye(3,3) == diag([(1,i=1,3)])) failed.",warn=warn)

rye = eye(3,4)
call check(sum(abs(rye(:,1:3) - diag([(1.0_sp,i=1,3)]))) < sptol, &
msg="sum(abs(rye(:,1:3) - diag([(1.0_sp,i=1,3)]))) < sptol failed", warn=warn)

call check(all(eye(5) == diag([(1,i=1,5)])), &
msg="all(eye(5) == diag([(1,i=1,5)] failed.",warn=warn)

rye = eye(6)
call check(sum(rye - diag([(1.0_sp,i=1,6)])) < sptol, &
msg="sum(rye - diag([(1.0_sp,i=1,6)])) < sptol failed.",warn=warn)
call check(sum(abs(rye - diag([(1.0_sp,i=1,6)]))) < sptol, &
msg="sum(abs(rye - diag([(1.0_sp,i=1,6)]))) < sptol failed.",warn=warn)

cye = eye(7)
call check(abs(trace(cye) - cmplx(7.0_sp,0.0_sp,kind=sp)) < sptol, &
Expand Down