Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for CPUThreads and CPUProceses #39

Merged
merged 2 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 8 additions & 22 deletions test/ensembles.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
module TestEnsembles

using Test
using Random
using StableRNGs
using MLJEnsembles
using MLJBase
using ..Models
using CategoricalArrays
import Distributions
using StatisticalMeasures

## HELPER FUNCTIONS

@test MLJEnsembles._reducer([1, 2], [3, ]) == [1, 2, 3]
Expand Down Expand Up @@ -187,10 +175,10 @@ predict(ensemble_model, fitresult, MLJEnsembles.selectrows(X, test))

@testset "further test of sample weights" begin
## Note: This testset also indirectly tests for compatibility with the data-front end
# implemented by `KNNClassifier` as calls to `fit`/`predict` on an `Ensemble` model
# implemented by `KNNClassifier` as calls to `fit`/`predict` on an `Ensemble` model
# with `atom=KNNClassifier` would error if the ensemble implementation doesn't handle
# data front-end conversions properly.

rng = StableRNG(123)
N = 20
X = (x = rand(rng, 3N), );
Expand Down Expand Up @@ -224,18 +212,18 @@ predict(ensemble_model, fitresult, MLJEnsembles.selectrows(X, test))
end


## MACHINE TEST
## MACHINE TEST
## (INCLUDES TEST OF UPDATE.
## ALSO INCLUDES COMPATIBILITY TESTS FOR ENSEMBLES WITH ATOM MODELS HAVING A
## ALSO INCLUDES COMPATIBILITY TESTS FOR ENSEMBLES WITH ATOM MODELS HAVING A
## DIFFERENT DATA FRONT-END SEE #16)

@testset "machine tests" begin
@testset_accelerated "machine tests" acceleration begin
N =100
X = (x1=rand(N), x2=rand(N), x3=rand(N))
y = 2X.x1 - X.x2 + 0.05*rand(N)

atom = KNNRegressor(K=7)
ensemble_model = EnsembleModel(model=atom)
ensemble_model = EnsembleModel(; model=atom, acceleration)
ensemble = machine(ensemble_model, X, y)
train, test = partition(eachindex(y), 0.7)
fit!(ensemble, rows=train, verbosity=0)
Expand Down Expand Up @@ -264,15 +252,13 @@ end
atom;
bagging_fraction=0.6,
rng=123,
out_of_bag_measure = [log_loss, brier_score]
out_of_bag_measure = [log_loss, brier_score],
acceleration,
)
ensemble = machine(ensemble_model, X_, y_)
fit!(ensemble)
@test length(ensemble.fitresult.ensemble) == ensemble_model.n

end


end

true
26 changes: 24 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,26 @@
include("_models.jl")
using Distributed
# Thanks to https://stackoverflow.com/a/70895939/5056635 for the exeflags tip.
addprocs(; exeflags="--project=$(Base.active_project())")

@info "nprocs() = $(nprocs())"
import .Threads
@info "nthreads() = $(Threads.nthreads())"

include("test_utilities.jl")
include_everywhere("_models.jl")

@everywhere begin
using Test
using Random
using StableRNGs
using MLJEnsembles
using MLJBase
using ..Models
using CategoricalArrays
import Distributions
using StatisticalMeasures
import Distributed
end

include("ensembles.jl")
include("serialization.jl")

50 changes: 50 additions & 0 deletions test/test_utilities.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using Test

using ComputationalResources

macro testset_accelerated(name::String, var, ex)
testset_accelerated(name, var, ex)
end
macro testset_accelerated(name::String, var, opts::Expr, ex)
testset_accelerated(name, var, ex; eval(opts)...)
end
function testset_accelerated(name::String, var, ex; exclude=[])
final_ex = quote
local $var = CPU1()
@testset $name $ex
end

resources = AbstractResource[CPUProcesses(), CPUThreads()]

for res in resources
if any(x->typeof(res)<:x, exclude)
push!(final_ex.args, quote
local $var = $res
@testset $(name*" ($(typeof(res).name))") begin
@test_broken false
end
end)
else
push!(final_ex.args, quote
local $var = $res
@testset $(name*" ($(typeof(res).name))") $ex
end)
end
end
# preserve outer location if possible
if ex isa Expr && ex.head === :block && !isempty(ex.args) &&
ex.args[1] isa LineNumberNode
final_ex = Expr(:block, ex.args[1], final_ex)
end
return esc(final_ex)
end

function include_everywhere(filepath)
include(filepath) # Load on Node 1 first, triggering any precompile
if nprocs() > 1
fullpath = joinpath(@__DIR__, filepath)
@sync for p in workers()
@async remotecall_wait(include, p, fullpath)
end
end
end
Loading