From 71fc5d5bba2c19f312b3a6de75d82d81e06aaeb1 Mon Sep 17 00:00:00 2001 From: sriharshakandala Date: Thu, 24 Oct 2024 11:24:33 -0700 Subject: [PATCH] Use single kernel for pointwise functions --- Project.toml | 1 + examples/Manifest.toml | 62 ++-- .../microphysics/microphysics_wrappers.jl | 275 +++++++++++------- 3 files changed, 208 insertions(+), 130 deletions(-) diff --git a/Project.toml b/Project.toml index bef81424e3..3b7fd8efea 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" ArtifactWrappers = "a14bc488-3040-4b00-9dc1-f6467924858a" Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" AtmosphericProfilesLibrary = "86bc3604-9858-485a-bdbe-831ec50de11d" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884" ClimaDiagnostics = "1ecacbb8-0713-4841-9a07-eb5aa8a2d53f" diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 9ef3bf6853..0e566f3deb 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -57,9 +57,9 @@ version = "0.1.38" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" +git-tree-sha1 = "d80af0733c99ea80575f612813fa6aa71022d33a" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.4" +version = "4.1.0" weakdeps = ["StaticArrays"] [deps.Adapt.extensions] @@ -210,9 +210,9 @@ version = "1.21.6+0" [[deps.Bzip2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd" +git-tree-sha1 = "8873e196c2eb87962a2048b3b8e08946535864a1" uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" -version = "1.0.8+1" +version = "1.0.8+2" [[deps.CEnum]] git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" @@ -319,7 +319,7 @@ version = "0.5.7" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" [[deps.ClimaAtmos]] -deps = ["Adapt", "ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "ClimaComms", "ClimaCore", "ClimaDiagnostics", "ClimaParams", "ClimaTimeSteppers", "ClimaUtilities", "CloudMicrophysics", "Dates", "DiffEqBase", "FastGaussQuadrature", "Insolation", "Interpolations", "LazyArtifacts", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "RRTMGP", "Random", "SciMLBase", "StaticArrays", "Statistics", "SurfaceFluxes", "Thermodynamics", "UnrolledUtilities", "YAML"] +deps = ["Adapt", "ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "CUDA", "ClimaComms", "ClimaCore", "ClimaDiagnostics", "ClimaParams", "ClimaTimeSteppers", "ClimaUtilities", "CloudMicrophysics", "Dates", "DiffEqBase", "FastGaussQuadrature", "Insolation", "Interpolations", "LazyArtifacts", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "RRTMGP", "Random", "SciMLBase", "StaticArrays", "Statistics", "SurfaceFluxes", "Thermodynamics", "UnrolledUtilities", "YAML"] path = ".." uuid = "b2c96348-7fb7-4fe0-8da9-78d88439e717" version = "0.27.7" @@ -383,9 +383,9 @@ version = "0.2.8" [[deps.ClimaParams]] deps = ["TOML"] -git-tree-sha1 = "ca82603622e2df9dbf14716e43589b76e9705840" +git-tree-sha1 = "489c5655993c62fb34293908a6b0877e32f183ee" uuid = "5c42b081-d73a-476f-9059-fd94b934656c" -version = "0.10.15" +version = "0.10.16" [[deps.ClimaReproducibilityTests]] deps = ["OrderedCollections", "PrettyTables"] @@ -411,9 +411,9 @@ version = "0.7.38" [[deps.ClimaUtilities]] deps = ["Artifacts", "Dates"] -git-tree-sha1 = "24bd6d5066404af09215c372d0ffea56ed849206" +git-tree-sha1 = "9b783e099151e9e14c1063e736135145dcfda451" uuid = "b3f4f4ca-9299-4f7f-bd9b-81e1242a7513" -version = "0.1.15" +version = "0.1.16" weakdeps = ["Adapt", "CUDA", "ClimaComms", "ClimaCore", "ClimaCoreTempestRemap", "Interpolations", "NCDatasets"] [deps.ClimaUtilities.extensions] @@ -464,9 +464,9 @@ version = "0.4.0" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" +git-tree-sha1 = "13951eb68769ad1cd460cdb2e64e5e95f1bf123d" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.26.0" +version = "3.27.0" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] @@ -602,9 +602,9 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" [[deps.DelaunayTriangulation]] deps = ["AdaptivePredicates", "EnumX", "ExactPredicates", "PrecompileTools", "Random"] -git-tree-sha1 = "668bb97ea6df5e654e6288d87d2243591fe68665" +git-tree-sha1 = "89df54fbe66e5872d91d8c2cd3a375f660c3fd64" uuid = "927a84f5-c5f4-47a5-9785-b46e178433df" -version = "1.6.0" +version = "1.6.1" [[deps.DelimitedFiles]] deps = ["Mmap"] @@ -1179,9 +1179,9 @@ version = "1.0.0" [[deps.JET]] deps = ["CodeTracking", "InteractiveUtils", "JuliaInterpreter", "LoweredCodeUtils", "MacroTools", "Pkg", "PrecompileTools", "Preferences", "Test"] -git-tree-sha1 = "b2cb92e1fa8c1f33b1eb997e195dca442f53440b" +git-tree-sha1 = "5c5ac91e775b585864015c5c1703cee283071a47" uuid = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" -version = "0.9.11" +version = "0.9.12" [deps.JET.extensions] JETCthulhuExt = "Cthulhu" @@ -1193,9 +1193,9 @@ version = "0.9.11" [[deps.JLD2]] deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"] -git-tree-sha1 = "aeab5c68eb2cf326619bf71235d8f4561c62fe22" +git-tree-sha1 = "783c1be5213a09609b23237a0c9e5dfd258ae6f2" uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.5.5" +version = "0.5.7" [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] @@ -1280,10 +1280,10 @@ uuid = "88015f11-f218-50d7-93a8-a6af411a945d" version = "4.0.0+0" [[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "4ad43cb0a4bb5e5b1506e1d1f48646d7e0c80363" +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Unicode"] +git-tree-sha1 = "d422dfd9707bec6617335dc2ea3c5172a87d5908" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "9.1.2" +version = "9.1.3" weakdeps = ["BFloat16s"] [deps.LLVM.extensions] @@ -2063,10 +2063,10 @@ uuid = "94e857df-77ce-4151-89e5-788b33177be4" version = "0.1.0" [[deps.SciMLBase]] -deps = ["ADTypes", "Accessors", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "Expronicon", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "50ed64cd5ad79b0bef71fdb6a11d10c3448bfef0" +deps = ["ADTypes", "Accessors", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "Expronicon", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface"] +git-tree-sha1 = "26fea1911818cd480400f1a2b7f6b32c3cc3836a" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -version = "2.56.1" +version = "2.56.4" [deps.SciMLBase.extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" @@ -2089,9 +2089,9 @@ version = "2.56.1" [[deps.SciMLOperators]] deps = ["Accessors", "ArrayInterface", "DocStringExtensions", "LinearAlgebra", "MacroTools"] -git-tree-sha1 = "e39c5f217f9aca640c8e27ab21acf557a3967db5" +git-tree-sha1 = "ef388ca9e4921ec5614ce714f8aa59a5cd33d867" uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" -version = "0.3.10" +version = "0.3.11" weakdeps = ["SparseArrays", "StaticArraysCore"] [deps.SciMLOperators.extensions] @@ -2219,9 +2219,9 @@ weakdeps = ["OffsetArrays", "StaticArrays"] [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" +git-tree-sha1 = "777657803913ffc7e8cc20f0fd04b634f871af8f" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.7" +version = "1.9.8" weakdeps = ["ChainRulesCore", "Statistics"] [deps.StaticArrays.extensions] @@ -2373,9 +2373,9 @@ version = "1.0.2" [[deps.Thermodynamics]] deps = ["DocStringExtensions", "KernelAbstractions", "Random", "RootSolvers"] -git-tree-sha1 = "8c2afc6dbb2bdac698a5b05816b7521630dea034" +git-tree-sha1 = "5de9f9f6019165cedb04e365a9f277a518ac5aaf" uuid = "b60c26fb-14c3-4610-9d3e-2d17fe7ff00c" -version = "0.12.8" +version = "0.12.9" weakdeps = ["ClimaParams"] [deps.Thermodynamics.extensions] @@ -2445,9 +2445,9 @@ uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" version = "0.1.5" [[deps.UnrolledUtilities]] -git-tree-sha1 = "1244bd810bfd53c539ef5463b01dd9b808c2e5da" +git-tree-sha1 = "5caf11dfadeee25daafa7caabb3f252a977ffe72" uuid = "0fe1646c-419e-43be-ac14-22321958931b" -version = "0.1.5" +version = "0.1.6" weakdeps = ["StaticArrays"] [deps.UnrolledUtilities.extensions] diff --git a/src/parameterized_tendencies/microphysics/microphysics_wrappers.jl b/src/parameterized_tendencies/microphysics/microphysics_wrappers.jl index ac118ef08d..836501413a 100644 --- a/src/parameterized_tendencies/microphysics/microphysics_wrappers.jl +++ b/src/parameterized_tendencies/microphysics/microphysics_wrappers.jl @@ -6,6 +6,8 @@ import CloudMicrophysics.Microphysics1M as CM1 import CloudMicrophysics.Microphysics2M as CM2 import CloudMicrophysics.MicrophysicsNonEq as CMNe import CloudMicrophysics.Parameters as CMP +using CUDA +using NVTX # define some aliases and functions to make the code more readable const Iₗ = TD.internal_energy_liquid @@ -136,7 +138,7 @@ The specific humidity source terms are defined as defined as Δmᵢ / (m_dry + m where i stands for total, rain or snow. Also returns the total energy source term due to the microphysics processes. """ -function compute_precipitation_sources!( +NVTX.@annotate function compute_precipitation_sources!( Sᵖ, Sᵖ_snow, Sqₜᵖ, @@ -153,108 +155,183 @@ function compute_precipitation_sources!( thp, ) FT = eltype(thp) - # @. Sqₜᵖ = FT(0) should work after fixing - # https://github.com/CliMA/ClimaCore.jl/issues/1786 - @. Sqₜᵖ = ρ * FT(0) - @. Sqᵣᵖ = ρ * FT(0) - @. Sqₛᵖ = ρ * FT(0) - @. Seₜᵖ = ρ * FT(0) - - #! format: off - # rain autoconversion: q_liq -> q_rain - @. Sᵖ = ifelse( - mp.Ndp <= 0, - CM1.conv_q_liq_to_q_rai(mp.pr.acnv1M, qₗ(thp, ts), true), - CM2.conv_q_liq_to_q_rai(mp.var, qₗ(thp, ts), ρ, mp.Ndp), - ) - @. Sᵖ = min(limit(qₗ(thp, ts), dt, 5), Sᵖ) - @. Sqₜᵖ -= Sᵖ - @. Sqᵣᵖ += Sᵖ - @. Seₜᵖ -= Sᵖ * (Iₗ(thp, ts) + Φ) - - # snow autoconversion assuming no supersaturation: q_ice -> q_snow - @. Sᵖ = min( - limit(qᵢ(thp, ts), dt, 5), - CM1.conv_q_ice_to_q_sno_no_supersat(mp.ps.acnv1M, qᵢ(thp, ts), true), - ) - @. Sqₜᵖ -= Sᵖ - @. Sqₛᵖ += Sᵖ - @. Seₜᵖ -= Sᵖ * (Iᵢ(thp, ts) + Φ) - - # accretion: q_liq + q_rain -> q_rain - @. Sᵖ = min( - limit(qₗ(thp, ts), dt, 5), - CM1.accretion(mp.cl, mp.pr, mp.tv.rain, mp.ce, qₗ(thp, ts), qᵣ, ρ), + device = ClimaComms.device(Sᵖ) + dims = Base.size(Fields.field_values(Sᵖ)) + fvargs = + Fields.field_values.(( + Sᵖ, + Sᵖ_snow, + Sqₜᵖ, + Sqᵣᵖ, + Sqₛᵖ, + Seₜᵖ, + ρ, + qᵣ, + qₛ, + ts, + Φ, + )) + args = (dt, mp, thp) + pointwise_dispatch( + device, + dims, + compute_precipitation_sources_kernel!, + fvargs..., + args..., ) - @. Sqₜᵖ -= Sᵖ - @. Sqᵣᵖ += Sᵖ - @. Seₜᵖ -= Sᵖ * (Iₗ(thp, ts) + Φ) - - # accretion: q_ice + q_snow -> q_snow - @. Sᵖ = min( - limit(qᵢ(thp, ts), dt, 5), - CM1.accretion(mp.ci, mp.ps, mp.tv.snow, mp.ce, qᵢ(thp, ts), qₛ, ρ), - ) - @. Sqₜᵖ -= Sᵖ - @. Sqₛᵖ += Sᵖ - @. Seₜᵖ -= Sᵖ * (Iᵢ(thp, ts) + Φ) + return nothing +end - # accretion: q_liq + q_sno -> q_sno or q_rai - # sink of cloud water via accretion cloud water + snow - @. Sᵖ = min( - limit(qₗ(thp, ts), dt, 5), - CM1.accretion(mp.cl, mp.ps, mp.tv.snow, mp.ce, qₗ(thp, ts), qₛ, ρ), - ) - # if T < T_freeze cloud droplets freeze to become snow - # else the snow melts and both cloud water and snow become rain - α(thp, ts) = cᵥₗ(thp) / Lf(thp, ts) * (Tₐ(thp, ts) - mp.ps.T_freeze) - @. Sᵖ_snow = ifelse( - Tₐ(thp, ts) < mp.ps.T_freeze, - Sᵖ, - FT(-1) * min(Sᵖ * α(thp, ts), limit(qₛ, dt, 5)), - ) - @. Sqₛᵖ += Sᵖ_snow - @. Sqₜᵖ -= Sᵖ - @. Sqᵣᵖ += ifelse(Tₐ(thp, ts) < mp.ps.T_freeze, FT(0), Sᵖ - Sᵖ_snow) - @. Seₜᵖ -= ifelse( - Tₐ(thp, ts) < mp.ps.T_freeze, - Sᵖ * (Iᵢ(thp, ts) + Φ), - Sᵖ * (Iₗ(thp, ts) + Φ) - Sᵖ_snow * (Iₗ(thp, ts) - Iᵢ(thp, ts)), - ) +@inline function compute_precipitation_sources_kernel!( + Sᵖ, + Sᵖ_snow, + Sqₜᵖ, + Sqᵣᵖ, + Sqₛᵖ, + Seₜᵖ, + ρ, + qᵣ, + qₛ, + ts, + Φ, + dt, + mp, + thp, + idx, +) + FT = eltype(thp) + @inbounds begin + # @. Sqₜᵖ = FT(0) should work after fixing + # https://github.com/CliMA/ClimaCore.jl/issues/1786 + Sqₜᵖ[idx] = zero(Sqₜᵖ[idx]) + Sqᵣᵖ[idx] = zero(Sqᵣᵖ[idx]) + Sqₛᵖ[idx] = zero(Sqₛᵖ[idx]) + Seₜᵖ[idx] = zero(Seₜᵖ[idx]) + + #! format: off + # rain autoconversion: q_liq -> q_rain + Sᵖ[idx] = ifelse( + mp.Ndp <= 0, + CM1.conv_q_liq_to_q_rai(mp.pr.acnv1M, qₗ(thp, ts[idx]), true), + CM2.conv_q_liq_to_q_rai(mp.var, qₗ(thp, ts[idx]), ρ[idx], mp.Ndp), + ) + Sᵖ[idx] = min(limit(qₗ(thp, ts[idx]), dt, 5), Sᵖ[idx]) + Sqₜᵖ[idx] -= Sᵖ[idx] + Sqᵣᵖ[idx] += Sᵖ[idx] + Seₜᵖ[idx] -= Sᵖ[idx] * (Iₗ(thp, ts[idx]) + Φ[idx]) + + # snow autoconversion assuming no supersaturation: q_ice -> q_snow + Sᵖ[idx] = min( + limit(qᵢ(thp, ts[idx]), dt, 5), + CM1.conv_q_ice_to_q_sno_no_supersat(mp.ps.acnv1M, qᵢ(thp, ts[idx]), true), + ) + Sqₜᵖ[idx] -= Sᵖ[idx] + Sqₛᵖ[idx] += Sᵖ[idx] + Seₜᵖ[idx] -= Sᵖ[idx] * (Iᵢ(thp, ts[idx]) + Φ[idx]) + + # accretion: q_liq + q_rain -> q_rain + Sᵖ[idx] = min( + limit(qₗ(thp, ts[idx]), dt, 5), + CM1.accretion(mp.cl, mp.pr, mp.tv.rain, mp.ce, qₗ(thp, ts[idx]), qᵣ[idx], ρ[idx]), + ) + Sqₜᵖ[idx] -= Sᵖ[idx] + Sqᵣᵖ[idx] += Sᵖ[idx] + Seₜᵖ[idx] -= Sᵖ[idx] * (Iₗ(thp, ts[idx]) + Φ[idx]) + + # accretion: q_ice + q_snow -> q_snow + Sᵖ[idx] = min( + limit(qᵢ(thp, ts[idx]), dt, 5), + CM1.accretion(mp.ci, mp.ps, mp.tv.snow, mp.ce, qᵢ(thp, ts[idx]), qₛ[idx], ρ[idx]), + ) + Sqₜᵖ[idx] -= Sᵖ[idx] + Sqₛᵖ[idx] += Sᵖ[idx] + Seₜᵖ[idx] -= Sᵖ[idx] * (Iᵢ(thp, ts[idx]) + Φ[idx]) + + # accretion: q_liq + q_sno -> q_sno or q_rai + # sink of cloud water via accretion cloud water + snow + Sᵖ[idx] = min( + limit(qₗ(thp, ts[idx]), dt, 5), + CM1.accretion(mp.cl, mp.ps, mp.tv.snow, mp.ce, qₗ(thp, ts[idx]), qₛ[idx], ρ[idx]), + ) + # if T < T_freeze cloud droplets freeze to become snow + # else the snow melts and both cloud water and snow become rain + #α(thp, ts[idx]) = cᵥₗ(thp) / Lf(thp, ts[idx]) * (Tₐ(thp, ts[idx]) - mp.ps.T_freeze) + α(thparg, tsarg) = cᵥₗ(thparg) / Lf(thparg, tsarg) * (Tₐ(thparg, tsarg) - mp.ps.T_freeze) + Sᵖ_snow[idx] = ifelse( + Tₐ(thp, ts[idx]) < mp.ps.T_freeze, + Sᵖ[idx], + FT(-1) * min(Sᵖ[idx] * α(thp, ts[idx]), limit(qₛ[idx], dt, 5)), + ) + Sqₛᵖ[idx] += Sᵖ_snow[idx] + Sqₜᵖ[idx] -= Sᵖ[idx] + Sqᵣᵖ[idx] += ifelse(Tₐ(thp, ts[idx]) < mp.ps.T_freeze, FT(0), Sᵖ[idx] - Sᵖ_snow[idx]) + Seₜᵖ[idx] -= ifelse( + Tₐ(thp, ts[idx]) < mp.ps.T_freeze, + Sᵖ[idx] * (Iᵢ(thp, ts[idx]) + Φ[idx]), + Sᵖ[idx] * (Iₗ(thp, ts[idx]) + Φ[idx]) - Sᵖ_snow[idx] * (Iₗ(thp, ts[idx]) - Iᵢ(thp, ts[idx])), + ) + + # accretion: q_ice + q_rai -> q_sno + Sᵖ[idx] = min( + limit(qᵢ(thp, ts[idx]), dt, 5), + CM1.accretion(mp.ci, mp.pr, mp.tv.rain, mp.ce, qᵢ(thp, ts[idx]), qᵣ[idx], ρ[idx]), + ) + Sqₜᵖ[idx] -= Sᵖ[idx] + Sqₛᵖ[idx] += Sᵖ[idx] + Seₜᵖ[idx] -= Sᵖ[idx] * (Iᵢ(thp, ts[idx]) + Φ[idx]) + # sink of rain via accretion cloud ice - rain + Sᵖ[idx] = min( + limit(qᵣ[idx], dt, 5), + CM1.accretion_rain_sink(mp.pr, mp.ci, mp.tv.rain, mp.ce, qᵢ(thp, ts[idx]), qᵣ[idx], ρ[idx]), + ) + Sqᵣᵖ[idx] -= Sᵖ[idx] + Sqₛᵖ[idx] += Sᵖ[idx] + Seₜᵖ[idx] += Sᵖ[idx] * Lf(thp, ts[idx]) + + # accretion: q_rai + q_sno -> q_rai or q_sno + Sᵖ[idx] = ifelse( + Tₐ(thp, ts[idx]) < mp.ps.T_freeze, + min( + limit(qᵣ[idx], dt, 5), + CM1.accretion_snow_rain(mp.ps, mp.pr, mp.tv.rain, mp.tv.snow, mp.ce, qₛ[idx], qᵣ[idx], ρ[idx]), + ), + -min( + limit(qₛ[idx], dt, 5), + CM1.accretion_snow_rain(mp.pr, mp.ps, mp.tv.snow, mp.tv.rain, mp.ce, qᵣ[idx], qₛ[idx], ρ[idx]), + ), + ) + Sqₛᵖ[idx] += Sᵖ[idx] + Sqᵣᵖ[idx] -= Sᵖ[idx] + Seₜᵖ[idx] += Sᵖ[idx] * Lf(thp, ts[idx]) + #! format: on + end +end - # accretion: q_ice + q_rai -> q_sno - @. Sᵖ = min( - limit(qᵢ(thp, ts), dt, 5), - CM1.accretion(mp.ci, mp.pr, mp.tv.rain, mp.ce, qᵢ(thp, ts), qᵣ, ρ), - ) - @. Sqₜᵖ -= Sᵖ - @. Sqₛᵖ += Sᵖ - @. Seₜᵖ -= Sᵖ * (Iᵢ(thp, ts) + Φ) - # sink of rain via accretion cloud ice - rain - @. Sᵖ = min( - limit(qᵣ, dt, 5), - CM1.accretion_rain_sink(mp.pr, mp.ci, mp.tv.rain, mp.ce, qᵢ(thp, ts), qᵣ, ρ), - ) - @. Sqᵣᵖ -= Sᵖ - @. Sqₛᵖ += Sᵖ - @. Seₜᵖ += Sᵖ * Lf(thp, ts) +function pointwise_dispatch( + device::ClimaComms.CUDADevice, + dims, + pointwisefn!, + args..., +) + NI, NJ, _, NV, NH = dims + max_threads = 768#512#256 + @assert NI * NJ ≤ max_threads + nvthreads = Int(fld(max_threads, NI * NJ)) + nvblocks = Int(cld(NV, nvthreads)) + CUDA.@cuda always_inline = true threads = (NI, NJ, nvthreads) blocks = + (nvblocks, NH) pointwise_cuda_kernel!(pointwisefn!, NV, args...) + return nothing +end - # accretion: q_rai + q_sno -> q_rai or q_sno - @. Sᵖ = ifelse( - Tₐ(thp, ts) < mp.ps.T_freeze, - min( - limit(qᵣ, dt, 5), - CM1.accretion_snow_rain(mp.ps, mp.pr, mp.tv.rain, mp.tv.snow, mp.ce, qₛ, qᵣ, ρ), - ), - -min( - limit(qₛ, dt, 5), - CM1.accretion_snow_rain(mp.pr, mp.ps, mp.tv.snow, mp.tv.rain, mp.ce, qᵣ, qₛ, ρ), - ), - ) - @. Sqₛᵖ += Sᵖ - @. Sqᵣᵖ -= Sᵖ - @. Seₜᵖ += Sᵖ * Lf(thp, ts) - #! format: on +function pointwise_cuda_kernel!(pointwisefn!, NV, args...) + (i, j, tv) = threadIdx() + (bv, bh, _) = blockIdx() + v = tv + (bv - 1) * blockDim().z + if v ≤ NV + idx = CartesianIndex(i, j, 1, v, bh) + pointwisefn!(args..., idx) + end + return nothing end """