Skip to content

Commit

Permalink
Merge pull request #614 from open-AIMS/rsa-large-samples
Browse files Browse the repository at this point in the history
Support RSA with large sample sizes
  • Loading branch information
Rosejoycrocker authored Dec 4, 2023
2 parents f21337f + 83028fa commit 97ece67
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions src/analysis/sensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ Get quantile value for a given categorical variable.
- `factor_name` : Contains true where the factor is categorical and false otherwise
- `steps` : Number of steps for defining bins
"""
function _get_cat_quantile(foi_spec::DataFrame, factor_name::Symbol, steps::Vector{Float64})
fact_idx = foi_spec.fieldname .== factor_name
function _get_cat_quantile(foi_spec::DataFrame, factor_name::Symbol, steps::Vector{Float64})::Vector{Float64}
fact_idx::BitVector = foi_spec.fieldname .== factor_name
lb = foi_spec.lower_bound[fact_idx][1]
ub = foi_spec.upper_bound[fact_idx][1]

Expand Down Expand Up @@ -392,7 +392,8 @@ function tsa(rs::ResultSet, y::AbstractMatrix{<:Real})::NamedDimsArray
end

"""
rsa(X::DataFrame, y::Vector{<:Real}, model_spec::DataFrame; S=10)::NamedDimsArray
rsa(X::DataFrame, y::Vector{<:Real}, model_spec::DataFrame; S::Int64=10)::NamedDimsArray
rsa(rs::ResultSet, y::AbstractVector{<:Real}, factors::Vector{Symbol}; S::Int64=10)::NamedDimsArray
rsa(rs::ResultSet, y::AbstractArray{<:Real}; S::Int64=10)::NamedDimsArray
Perform Regional Sensitivity Analysis.
Expand Down Expand Up @@ -424,6 +425,7 @@ Note: Values of type `missing` indicate a lack of samples in the region.
- `X` : scenario specification
- `y` : scenario outcomes
- `model_spec` : Model specification, as extracted by `ADRIA.model_spec(domain)` or from a `ResultSet`
- `factors` : Specific model factors to examine
- `S` : number of bins to slice factor space into (default: 10)
# Returns
Expand Down Expand Up @@ -454,7 +456,7 @@ function rsa(
)::NamedDimsArray
N, D = size(X)

X_di = @MVector zeros(N)
X_di = zeros(N)
sel = trues(N)
factors = Symbol.(names(X))

Expand All @@ -465,7 +467,7 @@ function rsa(
S = _category_bins(S, foi_spec[is_cat, :])
end

X_q = @MVector zeros(S + 1)
X_q = zeros(S + 1)
r_s = zeros(Union{Missing,Float64}, S, D)
seq = collect(0.0:(1 / S):1.0)

Expand Down Expand Up @@ -508,7 +510,17 @@ end
function rsa(
rs::ResultSet, y::AbstractVector{<:Real}; S::Int64=10
)::NamedDimsArray
return rsa(rs.inputs[:, Not(:RCP)], vec(y), rs.model_spec; S=S)
return rsa(rs.inputs[!, Not(:RCP)], y, rs.model_spec; S=S)
end
function rsa(
rs::ResultSet, y::AbstractVector{<:Real}, factors::Vector{Symbol}; S::Int64=10
)::NamedDimsArray
return rsa(
rs.inputs[!, Not(:RCP)][!, factors],
y,
rs.model_spec[rs.model_spec.fieldname .∈ [factors], :];
S=S
)
end


Expand Down

0 comments on commit 97ece67

Please sign in to comment.