diff --git a/src/analysis/sensitivity.jl b/src/analysis/sensitivity.jl index 038da0b55..db6c96788 100644 --- a/src/analysis/sensitivity.jl +++ b/src/analysis/sensitivity.jl @@ -71,8 +71,8 @@ Get quantile value for a given categorical variable. - `factor_name` : Contains true where the factor is categorical and false otherwise - `steps` : Number of steps for defining bins """ -function _get_cat_quantile(foi_spec::DataFrame, factor_name::Symbol, steps::Vector{Float64}) - fact_idx = foi_spec.fieldname .== factor_name +function _get_cat_quantile(foi_spec::DataFrame, factor_name::Symbol, steps::Vector{Float64})::Vector{Float64} + fact_idx::BitVector = foi_spec.fieldname .== factor_name lb = foi_spec.lower_bound[fact_idx][1] ub = foi_spec.upper_bound[fact_idx][1] @@ -392,7 +392,8 @@ function tsa(rs::ResultSet, y::AbstractMatrix{<:Real})::NamedDimsArray end """ - rsa(X::DataFrame, y::Vector{<:Real}, model_spec::DataFrame; S=10)::NamedDimsArray + rsa(X::DataFrame, y::Vector{<:Real}, model_spec::DataFrame; S::Int64=10)::NamedDimsArray + rsa(rs::ResultSet, y::AbstractVector{<:Real}, factors::Vector{Symbol}; S::Int64=10)::NamedDimsArray rsa(rs::ResultSet, y::AbstractArray{<:Real}; S::Int64=10)::NamedDimsArray Perform Regional Sensitivity Analysis. @@ -424,6 +425,7 @@ Note: Values of type `missing` indicate a lack of samples in the region. - `X` : scenario specification - `y` : scenario outcomes - `model_spec` : Model specification, as extracted by `ADRIA.model_spec(domain)` or from a `ResultSet` +- `factors` : Specific model factors to examine - `S` : number of bins to slice factor space into (default: 10) # Returns @@ -454,7 +456,7 @@ function rsa( )::NamedDimsArray N, D = size(X) - X_di = @MVector zeros(N) + X_di = zeros(N) sel = trues(N) factors = Symbol.(names(X)) @@ -465,7 +467,7 @@ function rsa( S = _category_bins(S, foi_spec[is_cat, :]) end - X_q = @MVector zeros(S + 1) + X_q = zeros(S + 1) r_s = zeros(Union{Missing,Float64}, S, D) seq = collect(0.0:(1 / S):1.0) @@ -508,7 +510,17 @@ end function rsa( rs::ResultSet, y::AbstractVector{<:Real}; S::Int64=10 )::NamedDimsArray - return rsa(rs.inputs[:, Not(:RCP)], vec(y), rs.model_spec; S=S) + return rsa(rs.inputs[!, Not(:RCP)], y, rs.model_spec; S=S) +end +function rsa( + rs::ResultSet, y::AbstractVector{<:Real}, factors::Vector{Symbol}; S::Int64=10 +)::NamedDimsArray + return rsa( + rs.inputs[!, Not(:RCP)][!, factors], + y, + rs.model_spec[rs.model_spec.fieldname .∈ [factors], :]; + S=S + ) end