Skip to content

Commit

Permalink
Merge pull request #764 from Ivanou34/meshgrid
Browse files Browse the repository at this point in the history
Addition of the `meshgrid` subroutine in `stdlib_math`
  • Loading branch information
jvdp1 authored Mar 5, 2024
2 parents 90a3e9c + 76aaad5 commit 8476d65
Show file tree
Hide file tree
Showing 8 changed files with 297 additions and 1 deletion.
56 changes: 56 additions & 0 deletions doc/specs/stdlib_math.md
Original file line number Diff line number Diff line change
Expand Up @@ -554,3 +554,59 @@ When both `prepend` and `append` are not present, the result `y` has one fewer e
```fortran
{!example/math/example_diff.f90!}
```

### `meshgrid` subroutine

#### Description

Computes a list of coordinate matrices from coordinate vectors.

For $n \geq 1$ coordinate vectors $(x_1, x_2, ..., x_n)$ of sizes $(s_1, s_2, ..., s_n)$, `meshgrid` computes $n$ coordinate matrices $(X_1, X_2, ..., X_n)$ with identical shape corresponding to the selected indexing:
- Cartesian indexing (default behavior): the shape of the coordinate matrices is $(s_2, s_1, s_3, s_4, ... s_n)$.
- matrix indexing: the shape of the coordinate matrices is $(s_1, s_2, s_3, s_4, ... s_n)$.

#### Syntax

For a 2D problem in Cartesian indexing:
`call [[stdlib_math(module):meshgrid(interface)]](x, y, xm, ym)`

For a 3D problem in Cartesian indexing:
`call [[stdlib_math(module):meshgrid(interface)]](x, y, z, xm, ym, zm)`

For a 3D problem in matrix indexing:
`call [[stdlib_math(module):meshgrid(interface)]](x, y, z, xm, ym, zm, indexing="ij")`

The subroutine can be called in `n`-dimensional situations, as long as `n` is inferior to the maximum allowed array rank.

#### Status

Experimental.

#### Class

Subroutine.

#### Arguments

For a `n`-dimensional problem, with `n >= 1`:

`x1, x2, ..., xn`: The coordinate vectors.
Shall be `real/integer` and `rank-1` arrays.
These arguments are `intent(in)`.

`xm1, xm2, ..., xmn`: The coordinate matrices.
Shall be arrays of type `real` or `integer` of adequate shape:
- for Cartesian indexing, the shape of the coordinate matrices must be `[size(x2), size(x1), size(x3), ..., size(xn)]`.
- for matrix indexing, the shape of the coordinate matrices must be `[size(x1), size(x2), size(x3), ..., size(xn)]`.

These argument are `intent(out)`.

`indexing`: the selected indexing.
Shall be an `integer` equal to `stdlib_meshgrid_xy` for Cartesian indexing (default), or `stdlib_meshgrid_ij` for matrix indexing. `stdlib_meshgrid_xy` and `stdlib_meshgrid_ij` are public constants defined in the module.
This argument is `intent(in)` and `optional`, and is equal to `stdlib_meshgrid_xy` by default.

#### Example

```fortran
{!example/math/example_meshgrid.f90!}
```
1 change: 1 addition & 0 deletions example/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ ADD_EXAMPLE(math_argd)
ADD_EXAMPLE(math_arg)
ADD_EXAMPLE(math_argpi)
ADD_EXAMPLE(math_is_close)
ADD_EXAMPLE(meshgrid)
37 changes: 37 additions & 0 deletions example/math/example_meshgrid.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
program example_meshgrid

use stdlib_math, only: meshgrid, linspace, stdlib_meshgrid_ij
use stdlib_kinds, only: sp

implicit none

integer, parameter :: nx = 3, ny = 2
real(sp) :: x(nx), y(ny), &
xm_cart(ny, nx), ym_cart(ny, nx), &
xm_mat(nx, ny), ym_mat(nx, ny)

x = linspace(0_sp, 1_sp, nx)
y = linspace(0_sp, 1_sp, ny)

call meshgrid(x, y, xm_cart, ym_cart)
print *, "xm_cart = "
call print_2d_array(xm_cart)
print *, "ym_cart = "
call print_2d_array(ym_cart)

call meshgrid(x, y, xm_mat, ym_mat, indexing=stdlib_meshgrid_ij)
print *, "xm_mat = "
call print_2d_array(xm_mat)
print *, "ym_mat = "
call print_2d_array(ym_mat)

contains
subroutine print_2d_array(array)
real(sp), intent(in) :: array(:, :)
integer :: i

do i = 1, size(array, dim=1)
print *, array(i, :)
end do
end subroutine
end program example_meshgrid
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ set(fppFiles
stdlib_math_is_close.fypp
stdlib_math_all_close.fypp
stdlib_math_diff.fypp
stdlib_math_meshgrid.fypp
stdlib_str2num.fypp
stdlib_string_type.fypp
stdlib_string_type_constructor.fypp
Expand Down
30 changes: 29 additions & 1 deletion src/stdlib_math.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ module stdlib_math
public :: EULERS_NUMBER_QP
#:endif
public :: DEFAULT_LINSPACE_LENGTH, DEFAULT_LOGSPACE_BASE, DEFAULT_LOGSPACE_LENGTH
public :: arange, arg, argd, argpi, is_close, all_close, diff
public :: stdlib_meshgrid_ij, stdlib_meshgrid_xy
public :: arange, arg, argd, argpi, is_close, all_close, diff, meshgrid

integer, parameter :: DEFAULT_LINSPACE_LENGTH = 100
integer, parameter :: DEFAULT_LOGSPACE_LENGTH = 50
Expand All @@ -32,6 +33,9 @@ module stdlib_math
real(kind=${k1}$), parameter :: PI_${k1}$ = acos(-1.0_${k1}$)
#:endfor

!> Values for optional argument `indexing` of `meshgrid`
integer, parameter :: stdlib_meshgrid_xy = 0, stdlib_meshgrid_ij = 1

interface clip
#:for k1, t1 in IR_KINDS_TYPES
module procedure clip_${k1}$
Expand Down Expand Up @@ -382,6 +386,30 @@ module stdlib_math
#:endfor
end interface diff


!> Version: experimental
!>
!> Computes a list of coordinate matrices from coordinate vectors.
!> ([Specification](../page/specs/stdlib_math.html#meshgrid))
interface meshgrid
#:set RANKS = range(1, MAXRANK + 1)
#:for k1, t1 in IR_KINDS_TYPES
#:for rank in RANKS
#:set RName = rname("meshgrid", rank, t1, k1)
module subroutine ${RName}$(&
${"".join(f"x{i}, " for i in range(1, rank + 1))}$ &
${"".join(f"xm{i}, " for i in range(1, rank + 1))}$ &
indexing &
)
#:for i in range(1, rank + 1)
${t1}$, intent(in) :: x${i}$(:)
${t1}$, intent(out) :: xm${i}$ ${ranksuffix(rank)}$
#:endfor
integer, intent(in), optional :: indexing
end subroutine ${RName}$
#:endfor
#:endfor
end interface meshgrid
contains

#:for k1, t1 in IR_KINDS_TYPES
Expand Down
50 changes: 50 additions & 0 deletions src/stdlib_math_meshgrid.fypp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#:include "common.fypp"
#:set IR_KINDS_TYPES = INT_KINDS_TYPES + REAL_KINDS_TYPES
#:set RANKS = range(1, MAXRANK + 1)

#:def meshgrid_loop(indices)
#:for j in reversed(indices)
do i${j}$ = 1, size(x${j}$)
#:endfor
#:for j in indices
xm${j}$(${"".join(f"i{j}," for j in indices).removesuffix(",")}$) = &
x${j}$(i${j}$)
#:endfor
#:for j in indices
end do
#:endfor
#:enddef

submodule(stdlib_math) stdlib_math_meshgrid

use stdlib_error, only: error_stop

contains

#:for k1, t1 in IR_KINDS_TYPES
#:for rank in RANKS
#:if rank == 1
#:set XY_INDICES = [1]
#:set IJ_INDICES = [1]
#:else
#:set XY_INDICES = [2, 1] + [j for j in range(3, rank + 1)]
#:set IJ_INDICES = [1, 2] + [j for j in range(3, rank + 1)]
#:endif
#: set RName = rname("meshgrid", rank, t1, k1)
module procedure ${RName}$

integer :: ${"".join(f"i{j}," for j in range(1, rank + 1)).removesuffix(",")}$

select case (optval(indexing, stdlib_meshgrid_xy))
case (stdlib_meshgrid_xy)
$:meshgrid_loop(XY_INDICES)
case (stdlib_meshgrid_ij)
$:meshgrid_loop(IJ_INDICES)
case default
call error_stop("ERROR (meshgrid): unexpected indexing.")
end select
end procedure
#:endfor
#:endfor

end submodule
2 changes: 2 additions & 0 deletions test/math/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
set(
fppFiles
"test_stdlib_math.fypp"
"test_meshgrid.fypp"
)
fypp_f90("${fyppFlags}" "${fppFiles}" outFiles)

ADDTEST(stdlib_math)
ADDTEST(linspace)
ADDTEST(logspace)
ADDTEST(meshgrid)
121 changes: 121 additions & 0 deletions test/math/test_meshgrid.fypp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
! SPDX-Identifier: MIT

#:include "common.fypp"
#:set IR_KINDS_TYPES = INT_KINDS_TYPES + REAL_KINDS_TYPES
#:set RANKS = range(1, MAXRANK + 1)
#:set INDEXINGS = ["default", "xy", "ij"]

#:def OPTIONAL_PART_IN_SIGNATURE(indexing)
#:if indexing in ("xy", "ij")
${f', stdlib_meshgrid_{indexing}'}$
#:endif
#:enddef

module test_meshgrid
use testdrive, only : new_unittest, unittest_type, error_type, check
use stdlib_math, only: meshgrid, stdlib_meshgrid_ij, stdlib_meshgrid_xy
use stdlib_kinds, only: int8, int16, int32, int64, sp, dp, xdp, qp
implicit none

public :: collect_meshgrid

contains

!> Collect all exported unit tests
subroutine collect_meshgrid(testsuite)
!> Collection of tests
type(unittest_type), allocatable, intent(out) :: testsuite(:)

testsuite = [ &
#:for k1, t1 in IR_KINDS_TYPES
#:for rank in RANKS
#:for INDEXING in INDEXINGS
#: set RName = rname(f"meshgrid_{INDEXING}", rank, t1, k1)
new_unittest("${RName}$", test_${RName}$), &
#:endfor
#:endfor
#:endfor
new_unittest("dummy", test_dummy) &
]

end subroutine collect_meshgrid

#:for k1, t1 in IR_KINDS_TYPES
#:for rank in RANKS
#:for INDEXING in INDEXINGS
#:if rank == 1
#:set INDICES = [1]
#:else
#:if INDEXING in ("default", "xy")
#:set INDICES = [2, 1] + [j for j in range(3, rank + 1)]
#:elif INDEXING == "ij"
#:set INDICES = [1, 2] + [j for j in range(3, rank + 1)]
#:endif
#:endif
#:set RName = rname(f"meshgrid_{INDEXING}", rank, t1, k1)
#:set GRIDSHAPE = "".join("length," for j in range(rank)).removesuffix(",")
subroutine test_${RName}$(error)
!> Error handling
type(error_type), allocatable, intent(out) :: error
integer, parameter :: length = 3
${t1}$ :: ${"".join(f"x{j}(length)," for j in range(1, rank + 1)).removesuffix(",")}$
${t1}$ :: ${"".join(f"xm{j}({GRIDSHAPE})," for j in range(1, rank + 1)).removesuffix(",")}$
${t1}$ :: ${"".join(f"xm{j}_exact({GRIDSHAPE})," for j in range(1, rank + 1)).removesuffix(",")}$
integer :: i
integer :: ${"".join(f"i{j}," for j in range(1, rank + 1)).removesuffix(",")}$
${t1}$, parameter :: ZERO = 0
! valid test case
#:for index in range(1, rank + 1)
x${index}$ = [(i, i = length * ${index - 1}$ + 1, length * ${index}$)]
#:endfor
#:for j in range(1, rank + 1)
xm${j}$_exact = reshape( &
[${"".join("(" for dummy in range(rank)) + f"x{j}(i{j})" + "".join(f", i{index} = 1, size(x{index}))" for index in INDICES)}$], &
shape=[${GRIDSHAPE}$] &
)
#:endfor
call meshgrid( &
${"".join(f"x{j}," for j in range(1, rank + 1))}$ &
${"".join(f"xm{j}," for j in range(1, rank + 1)).removesuffix(",")}$ &
${OPTIONAL_PART_IN_SIGNATURE(INDEXING)}$ )
#:for j in range(1, rank + 1)
call check(error, maxval(abs(xm${j}$ - xm${j}$_exact)), ZERO)
if (allocated(error)) return
#:endfor
end subroutine test_${RName}$
#:endfor
#:endfor
#:endfor

subroutine test_dummy(error)
!> Error handling
type(error_type), allocatable, intent(out) :: error
end subroutine

end module test_meshgrid

program tester
use, intrinsic :: iso_fortran_env, only : error_unit
use testdrive, only : run_testsuite, new_testsuite, testsuite_type
use test_meshgrid, only : collect_meshgrid
implicit none
integer :: stat, is
type(testsuite_type), allocatable :: testsuites(:)
character(len=*), parameter :: fmt = '("#", *(1x, a))'

stat = 0

testsuites = [ &
new_testsuite("meshgrid", collect_meshgrid) &
]

do is = 1, size(testsuites)
write(error_unit, fmt) "Testing:", testsuites(is)%name
call run_testsuite(testsuites(is)%collect, error_unit, stat)
end do

if (stat > 0) then
write(error_unit, '(i0, 1x, a)') stat, "test(s) failed!"
error stop
end if
end program tester

0 comments on commit 8476d65

Please sign in to comment.