diff --git a/src/ParallelKernel/allocators.jl b/src/ParallelKernel/allocators.jl index 033b9d04..a07ab3a6 100644 --- a/src/ParallelKernel/allocators.jl +++ b/src/ParallelKernel/allocators.jl @@ -43,7 +43,7 @@ macro rand_threads(args...) check_initialized(); esc(_rand(args...; package=PKG function _zeros(args...; package::Symbol=get_package()) numbertype = get_numbertype() if (package == PKG_CUDA) return :(CUDA.zeros($numbertype, $(args...))) - elseif (package == PKG_THREADS) return :(Base.zeros($numbertype, $(args...))) + elseif (package == PKG_THREADS) return :(ParallelStencil.ParallelKernel._parallel_init(Base.zero, $numbertype, $(args...))) else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") end end @@ -51,7 +51,7 @@ end function _ones(args...; package::Symbol=get_package()) numbertype = get_numbertype() if (package == PKG_CUDA) return :(CUDA.ones($numbertype, $(args...))) - elseif (package == PKG_THREADS) return :(Base.ones($numbertype, $(args...))) + elseif (package == PKG_THREADS) return :(ParallelStencil.ParallelKernel._parallel_init(Base.one, $numbertype, $(args...))) else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") end end @@ -59,7 +59,15 @@ end function _rand(args...; package::Symbol=get_package()) numbertype = get_numbertype() if (package == PKG_CUDA) return :(CUDA.CuArray(rand($numbertype, $(args...)))) - elseif (package == PKG_THREADS) return :(Base.rand($numbertype, $(args...))) + elseif (package == PKG_THREADS) return :(ParallelStencil.ParallelKernel._parallel_init(Base.rand, $numbertype, $(args...))) else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") end end + +function _parallel_init(f::F, numbertype::Type{T}, args...) where {F, T} + arr = Array{numbertype, length(args)}(undef, args...) + Threads.@threads :static for i in eachindex(arr) + @inbounds arr[i] = f(T) + end + return arr +end