Skip to content

Commit

Permalink
Merge pull request #481 from fortran-fans/update_eye
Browse files Browse the repository at this point in the history
[stdlib_linalg] Update eye function.
  • Loading branch information
milancurcic authored Aug 23, 2021
2 parents 5089a40 + 2e4d681 commit d3b45ed
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 21 deletions.
38 changes: 33 additions & 5 deletions doc/specs/stdlib_linalg.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,30 +101,58 @@ end program demo_diag5

Experimental

### Class

Pure function.

### Description

Construct the identity matrix
Construct the identity matrix.

### Syntax

`I = [[stdlib_linalg(module):eye(function)]](n)`
`I = [[stdlib_linalg(module):eye(function)]](dim1 [, dim2])`

### Arguments

`n`: Shall be a scalar of default type `integer`.
`dim1`: Shall be a scalar of default type `integer`.
This is an `intent(in)` argument.

`dim2`: Shall be a scalar of default type `integer`.
This is an `intent(in)` and `optional` argument.

### Return value

Returns the identity matrix, i.e. a square matrix with ones on the main diagonal and zeros elsewhere. The return value is of type `integer(int8)`.
Return the identity matrix, i.e. a matrix with ones on the main diagonal and zeros elsewhere. The return value is of type `integer(int8)`.
The use of `int8` was suggested to save storage.

#### Warning

Since the result of `eye` is of `integer(int8)` type, one should be careful about using it in arithmetic expressions. For example:
```fortran
real :: A(:,:)
!> Be careful
A = eye(2,2)/2 !! A == 0.0
!> Recommend
A = eye(2,2)/2.0 !! A == diag([0.5, 0.5])
```

### Example

```fortran
program demo_eye1
use stdlib_linalg, only: eye
implicit none
integer :: i(2,2)
real :: a(3,3)
A = eye(3)
real :: b(2,3) !! Matrix is non-square.
complex :: c(2,2)
I = eye(2) !! [1,0; 0,1]
A = eye(3) !! [1.0,0.0,0.0; 0.0,1.0,0.0; 0.0,0.0,1.0]
A = eye(3,3) !! [1.0,0.0,0.0; 0.0,1.0,0.0; 0.0,0.0,1.0]
B = eye(2,3) !! [1.0,0.0,0.0; 0.0,1.0,0.0]
C = eye(2,2) !! [(1.0,0.0),(0.0,0.0); (0.0,0.0),(1.0,0.0)]
C = (1.0,1.0)*eye(2,2) !! [(1.0,1.0),(0.0,0.0); (0.0,0.0),(1.0,1.0)]
end program demo_eye1
```

Expand Down
38 changes: 24 additions & 14 deletions src/stdlib_linalg.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module stdlib_linalg
!! ([Specification](../page/specs/stdlib_linalg.html))
use stdlib_kinds, only: sp, dp, qp, &
int8, int16, int32, int64
use stdlib_optval, only: optval
implicit none
private

Expand Down Expand Up @@ -82,20 +83,28 @@ module stdlib_linalg

contains

function eye(n) result(res)
!! version: experimental
!!
!! Constructs the identity matrix
!! ([Specification](../page/specs/stdlib_linalg.html#description_1))
integer, intent(in) :: n
integer(int8) :: res(n, n)
integer :: i
res = 0
do i = 1, n
res(i, i) = 1
end do
end function eye
!> Version: experimental
!>
!> Constructs the identity matrix.
!> ([Specification](../page/specs/stdlib_linalg.html#eye-construct-the-identity-matrix))
pure function eye(dim1, dim2) result(result)

integer, intent(in) :: dim1
integer, intent(in), optional :: dim2
integer(int8), allocatable :: result(:, :)

integer :: dim2_
integer :: i

dim2_ = optval(dim2, dim1)
allocate(result(dim1, dim2_))

result = 0_int8
do i = 1, min(dim1, dim2_)
result(i, i) = 1_int8
end do

end function eye

#:for k1, t1 in RCI_KINDS_TYPES
function trace_${t1[0]}$${k1}$(A) result(res)
Expand All @@ -108,4 +117,5 @@ contains
end do
end function trace_${t1[0]}$${k1}$
#:endfor
end module

end module stdlib_linalg
1 change: 1 addition & 0 deletions src/tests/Makefile.manual
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ all test clean:
$(MAKE) -f Makefile.manual --directory=stats $@
$(MAKE) -f Makefile.manual --directory=string $@
$(MAKE) -f Makefile.manual --directory=math $@
$(MAKE) -f Makefile.manual --directory=linalg $@
4 changes: 4 additions & 0 deletions src/tests/linalg/Makefile.manual
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
PROGS_SRC = test_linalg.f90


include ../Makefile.manual.test.mk
11 changes: 9 additions & 2 deletions src/tests/linalg/test_linalg.f90
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,19 @@ subroutine test_eye
integer :: i
write(*,*) "test_eye"

call check(all(eye(3,3) == diag([(1,i=1,3)])), &
msg="all(eye(3,3) == diag([(1,i=1,3)])) failed.",warn=warn)

rye = eye(3,4)
call check(sum(abs(rye(:,1:3) - diag([(1.0_sp,i=1,3)]))) < sptol, &
msg="sum(abs(rye(:,1:3) - diag([(1.0_sp,i=1,3)]))) < sptol failed", warn=warn)

call check(all(eye(5) == diag([(1,i=1,5)])), &
msg="all(eye(5) == diag([(1,i=1,5)] failed.",warn=warn)

rye = eye(6)
call check(sum(rye - diag([(1.0_sp,i=1,6)])) < sptol, &
msg="sum(rye - diag([(1.0_sp,i=1,6)])) < sptol failed.",warn=warn)
call check(sum(abs(rye - diag([(1.0_sp,i=1,6)]))) < sptol, &
msg="sum(abs(rye - diag([(1.0_sp,i=1,6)]))) < sptol failed.",warn=warn)

cye = eye(7)
call check(abs(trace(cye) - cmplx(7.0_sp,0.0_sp,kind=sp)) < sptol, &
Expand Down

0 comments on commit d3b45ed

Please sign in to comment.