Skip to content

Commit

Permalink
Merge pull request #9 from MP-Gadget/r2r
Browse files Browse the repository at this point in the history
Basic support for R2R transforms.

Fixes: #8
  • Loading branch information
sbird authored Aug 20, 2024
2 parents ed8af3c + f16f776 commit aaf907a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 16 deletions.
58 changes: 42 additions & 16 deletions pfft/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ cdef extern from 'pfft.h':
int _PFFT_PADDED_R2C "PFFT_PADDED_R2C"
int _PFFT_PADDED_C2R "PFFT_PADDED_C2R"

int _FFTW_R2HC "FFTW_R2HC"
int _FFTW_HC2R "FFTW_HC2R"

void pfft_init()
void pfftf_init()
void pfft_cleanup()
Expand Down Expand Up @@ -97,7 +100,7 @@ cdef extern from 'pfft.h':
pfft_plan pfft_plan_r2r(
int rnk_n, numpy.intp_t *n, void * input, void * output,
cMPI.MPI_Comm ccart,
int sign, unsigned pfft_flags)
int * kinds, unsigned pfft_flags)

pfft_plan pfftf_plan_dft(
int rnk_n, numpy.intp_t *n, void * input, void * output,
Expand All @@ -117,7 +120,7 @@ cdef extern from 'pfft.h':
pfft_plan pfftf_plan_r2r(
int rnk_n, numpy.intp_t *n, void * input, void * output,
cMPI.MPI_Comm ccart,
int sign, unsigned pfft_flags)
int * kinds, unsigned pfft_flags)

int pfft_create_procmesh(int rnk_n, cMPI.MPI_Comm comm, int *np,
cMPI.MPI_Comm * ccart)
Expand Down Expand Up @@ -287,9 +290,11 @@ class Type(int):
inverses = { Type.C2C : Type.C2C,
Type.R2C : Type.C2R,
Type.C2R : Type.R2C,
Type.R2R : Type.R2R,
Type.C2CF : Type.C2CF,
Type.R2CF : Type.C2RF,
Type.C2RF : Type.R2CF,
Type.R2RF : Type.R2RF,
}
return inverses[self]

Expand All @@ -314,17 +319,23 @@ ctypedef pfft_plan (*pfft_plan_func) (
int rnk_n, numpy.intp_t *n, void * input, void * output,
cMPI.MPI_Comm ccart,
int sign, unsigned pfft_flags)
cdef pfft_plan_func PFFT_PLAN_FUNC [8]

ctypedef pfft_plan (*pfft_plan_func_r2r) (
int rnk_n, numpy.intp_t *n, void * input, void * output,
cMPI.MPI_Comm ccart,
int * kinds, unsigned pfft_flags)

cdef void * PFFT_PLAN_FUNC [8]

PFFT_PLAN_FUNC[:] = [
<pfft_plan_func> pfft_plan_dft,
<pfft_plan_func> pfft_plan_dft_r2c,
<pfft_plan_func> pfft_plan_dft_c2r,
<pfft_plan_func> pfft_plan_r2r,
<pfft_plan_func> pfftf_plan_dft,
<pfft_plan_func> pfftf_plan_dft_r2c,
<pfft_plan_func> pfftf_plan_dft_c2r,
<pfft_plan_func> pfftf_plan_r2r,
<void*> pfft_plan_dft,
<void*> pfft_plan_dft_r2c,
<void*> pfft_plan_dft_c2r,
<void*> pfft_plan_r2r,
<void*> pfftf_plan_dft,
<void*> pfftf_plan_dft_r2c,
<void*> pfftf_plan_dft_c2r,
<void*> pfftf_plan_r2r,
]

ctypedef void (*pfft_free_plan_func) (void * plan)
Expand Down Expand Up @@ -868,7 +879,9 @@ cdef class Plan(object):

self.flags = Flags(flags)

cdef pfft_plan_func func = PFFT_PLAN_FUNC[self.type]
cdef pfft_plan_func plan_func = <pfft_plan_func> PFFT_PLAN_FUNC[self.type]
cdef pfft_plan_func_r2r plan_func_r2r = <pfft_plan_func_r2r> PFFT_PLAN_FUNC[self.type]

cdef numpy.intp_t [::1] n_ = numpy.array(n, dtype='intp')
if o is None:
o = i
Expand All @@ -885,10 +898,23 @@ cdef class Plan(object):
raise NotImplementedError("out place non-padded r2c / c2r does not preserve input.(%s) " % repr(self.flags)
+ "Provide PFFT_DESTROY_INPUT as a flag and deal with this quirk.")

self.plan = func(n_.shape[0], &n_[0], i.ptr, o.ptr,
procmesh.ccart,
self.direction,
flags)
cdef int [::1] kinds = numpy.zeros(len(n), dtype='int32')

if direction == Direction.FORWARD:
kinds[...] = _FFTW_R2HC
else:
kinds[...] = _FFTW_HC2R

if self.type in (Type.R2R, Type.R2RF):
self.plan = plan_func_r2r(n_.shape[0], &n_[0], i.ptr, o.ptr,
procmesh.ccart,
&kinds[0],
flags)
else:
self.plan = plan_func(n_.shape[0], &n_[0], i.ptr, o.ptr,
procmesh.ccart,
self.direction,
flags)
if not self.plan:
raise ValueError("Plan is not created")

Expand Down
14 changes: 14 additions & 0 deletions pfft/tests/test_pfft.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,20 @@ def test_leak(comm):
#FIXME: check with @mpip if this is correct.
i = buffer.view_input()

@MPITest([1, 4])
def test_no_segfault_r2r(comm):
procmesh = pfft.ProcMesh(np=[comm.size], comm=comm)

partition = pfft.Partition(pfft.Type.PFFT_R2R, [32, 32],
procmesh, flags=pfft.Flags.PFFT_ESTIMATE)

buffer1 = pfft.LocalBuffer(partition)
buffer2 = pfft.LocalBuffer(partition)

plan = pfft.Plan(partition, pfft.Direction.PFFT_FORWARD, buffer1, buffer2)
plan.execute(buffer1, buffer2)


@MPITest([4])
def test_2d_on_2d_c2c(comm):
procmesh = pfft.ProcMesh(np=[2, 2], comm=comm)
Expand Down

0 comments on commit aaf907a

Please sign in to comment.