Skip to content

Commit

Permalink
Merge pull request #899 from open-AIMS/remove-duplicate-rules
Browse files Browse the repository at this point in the history
Remove duplicate rules
  • Loading branch information
Zapiano authored Nov 26, 2024
2 parents d219e57 + 90a452b commit 646615a
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 14 deletions.
102 changes: 90 additions & 12 deletions src/analysis/rule_extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ past of the 'positive class' given that the condition is true or false.
struct Rule{V<:Vector{Vector},W<:Vector{Float64}}
condition::V
consequent::W

function Rule(condition, consequent)
return new{typeof(condition),typeof(consequent)}(
sort(condition; by=x -> x[1]), consequent
)
end
end

"""
Expand Down Expand Up @@ -103,18 +109,21 @@ function print_rules(rules::Vector{Rule{Vector{Vector},Vector{Float64}}})::Nothi
end

"""
cluster_rules(clusters::Vector{T}, X::DataFrame, max_rules::T; seed::Int64=123, kwargs...) where {T<:Integer,F<:Real}
cluster_rules(clusters::Union{BitVector,Vector{Bool}}, X::DataFrame, max_rules::T; kwargs...) where {T<:Int64}
cluster_rules(clusters::Vector{T}, X::DataFrame, max_rules::T; seed::Int64=123, remove_duplicates::Bool=true, kwargs...)::Vector{Rule{Vector{Vector},Vector{Float64}}} where {T<:Int64}
cluster_rules(clusters::Union{BitVector,Vector{Bool}}, X::DataFrame, max_rules::T; seed::Int64=123, remove_duplicates::Bool=true, kwargs...)::Vector{Rule{Vector{Vector},Vector{Float64}}} where {T<:Int64}
Use SIRUS package to extract rules from time series clusters based on some summary metric
(default is median). More information about the keyword arguments accepeted can be found in
MLJ's doc (https://juliaai.github.io/MLJ.jl/dev/models/StableRulesClassifier_SIRUS/).
# Arguments
- `clusters` : Vector of cluster indexes for each scenario outcome
- `X` : Features to be used as input by SIRUS
- `max_rules` : Maximum number of rules, to be used as input by SIRUS
- `seed` : Seed to be used by RGN
- `clusters` : Vector of cluster indexes for each scenario outcome.
- `X` : Features to be used as input by SIRUS.
- `max_rules` : Maximum number of rules, to be used as input by SIRUS.
- `seed` : Seed to be used by RGN. Defaults to 123.
- `remove_duplicates` : If true, duplicate rules will be removed from resulting ruleset. In
that case, the rule with the highest probability score is kept. Defaults to true.
- `kwargs` : Keyword arguments to be passed to `StableRulesClassifier`.
# Returns
A StableRules object (implemented by SIRUS).
Expand All @@ -129,8 +138,10 @@ A StableRules object (implemented by SIRUS).
Electron. J. Statist. 15 (1) 427 - 505.
https://doi.org//10.1214/20-EJS1792
"""
function cluster_rules(clusters::Vector{T}, X::DataFrame, max_rules::T;
seed::Int64=123, kwargs...) where {T<:Int64}
function cluster_rules(
clusters::Vector{T}, X::DataFrame, max_rules::T;
seed::Int64=123, remove_duplicates::Bool=true, kwargs...
)::Vector{Rule{Vector{Vector},Vector{Float64}}} where {T<:Int64}
# Set seed and Random Number Generator
rng = StableRNG(seed)

Expand All @@ -148,11 +159,78 @@ function cluster_rules(clusters::Vector{T}, X::DataFrame, max_rules::T;
error("Failed fitting SIRUS. Try increasing the number of scenarios/samples.")
end

return rules(mach.fitresult)
if remove_duplicates
return _remove_duplicates(rules(mach.fitresult))
else
return rules(mach.fitresult)
end
end
function cluster_rules(
clusters::Union{BitVector,Vector{Bool}}, X::DataFrame, max_rules::T;
seed::Int64=123, remove_duplicates::Bool=true, kwargs...
)::Vector{Rule{Vector{Vector},Vector{Float64}}} where {T<:Int64}
return cluster_rules(
convert.(Int64, clusters), X, max_rules;
seed=seed, remove_duplicates=remove_duplicates, kwargs...
)
end

"""
_remove_duplicates(rules)::Vector{Rule{Vector{Vector},Vector{Float64}}}
Identifies and removes duplicate rulesets (if any are found).
The criteria to choose which rule to keep is based on the rule consequence probability (the one with the highest
probability is kept). If there are more than one rule with the same highest probability,
then the first one is chosen.
# Returns
A ruleset with duplicate rules removed
"""
function _remove_duplicates(
rules::T
)::T where {T<:Vector{Rule{Vector{Vector},Vector{Float64}}}}
# Extract subclauses from each rule without value
subclauses = join.([_strip_value.(r.condition) for r in rules], "_&_")
unique_subclauses = unique(subclauses)

# Check if there are duplicate rules before moving on
n_unique_rules = length(unique_subclauses)
if n_unique_rules == length(rules)
return rules
end

n_rules = length(rules)
n_duplicates = n_rules - n_unique_rules
@warn "$n_duplicates of $n_rules duplicated rules were found and are going to be removed."

unique_rules::Vector{Rule} = Vector{Rule}(undef, n_unique_rules)
for (unique_idx, unique_subclause) in enumerate(unique_subclauses)
duplicate_rules_filter = unique_subclause .== subclauses

# If current rule has no duplicates go to next iteration
if sum(duplicate_rules_filter) == 1
unique_rules[unique_idx] = rules[duplicate_rules_filter][1]
continue
end

duplicate_rules = rules[duplicate_rules_filter]
max_probability_idx = findmax([r.consequent[1] for r in duplicate_rules])[2]
unique_rules[unique_idx] = duplicate_rules[max_probability_idx]
end

return unique_rules
end
function cluster_rules(clusters::Union{BitVector,Vector{Bool}}, X::DataFrame, max_rules::T;
kwargs...) where {T<:Int64}
return cluster_rules(convert.(Int64, clusters), X, max_rules; kwargs...)

"""
_strip_value(condition_subclause::Vector)
Helper function that extracts factor name and direction from a rule condition subclause.
Besides having just one line, this was extracted to a separate function to allow/facilitate
broadcasting this operation.
"""
function _strip_value(condition_subclause::Vector)
return join(condition_subclause[1:2], "__")
end

"""
Expand Down
20 changes: 18 additions & 2 deletions test/analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,15 @@ function test_rs_w_fig(rs::ADRIA.ResultSet, scens::ADRIA.DataFrame)
scenarios_iv = scens[:, fields_iv]

# Use SIRUS algorithm to extract rules
max_rules = 4
rules_iv = ADRIA.analysis.cluster_rules(target_clusters, scenarios_iv, max_rules)
max_rules = 10
rules_iv = ADRIA.analysis.cluster_rules(
target_clusters, scenarios_iv, max_rules; remove_duplicates=true
)
rules_iv_duplicates = ADRIA.analysis.cluster_rules(
target_clusters, scenarios_iv, max_rules; remove_duplicates=false
)
ADRIA.analysis.print_rules(rules_iv)
ADRIA.analysis.print_rules(rules_iv_duplicates)

# Plot scatters for each rule highlighting the area selected them
rules_scatter_fig = ADRIA.viz.rules_scatter(
Expand All @@ -259,6 +266,15 @@ function test_rs_w_fig(rs::ADRIA.ResultSet, scens::ADRIA.DataFrame)
opts=opts
)

ADRIA.viz.rules_scatter(
rs,
scenarios_iv,
target_clusters,
rules_iv_duplicates;
fig_opts=fig_opts,
opts=opts
)

# Save final figure
# save("rules_scatter.png", rules_scatter_fig)

Expand Down

0 comments on commit 646615a

Please sign in to comment.