Skip to content

Commit

Permalink
added symetry test
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Feb 5, 2025
1 parent 549ccab commit 22076f8
Showing 1 changed file with 55 additions and 23 deletions.
78 changes: 55 additions & 23 deletions test/scicloj/ml/tribuo_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,24 @@
(defn make-species-column [datatype categorical?
inference-target?
species-key->val-map]
(let [
meta
(let [meta
{:categorical? categorical?
:name :species
:datatype datatype
:n-elems 150
:inference-target? inference-target?
}]
:inference-target? inference-target?}]
(ds/new-column :species (map species-key->val-map iris-target-raw) meta)))


(defn make-iris-ds [species-column result-datatype]

(->
(assoc (data/iris-ds) :species species-column)
((fn [ds]
(if (some? result-datatype)
(ds/categorical->number ds [:species] [] result-datatype)
ds)))

(ds-mod/set-inference-target :species)))

(defn- validate [ds expected-target-val expected-accuracy]
Expand Down Expand Up @@ -212,22 +210,21 @@
0.65))


(t/deftest not-supported
(t/is (thrown? Exception
(validate
(make-iris-ds
(make-species-column :string
true ;categorical?
true ;inference-target?
{:setosa 1
:versicolor "12"
:virginica "a2a"})
nil)
"12"
0.94)))
)
(t/deftest not-supported
(t/is (thrown? Exception
(validate
(make-iris-ds
(make-species-column :string
true ;categorical?
true ;inference-target?
{:setosa 1
:versicolor "12"
:virginica "a2a"})
nil)
"12"
0.94))))

(ds/new-column :x [:a :b :c])
(ds/new-column :x [:a :b :c])

(defn- verify-evaluate [ds]
(let [make-pipefn (fn [opts]
Expand Down Expand Up @@ -265,14 +262,14 @@
(org.tribuo.Model/deserializeFromStream)
class
.getName)))

(t/is (= "org.tribuo.common.tree.TreeModel"
(->
evaluations flatten first :fit-ctx :model
(ml/thaw-model)
class
.getName)))

(t/is (= 0.8048780487804879
(->> evaluations
flatten
Expand All @@ -284,3 +281,38 @@
(t/deftest sonar-evaluate-2
(verify-evaluate (-> (data/sonar-ds)
(ds/categorical->number [:material]))))


(defn- validate-target-symetry [datatype]
(t/is (= datatype
(->>
(ml/train
(-> (ds/->dataset {:x [1 2 3 4]
:y [:a :b :c :d]})
(ds/categorical->number [:y] [] datatype)
(ds-mod/set-inference-target [:y]))
{:model-type :scicloj.ml.tribuo/classification
:tribuo-components [{:name "trainer"
:type "org.tribuo.classification.dtree.CARTClassificationTrainer"
:properties {:maxDepth "8"}}]
:tribuo-trainer-name "trainer"})
(ml/predict
(-> (ds/->dataset {:x [1 2 3 4]})))
:y
meta
:datatype))))


(t/deftest validate-target-sym
(validate-target-symetry :int8)
(validate-target-symetry :int16)
(validate-target-symetry :int32)
(validate-target-symetry :int64)
(validate-target-symetry :float32)
(validate-target-symetry :float64))






0 comments on commit 22076f8

Please sign in to comment.