Skip to content

Commit

Permalink
linalg eye: allow generalized return type and kind (#902)
Browse files Browse the repository at this point in the history
  • Loading branch information
perazz authored Dec 24, 2024
2 parents 35e7146 + 3d00b1d commit fcb5a50
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 27 deletions.
31 changes: 14 additions & 17 deletions doc/specs/stdlib_linalg.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,37 +239,34 @@ Pure function.

### Description

Construct the identity matrix.
Constructs the identity matrix.

### Syntax

`I = ` [[stdlib_linalg(module):eye(function)]] `(dim1 [, dim2])`
`I = ` [[stdlib_linalg(module):eye(function)]] `(dim1 [, dim2] [, mold])`

### Arguments

`dim1`: Shall be a scalar of default type `integer`.
This is an `intent(in)` argument.

`dim2`: Shall be a scalar of default type `integer`.
This is an `intent(in)` and `optional` argument.
- `dim1`: A scalar of type `integer`. This is an `intent(in)` argument and specifies the number of rows.
- `dim2`: A scalar of type `integer`. This is an optional `intent(in)` argument specifying the number of columns. If not provided, the matrix is square (`dim1 = dim2`).
- `mold`: A scalar of any supported `integer`, `real`, or `complex` type. This is an optional `intent(in)` argument. If provided, the returned identity matrix will have the same type and kind as `mold`. If not provided, the matrix will be of type `real(real64)` by default.

### Return value

Return the identity matrix, i.e. a matrix with ones on the main diagonal and zeros elsewhere. The return value is of type `integer(int8)`.
The use of `int8` was suggested to save storage.
Returns the identity matrix, with ones on the main diagonal and zeros elsewhere.

- By default, the return value is of type `real(real64)`, which is recommended for arithmetic safety.
- If the `mold` argument is provided, the return value will match the type and kind of `mold`, allowing for arbitrary `integer`, `real`, or `complex` return types.

#### Warning
### Example

Since the result of `eye` is of `integer(int8)` type, one should be careful about using it in arithmetic expressions. For example:
```fortran
!> Be careful
A = eye(2,2)/2 !! A == 0.0
!> Recommend
A = eye(2,2)/2.0 !! A == diag([0.5, 0.5])
!> Return default type (real64)
A = eye(2,2)/2 !! A == diag([0.5_dp, 0.5_dp])
!> Return 32-bit complex
A = eye(2,2, mold=(0.0,0.0))/2 !! A == diag([(0.5,0.5), (0.5,0.5)])
```

### Example

```fortran
{!example/linalg/example_eye1.f90!}
```
Expand Down
10 changes: 5 additions & 5 deletions example/linalg/example_eye1.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ program example_eye1
real :: a(3, 3)
real :: b(2, 3) !! Matrix is non-square.
complex :: c(2, 2)
I = eye(2) !! [1,0; 0,1]
A = eye(3) !! [1.0,0.0,0.0; 0.0,1.0,0.0; 0.0,0.0,1.0]
A = eye(3, 3) !! [1.0,0.0,0.0; 0.0,1.0,0.0; 0.0,0.0,1.0]
B = eye(2, 3) !! [1.0,0.0,0.0; 0.0,1.0,0.0]
C = eye(2, 2) !! [(1.0,0.0),(0.0,0.0); (0.0,0.0),(1.0,0.0)]
I = eye(2) !! [1,0; 0,1]
A = eye(3) !! [1.0,0.0,0.0; 0.0,1.0,0.0; 0.0,0.0,1.0]
A = eye(3, 3) !! [1.0,0.0,0.0; 0.0,1.0,0.0; 0.0,0.0,1.0]
B = eye(2, 3) !! [1.0,0.0,0.0; 0.0,1.0,0.0]
C = eye(2, 2) !! [(1.0,0.0),(0.0,0.0); (0.0,0.0),(1.0,0.0)]
C = (1.0, 1.0)*eye(2, 2) !! [(1.0,1.0),(0.0,0.0); (0.0,0.0),(1.0,1.0)]
end program example_eye1
23 changes: 18 additions & 5 deletions src/stdlib_linalg.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,16 @@ module stdlib_linalg
#:endfor
end interface

! Identity matrix
interface eye
!! version: experimental
!!
!! Constructs the identity matrix
!! ([Specification](../page/specs/stdlib_linalg.html#eye-construct-the-identity-matrix))
#:for k1, t1 in RCI_KINDS_TYPES
module procedure eye_${t1[0]}$${k1}$
#:endfor
end interface eye

! Outer product (of two vectors)
interface outer_product
Expand Down Expand Up @@ -1411,24 +1421,27 @@ contains
!>
!> Constructs the identity matrix.
!> ([Specification](../page/specs/stdlib_linalg.html#eye-construct-the-identity-matrix))
pure function eye(dim1, dim2) result(result)
#:for k1, t1 in RCI_KINDS_TYPES
pure function eye_${t1[0]}$${k1}$(dim1, dim2, mold) result(result)
integer, intent(in) :: dim1
integer, intent(in), optional :: dim2
integer(int8), allocatable :: result(:, :)
${t1}$, intent(in) #{if t1 == 'real(dp)'}#, optional #{endif}#:: mold
${t1}$, allocatable :: result(:, :)
integer :: dim2_
integer :: i
dim2_ = optval(dim2, dim1)
allocate(result(dim1, dim2_))
result = 0_int8
result = 0
do i = 1, min(dim1, dim2_)
result(i, i) = 1_int8
result(i, i) = 1
end do
end function eye
end function eye_${t1[0]}$${k1}$
#:endfor
#:for k1, t1 in RCI_KINDS_TYPES
function trace_${t1[0]}$${k1}$(A) result(res)
Expand Down

0 comments on commit fcb5a50

Please sign in to comment.