-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
remove generics from model interfaces and results to make manipulatio…
…n much clearer, generics are now hidden inside abstract base classes
- Loading branch information
Showing
41 changed files
with
180 additions
and
185 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,9 +51,7 @@ | |
* <p> | ||
* User: Aurelian Tutuianu <[email protected]> | ||
*/ | ||
public class AdaBoostSAMME | ||
extends AbstractClassifierModel<AdaBoostSAMME, ClassifierResult<AdaBoostSAMME>> | ||
implements Printable { | ||
public class AdaBoostSAMME extends AbstractClassifierModel<AdaBoostSAMME, ClassifierResult> implements Printable { | ||
|
||
private static final long serialVersionUID = -9154973036108114765L; | ||
private static final double delta_error = 10e-10; | ||
|
@@ -199,8 +197,8 @@ private boolean learnRound(Frame df) { | |
} | ||
|
||
@Override | ||
protected ClassifierResult<AdaBoostSAMME> corePredict(Frame df, boolean withClasses, boolean withDistributions) { | ||
ClassifierResult<AdaBoostSAMME> fit = ClassifierResult.build(this, df, withClasses, true); | ||
protected ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDistributions) { | ||
ClassifierResult fit = ClassifierResult.build(this, df, withClasses, true); | ||
for (int i = 0; i < h.size(); i++) { | ||
ClassifierResult hp = h.get(i).predict(df, true, false); | ||
for (int j = 0; j < df.rowCount(); j++) { | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,7 +52,7 @@ | |
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> at 12/12/14. | ||
*/ | ||
public class GBTClassifierModel | ||
extends AbstractClassifierModel<GBTClassifierModel, ClassifierResult<GBTClassifierModel>> | ||
extends AbstractClassifierModel<GBTClassifierModel, ClassifierResult> | ||
implements Printable { | ||
|
||
private static final long serialVersionUID = -2979235364091072967L; | ||
|
@@ -191,22 +191,22 @@ private void buildAdditionalTree(Frame df, Var w, DMatrix yk) { | |
tree.fit(train, sample.weights, "##tt##"); | ||
trees.get(k).add(tree); | ||
|
||
RegressionResult<RTree> rr = tree.predict(df, false); | ||
RegressionResult rr = tree.predict(df, false); | ||
for (int i = 0; i < df.rowCount(); i++) { | ||
f.set(k, i, f.get(k, i) + shrinkage * rr.firstPrediction().getDouble(i)); | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public ClassifierResult<GBTClassifierModel> corePredict(Frame df, boolean withClasses, boolean withDistributions) { | ||
ClassifierResult<GBTClassifierModel> cr = ClassifierResult.build(this, df, withClasses, withDistributions); | ||
public ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDistributions) { | ||
ClassifierResult cr = ClassifierResult.build(this, df, withClasses, withDistributions); | ||
|
||
DMatrix p_f = SolidDMatrix.empty(K, df.rowCount()); | ||
|
||
for (int k = 0; k < K; k++) { | ||
for (RTree tree : trees.get(k)) { | ||
RegressionResult<RTree> rr = tree.predict(df, false); | ||
RegressionResult rr = tree.predict(df, false); | ||
for (int i = 0; i < df.rowCount(); i++) { | ||
p_f.set(k, i, p_f.get(k, i) + shrinkage * rr.firstPrediction().getDouble(i)); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,7 +54,6 @@ | |
import rapaio.ml.common.Capabilities; | ||
import rapaio.ml.common.VarSelector; | ||
import rapaio.ml.eval.metric.Confusion; | ||
import rapaio.printer.Printable; | ||
import rapaio.printer.Printer; | ||
import rapaio.printer.opt.POption; | ||
import rapaio.util.Pair; | ||
|
@@ -75,9 +74,7 @@ | |
* <p> | ||
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> on 4/16/15. | ||
*/ | ||
public class CForest | ||
extends AbstractClassifierModel<CForest, ClassifierResult<CForest>> | ||
implements Printable { | ||
public class CForest extends AbstractClassifierModel<CForest, ClassifierResult> { | ||
|
||
private static final long serialVersionUID = -145958939373105497L; | ||
|
||
|
@@ -462,12 +459,12 @@ private Pair<ClassifierModel, VarInt> buildWeakPredictor(Frame df, Var weights) | |
} | ||
|
||
@Override | ||
protected ClassifierResult<CForest> corePredict(Frame df, boolean withClasses, boolean withDensities) { | ||
ClassifierResult<CForest> cp = ClassifierResult.build(this, df, true, true); | ||
var treeFits = predictors.stream().parallel() | ||
.map(pred -> pred.predict(df, baggingMode.needsClass(), baggingMode.needsDensity())) | ||
protected ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDensities) { | ||
ClassifierResult cp = ClassifierResult.build(this, df, true, true); | ||
List<ClassifierResult> treeFits = predictors.stream().parallel() | ||
.map(pred -> (ClassifierResult) pred.predict(df, baggingMode.needsClass(), baggingMode.needsDensity())) | ||
.collect(Collectors.toList()); | ||
baggingMode.computeDensity(firstTargetLevels(), new ArrayList<>(treeFits), cp.firstClasses(), cp.firstDensity()); | ||
baggingMode.computeDensity(firstTargetLevels(), treeFits, cp.firstClasses(), cp.firstDensity()); | ||
return cp; | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,7 +53,7 @@ | |
* <p> | ||
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> on 9/30/15. | ||
*/ | ||
public class CBinaryLogisticStacking extends AbstractClassifierModel<CBinaryLogisticStacking, ClassifierResult<CBinaryLogisticStacking>> implements Printable { | ||
public class CBinaryLogisticStacking extends AbstractClassifierModel<CBinaryLogisticStacking, ClassifierResult> implements Printable { | ||
|
||
private static final long serialVersionUID = -9087871586729573030L; | ||
|
||
|
@@ -164,7 +164,7 @@ protected PredSetup preparePredict(Frame df, boolean withClasses, boolean withDi | |
} | ||
|
||
@Override | ||
protected ClassifierResult<CBinaryLogisticStacking> corePredict(Frame df, boolean withClasses, boolean withDistributions) { | ||
protected ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDistributions) { | ||
return ClassifierResult.copy(this, df, withClasses, withDistributions, log.predict(df)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,7 +53,7 @@ | |
* <p> | ||
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> on 9/30/15. | ||
*/ | ||
public class CStacking extends AbstractClassifierModel<CStacking, ClassifierResult<CStacking>> implements Printable { | ||
public class CStacking extends AbstractClassifierModel<CStacking, ClassifierResult> implements Printable { | ||
|
||
private static final long serialVersionUID = -9087871586729573030L; | ||
|
||
|
@@ -155,7 +155,7 @@ protected PredSetup baseFit(Frame df, boolean withClasses, boolean withDistribut | |
} | ||
|
||
@Override | ||
protected ClassifierResult<CStacking> corePredict(Frame df, boolean withClasses, boolean withDistributions) { | ||
protected ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDistributions) { | ||
return ClassifierResult.copy(this, df, withClasses, withDistributions, stacker.predict(df)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,7 +53,7 @@ | |
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> on 11/11/15. | ||
*/ | ||
public class CStepwiseSelection | ||
extends AbstractClassifierModel<CStepwiseSelection, ClassifierResult<CStepwiseSelection>> implements Printable { | ||
extends AbstractClassifierModel<CStepwiseSelection, ClassifierResult> implements Printable { | ||
|
||
private static final long serialVersionUID = 2642562123626893974L; | ||
ClassifierModel best; | ||
|
@@ -224,7 +224,7 @@ protected boolean coreFit(Frame df, Var weights) { | |
} | ||
|
||
@Override | ||
protected ClassifierResult<CStepwiseSelection> corePredict(Frame df, boolean withClasses, boolean withDistributions) { | ||
protected ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDistributions) { | ||
return ClassifierResult.copy(this, df, withClasses, withDistributions, best.predict(df, withClasses, withDistributions)); | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,7 +68,7 @@ | |
* | ||
* @author <a href="mailto:[email protected]">Aurelian Tutuianu</a> | ||
*/ | ||
public class CTree extends AbstractClassifierModel<CTree, ClassifierResult<CTree>> implements Printable { | ||
public class CTree extends AbstractClassifierModel<CTree, ClassifierResult> implements Printable { | ||
|
||
private static final long serialVersionUID = 1203926824359387358L; | ||
private static final Map<VType, CTreeTest> DEFAULT_TEST_MAP; | ||
|
@@ -500,8 +500,8 @@ private void buildIndexMap(CTreeNode node, HashMap<Integer, Integer> indexMap) { | |
} | ||
|
||
@Override | ||
protected ClassifierResult<CTree> corePredict(Frame df, boolean withClasses, boolean withDensities) { | ||
ClassifierResult<CTree> prediction = ClassifierResult.build(this, df, withClasses, withDensities); | ||
protected ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDensities) { | ||
ClassifierResult prediction = ClassifierResult.build(this, df, withClasses, withDensities); | ||
for (int i = 0; i < df.rowCount(); i++) { | ||
Pair<String, DensityVector<String>> res = predictPoint(this, root, i, df); | ||
String label = res._1; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,7 +45,6 @@ | |
import rapaio.ml.regression.RegressionResult; | ||
import rapaio.ml.regression.simple.L2RegressionModel; | ||
import rapaio.ml.regression.tree.RTree; | ||
import rapaio.printer.Printable; | ||
import rapaio.printer.Printer; | ||
import rapaio.printer.opt.POption; | ||
|
||
|
@@ -61,8 +60,7 @@ | |
* User: Aurelian Tutuianu <[email protected]> | ||
*/ | ||
@Deprecated | ||
public class GBTRegressionModel extends AbstractRegressionModel<GBTRegressionModel, RegressionResult<GBTRegressionModel>> | ||
implements Printable { | ||
public class GBTRegressionModel extends AbstractRegressionModel<GBTRegressionModel, RegressionResult> { | ||
|
||
private static final long serialVersionUID = 4559540258922653130L; | ||
|
||
|
@@ -197,9 +195,9 @@ protected boolean coreFit(Frame df, Var weights) { | |
} | ||
|
||
@Override | ||
protected RegressionResult<GBTRegressionModel> corePredict(final Frame df, final boolean withResiduals) { | ||
RegressionResult<GBTRegressionModel> pred = RegressionResult.build(this, df, withResiduals); | ||
RegressionResult<GBTRegressionModel> initPred = initRegressionModel.predict(df, false); | ||
protected RegressionResult corePredict(final Frame df, final boolean withResiduals) { | ||
RegressionResult pred = RegressionResult.build(this, df, withResiduals); | ||
RegressionResult initPred = initRegressionModel.predict(df, false); | ||
for (int i = 0; i < df.rowCount(); i++) { | ||
pred.firstPrediction().setDouble(i, initPred.firstPrediction().getDouble(i)); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,7 +37,6 @@ | |
import rapaio.ml.regression.RegressionModel; | ||
import rapaio.ml.regression.RegressionResult; | ||
import rapaio.ml.regression.tree.RTree; | ||
import rapaio.printer.Printable; | ||
import rapaio.printer.Printer; | ||
import rapaio.printer.opt.POption; | ||
|
||
|
@@ -50,8 +49,7 @@ | |
/** | ||
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> at 1/15/15. | ||
*/ | ||
public class RForest extends AbstractRegressionModel<RForest, RegressionResult<RForest>> | ||
implements Printable { | ||
public class RForest extends AbstractRegressionModel<RForest, RegressionResult> { | ||
|
||
private static final long serialVersionUID = -3926256335736143438L; | ||
|
||
|
@@ -129,8 +127,8 @@ public List<RegressionModel> getRegressors() { | |
} | ||
|
||
@Override | ||
protected RegressionResult<RForest> corePredict(Frame df, boolean withResiduals) { | ||
RegressionResult<RForest> fit = RegressionResult.build(this, df, withResiduals); | ||
protected RegressionResult corePredict(Frame df, boolean withResiduals) { | ||
RegressionResult fit = RegressionResult.build(this, df, withResiduals); | ||
List<VarDouble> results = regressors | ||
.parallelStream() | ||
.map(r -> r.predict(df, false).firstPrediction()) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,7 @@ | |
/** | ||
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> on 5/21/19. | ||
*/ | ||
public interface GBTRtree<M extends RegressionModel<M, R>, R extends RegressionResult<M>> extends RegressionModel<M, R> { | ||
public interface GBTRtree<M extends RegressionModel, R extends RegressionResult> extends RegressionModel { | ||
|
||
void boostUpdate(Frame x, Var y, Var fx, GBTRegressionLoss lossFunction); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,8 +47,8 @@ | |
/** | ||
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> on 4/16/19. | ||
*/ | ||
public class NestedBoostingRTree extends AbstractRegressionModel<NestedBoostingRTree, RegressionResult<NestedBoostingRTree>> | ||
implements GBTRtree<NestedBoostingRTree, RegressionResult<NestedBoostingRTree>> { | ||
public class NestedBoostingRTree extends AbstractRegressionModel<NestedBoostingRTree, RegressionResult> | ||
implements GBTRtree<NestedBoostingRTree, RegressionResult> { | ||
|
||
private static final long serialVersionUID = 1864784340491461993L; | ||
private int minCount = 5; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,8 +48,8 @@ | |
/** | ||
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> on 6/19/19. | ||
*/ | ||
public class SmoothRTree extends AbstractRegressionModel<SmoothRTree, RegressionResult<SmoothRTree>> | ||
implements GBTRtree<SmoothRTree, RegressionResult<SmoothRTree>> { | ||
public class SmoothRTree extends AbstractRegressionModel<SmoothRTree, RegressionResult> | ||
implements GBTRtree<SmoothRTree, RegressionResult> { | ||
|
||
private static final long serialVersionUID = 5062591010395009141L; | ||
|
||
|
@@ -168,8 +168,8 @@ protected boolean coreFit(Frame df, Var weights) { | |
} | ||
|
||
@Override | ||
protected RegressionResult<SmoothRTree> corePredict(Frame df, boolean withResiduals) { | ||
RegressionResult<SmoothRTree> prediction = RegressionResult.build(this, df, withResiduals); | ||
protected RegressionResult corePredict(Frame df, boolean withResiduals) { | ||
RegressionResult prediction = RegressionResult.build(this, df, withResiduals); | ||
for (int i = 0; i < df.rowCount(); i++) { | ||
prediction.firstPrediction().setDouble(i, root.predict(df, i, this, 1.0)); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,8 +52,7 @@ | |
* | ||
* @author <a href="mailto:[email protected]">Aurelian Tutuianu</a> | ||
*/ | ||
public abstract class AbstractClassifierModel<M extends ClassifierModel<M, R>, R extends ClassifierResult<M>> | ||
implements ClassifierModel<M, R> { | ||
public abstract class AbstractClassifierModel<M extends ClassifierModel, R extends ClassifierResult> implements ClassifierModel { | ||
|
||
private static final long serialVersionUID = -6866948033065091047L; | ||
|
||
|
@@ -122,18 +121,18 @@ public BiConsumer<M, Integer> runningHook() { | |
} | ||
|
||
@Override | ||
public M withRunningHook(BiConsumer<M, Integer> runningHook) { | ||
this.runningHook = runningHook; | ||
return (M) this; | ||
public <T extends ClassifierModel> T withRunningHook(BiConsumer<? extends ClassifierModel, Integer> runningHook) { | ||
this.runningHook = (BiConsumer<M, Integer>) runningHook; | ||
return (T)this; | ||
} | ||
|
||
public BiFunction<M, Integer, Boolean> stoppingHook() { | ||
return stoppingHook; | ||
} | ||
|
||
public M withStoppingHook(BiFunction<M, Integer, Boolean> stoppingHook) { | ||
this.stoppingHook = stoppingHook; | ||
return (M) this; | ||
public <T extends ClassifierModel> T withStoppingHook(BiFunction<? extends ClassifierModel, Integer, Boolean> stoppingHook) { | ||
this.stoppingHook = (BiFunction<M, Integer, Boolean>) stoppingHook; | ||
return (T) this; | ||
} | ||
|
||
@Override | ||
|
Oops, something went wrong.