Skip to content

Commit

Permalink
remove generics from model interfaces and results to make manipulatio…
Browse files Browse the repository at this point in the history
…n much clearer, generics are now hidden inside abstract base classes
  • Loading branch information
padreati committed May 30, 2020
1 parent ca375e0 commit 0e07d53
Show file tree
Hide file tree
Showing 41 changed files with 180 additions and 185 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ notebooks or to document the idea you are working on. You have to install jupyte
on that you can follow the instruction from [here](https://github.com/SpencerPark/IJava#installing). The following notation is
specific to IJava kernel jupyter notation.

%mavenRepo oss-sonatype-snapshots https://oss.sonatype.org/content/repositories/snapshots/
%maven io.github.padreati:rapaio:2.2.2

The last option to use the library is do download the release files from this repository. If you use IntelliJ Idea IDE, you can use
Expand Down
8 changes: 3 additions & 5 deletions src/rapaio/experiment/ml/classifier/boost/AdaBoostSAMME.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@
* <p>
* User: Aurelian Tutuianu <[email protected]>
*/
public class AdaBoostSAMME
extends AbstractClassifierModel<AdaBoostSAMME, ClassifierResult<AdaBoostSAMME>>
implements Printable {
public class AdaBoostSAMME extends AbstractClassifierModel<AdaBoostSAMME, ClassifierResult> implements Printable {

private static final long serialVersionUID = -9154973036108114765L;
private static final double delta_error = 10e-10;
Expand Down Expand Up @@ -199,8 +197,8 @@ private boolean learnRound(Frame df) {
}

@Override
protected ClassifierResult<AdaBoostSAMME> corePredict(Frame df, boolean withClasses, boolean withDistributions) {
ClassifierResult<AdaBoostSAMME> fit = ClassifierResult.build(this, df, withClasses, true);
protected ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDistributions) {
ClassifierResult fit = ClassifierResult.build(this, df, withClasses, true);
for (int i = 0; i < h.size(); i++) {
ClassifierResult hp = h.get(i).predict(df, true, false);
for (int j = 0; j < df.rowCount(); j++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> at 12/12/14.
*/
public class GBTClassifierModel
extends AbstractClassifierModel<GBTClassifierModel, ClassifierResult<GBTClassifierModel>>
extends AbstractClassifierModel<GBTClassifierModel, ClassifierResult>
implements Printable {

private static final long serialVersionUID = -2979235364091072967L;
Expand Down Expand Up @@ -191,22 +191,22 @@ private void buildAdditionalTree(Frame df, Var w, DMatrix yk) {
tree.fit(train, sample.weights, "##tt##");
trees.get(k).add(tree);

RegressionResult<RTree> rr = tree.predict(df, false);
RegressionResult rr = tree.predict(df, false);
for (int i = 0; i < df.rowCount(); i++) {
f.set(k, i, f.get(k, i) + shrinkage * rr.firstPrediction().getDouble(i));
}
}
}

@Override
public ClassifierResult<GBTClassifierModel> corePredict(Frame df, boolean withClasses, boolean withDistributions) {
ClassifierResult<GBTClassifierModel> cr = ClassifierResult.build(this, df, withClasses, withDistributions);
public ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDistributions) {
ClassifierResult cr = ClassifierResult.build(this, df, withClasses, withDistributions);

DMatrix p_f = SolidDMatrix.empty(K, df.rowCount());

for (int k = 0; k < K; k++) {
for (RTree tree : trees.get(k)) {
RegressionResult<RTree> rr = tree.predict(df, false);
RegressionResult rr = tree.predict(df, false);
for (int i = 0; i < df.rowCount(); i++) {
p_f.set(k, i, p_f.get(k, i) + shrinkage * rr.firstPrediction().getDouble(i));
}
Expand Down
15 changes: 6 additions & 9 deletions src/rapaio/experiment/ml/classifier/ensemble/CForest.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
import rapaio.ml.common.Capabilities;
import rapaio.ml.common.VarSelector;
import rapaio.ml.eval.metric.Confusion;
import rapaio.printer.Printable;
import rapaio.printer.Printer;
import rapaio.printer.opt.POption;
import rapaio.util.Pair;
Expand All @@ -75,9 +74,7 @@
* <p>
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> on 4/16/15.
*/
public class CForest
extends AbstractClassifierModel<CForest, ClassifierResult<CForest>>
implements Printable {
public class CForest extends AbstractClassifierModel<CForest, ClassifierResult> {

private static final long serialVersionUID = -145958939373105497L;

Expand Down Expand Up @@ -462,12 +459,12 @@ private Pair<ClassifierModel, VarInt> buildWeakPredictor(Frame df, Var weights)
}

@Override
protected ClassifierResult<CForest> corePredict(Frame df, boolean withClasses, boolean withDensities) {
ClassifierResult<CForest> cp = ClassifierResult.build(this, df, true, true);
var treeFits = predictors.stream().parallel()
.map(pred -> pred.predict(df, baggingMode.needsClass(), baggingMode.needsDensity()))
protected ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDensities) {
ClassifierResult cp = ClassifierResult.build(this, df, true, true);
List<ClassifierResult> treeFits = predictors.stream().parallel()
.map(pred -> (ClassifierResult) pred.predict(df, baggingMode.needsClass(), baggingMode.needsDensity()))
.collect(Collectors.toList());
baggingMode.computeDensity(firstTargetLevels(), new ArrayList<>(treeFits), cp.firstClasses(), cp.firstDensity());
baggingMode.computeDensity(firstTargetLevels(), treeFits, cp.firstClasses(), cp.firstDensity());
return cp;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
* <p>
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> on 9/30/15.
*/
public class CBinaryLogisticStacking extends AbstractClassifierModel<CBinaryLogisticStacking, ClassifierResult<CBinaryLogisticStacking>> implements Printable {
public class CBinaryLogisticStacking extends AbstractClassifierModel<CBinaryLogisticStacking, ClassifierResult> implements Printable {

private static final long serialVersionUID = -9087871586729573030L;

Expand Down Expand Up @@ -164,7 +164,7 @@ protected PredSetup preparePredict(Frame df, boolean withClasses, boolean withDi
}

@Override
protected ClassifierResult<CBinaryLogisticStacking> corePredict(Frame df, boolean withClasses, boolean withDistributions) {
protected ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDistributions) {
return ClassifierResult.copy(this, df, withClasses, withDistributions, log.predict(df));
}
}
4 changes: 2 additions & 2 deletions src/rapaio/experiment/ml/classifier/meta/CStacking.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
* <p>
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> on 9/30/15.
*/
public class CStacking extends AbstractClassifierModel<CStacking, ClassifierResult<CStacking>> implements Printable {
public class CStacking extends AbstractClassifierModel<CStacking, ClassifierResult> implements Printable {

private static final long serialVersionUID = -9087871586729573030L;

Expand Down Expand Up @@ -155,7 +155,7 @@ protected PredSetup baseFit(Frame df, boolean withClasses, boolean withDistribut
}

@Override
protected ClassifierResult<CStacking> corePredict(Frame df, boolean withClasses, boolean withDistributions) {
protected ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDistributions) {
return ClassifierResult.copy(this, df, withClasses, withDistributions, stacker.predict(df));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> on 11/11/15.
*/
public class CStepwiseSelection
extends AbstractClassifierModel<CStepwiseSelection, ClassifierResult<CStepwiseSelection>> implements Printable {
extends AbstractClassifierModel<CStepwiseSelection, ClassifierResult> implements Printable {

private static final long serialVersionUID = 2642562123626893974L;
ClassifierModel best;
Expand Down Expand Up @@ -224,7 +224,7 @@ protected boolean coreFit(Frame df, Var weights) {
}

@Override
protected ClassifierResult<CStepwiseSelection> corePredict(Frame df, boolean withClasses, boolean withDistributions) {
protected ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDistributions) {
return ClassifierResult.copy(this, df, withClasses, withDistributions, best.predict(df, withClasses, withDistributions));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
*/
@Deprecated
public class SplitClassifierModel
extends AbstractClassifierModel<SplitClassifierModel, ClassifierResult<SplitClassifierModel>> implements Printable {
extends AbstractClassifierModel<SplitClassifierModel, ClassifierResult> implements Printable {

private static final long serialVersionUID = 3332377951136731541L;

Expand Down Expand Up @@ -126,15 +126,15 @@ public boolean coreFit(Frame df, Var weights) {
}

@Override
public ClassifierResult<SplitClassifierModel> corePredict(Frame df, boolean withClasses, boolean withDensities) {
public ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDensities) {

ClassifierResult<SplitClassifierModel> pred = ClassifierResult.build(this, df, withClasses, withDensities);
ClassifierResult pred = ClassifierResult.build(this, df, withClasses, withDensities);
df.stream().forEach(spot -> {
for (Split split : splits) {
if (split.predicate.test(spot)) {

Frame f = MappedFrame.byRow(df, spot.row());
ClassifierResult<ClassifierModel> p = split.classifierModel.predict(f, withClasses, withDensities);
ClassifierResult p = split.classifierModel.predict(f, withClasses, withDensities);

if (withClasses) {
for (String targetVar : targetNames()) {
Expand Down
6 changes: 3 additions & 3 deletions src/rapaio/experiment/ml/classifier/svm/BinarySMO.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
/**
* Class for building a binary support vector machine.
*/
public class BinarySMO extends AbstractClassifierModel<BinarySMO, ClassifierResult<BinarySMO>> implements Serializable, Printable {
public class BinarySMO extends AbstractClassifierModel<BinarySMO, ClassifierResult> implements Serializable, Printable {

private static final long serialVersionUID = 1208515184777030598L;

Expand Down Expand Up @@ -442,8 +442,8 @@ protected boolean coreFit(Frame df, Var weights) {


@Override
protected ClassifierResult<BinarySMO> corePredict(Frame df, boolean withClasses, boolean withDistributions) {
ClassifierResult<BinarySMO> cr = ClassifierResult.build(this, df, withClasses, withDistributions);
protected ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDistributions) {
ClassifierResult cr = ClassifierResult.build(this, df, withClasses, withDistributions);
for (int i = 0; i < df.rowCount(); i++) {
double pred = predict(df, i);

Expand Down
6 changes: 3 additions & 3 deletions src/rapaio/experiment/ml/classifier/tree/CTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
*
* @author <a href="mailto:[email protected]">Aurelian Tutuianu</a>
*/
public class CTree extends AbstractClassifierModel<CTree, ClassifierResult<CTree>> implements Printable {
public class CTree extends AbstractClassifierModel<CTree, ClassifierResult> implements Printable {

private static final long serialVersionUID = 1203926824359387358L;
private static final Map<VType, CTreeTest> DEFAULT_TEST_MAP;
Expand Down Expand Up @@ -500,8 +500,8 @@ private void buildIndexMap(CTreeNode node, HashMap<Integer, Integer> indexMap) {
}

@Override
protected ClassifierResult<CTree> corePredict(Frame df, boolean withClasses, boolean withDensities) {
ClassifierResult<CTree> prediction = ClassifierResult.build(this, df, withClasses, withDensities);
protected ClassifierResult corePredict(Frame df, boolean withClasses, boolean withDensities) {
ClassifierResult prediction = ClassifierResult.build(this, df, withClasses, withDensities);
for (int i = 0; i < df.rowCount(); i++) {
Pair<String, DensityVector<String>> res = predictPoint(this, root, i, df);
String label = res._1;
Expand Down
4 changes: 2 additions & 2 deletions src/rapaio/experiment/ml/classifier/tree/CTreeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public CTreeCandidate computeCandidate(CTree c, Frame df, Var weights, String te
int[] rows = new int[df.rowCount()];
int len = 0;
for (int i = 0; i < df.rowCount(); i++) {
if(!df.isMissing(i, testNameIndex)) {
if (!df.isMissing(i, testNameIndex)) {
rows[len++] = i;
dt.increment(1, dt.colIndex().getIndex(df, targetName, i), weights.getDouble(i));
}
Expand Down Expand Up @@ -250,7 +250,7 @@ public CTreeCandidate computeCandidate(CTree c, Frame df, Var weights, String te

double[] rowCounts = counts.rowTotals();
for (int i = 1; i < df.levels(testName).size(); i++) {
if (rowCounts[i] < c.minCount())
if (rowCounts[i - 1] < c.minCount())
continue;

String testLabel = df.rvar(testName).levels().get(i);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import rapaio.ml.regression.RegressionResult;
import rapaio.ml.regression.simple.L2RegressionModel;
import rapaio.ml.regression.tree.RTree;
import rapaio.printer.Printable;
import rapaio.printer.Printer;
import rapaio.printer.opt.POption;

Expand All @@ -61,8 +60,7 @@
* User: Aurelian Tutuianu <[email protected]>
*/
@Deprecated
public class GBTRegressionModel extends AbstractRegressionModel<GBTRegressionModel, RegressionResult<GBTRegressionModel>>
implements Printable {
public class GBTRegressionModel extends AbstractRegressionModel<GBTRegressionModel, RegressionResult> {

private static final long serialVersionUID = 4559540258922653130L;

Expand Down Expand Up @@ -197,9 +195,9 @@ protected boolean coreFit(Frame df, Var weights) {
}

@Override
protected RegressionResult<GBTRegressionModel> corePredict(final Frame df, final boolean withResiduals) {
RegressionResult<GBTRegressionModel> pred = RegressionResult.build(this, df, withResiduals);
RegressionResult<GBTRegressionModel> initPred = initRegressionModel.predict(df, false);
protected RegressionResult corePredict(final Frame df, final boolean withResiduals) {
RegressionResult pred = RegressionResult.build(this, df, withResiduals);
RegressionResult initPred = initRegressionModel.predict(df, false);
for (int i = 0; i < df.rowCount(); i++) {
pred.firstPrediction().setDouble(i, initPred.firstPrediction().getDouble(i));
}
Expand Down
8 changes: 3 additions & 5 deletions src/rapaio/experiment/ml/regression/ensemble/RForest.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import rapaio.ml.regression.RegressionModel;
import rapaio.ml.regression.RegressionResult;
import rapaio.ml.regression.tree.RTree;
import rapaio.printer.Printable;
import rapaio.printer.Printer;
import rapaio.printer.opt.POption;

Expand All @@ -50,8 +49,7 @@
/**
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> at 1/15/15.
*/
public class RForest extends AbstractRegressionModel<RForest, RegressionResult<RForest>>
implements Printable {
public class RForest extends AbstractRegressionModel<RForest, RegressionResult> {

private static final long serialVersionUID = -3926256335736143438L;

Expand Down Expand Up @@ -129,8 +127,8 @@ public List<RegressionModel> getRegressors() {
}

@Override
protected RegressionResult<RForest> corePredict(Frame df, boolean withResiduals) {
RegressionResult<RForest> fit = RegressionResult.build(this, df, withResiduals);
protected RegressionResult corePredict(Frame df, boolean withResiduals) {
RegressionResult fit = RegressionResult.build(this, df, withResiduals);
List<VarDouble> results = regressors
.parallelStream()
.map(r -> r.predict(df, false).firstPrediction())
Expand Down
2 changes: 1 addition & 1 deletion src/rapaio/experiment/ml/regression/tree/GBTRtree.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
/**
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> on 5/21/19.
*/
public interface GBTRtree<M extends RegressionModel<M, R>, R extends RegressionResult<M>> extends RegressionModel<M, R> {
public interface GBTRtree<M extends RegressionModel, R extends RegressionResult> extends RegressionModel {

void boostUpdate(Frame x, Var y, Var fx, GBTRegressionLoss lossFunction);
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
/**
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> on 4/16/19.
*/
public class NestedBoostingRTree extends AbstractRegressionModel<NestedBoostingRTree, RegressionResult<NestedBoostingRTree>>
implements GBTRtree<NestedBoostingRTree, RegressionResult<NestedBoostingRTree>> {
public class NestedBoostingRTree extends AbstractRegressionModel<NestedBoostingRTree, RegressionResult>
implements GBTRtree<NestedBoostingRTree, RegressionResult> {

private static final long serialVersionUID = 1864784340491461993L;
private int minCount = 5;
Expand Down
8 changes: 4 additions & 4 deletions src/rapaio/experiment/ml/regression/tree/SmoothRTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
/**
* Created by <a href="mailto:[email protected]">Aurelian Tutuianu</a> on 6/19/19.
*/
public class SmoothRTree extends AbstractRegressionModel<SmoothRTree, RegressionResult<SmoothRTree>>
implements GBTRtree<SmoothRTree, RegressionResult<SmoothRTree>> {
public class SmoothRTree extends AbstractRegressionModel<SmoothRTree, RegressionResult>
implements GBTRtree<SmoothRTree, RegressionResult> {

private static final long serialVersionUID = 5062591010395009141L;

Expand Down Expand Up @@ -168,8 +168,8 @@ protected boolean coreFit(Frame df, Var weights) {
}

@Override
protected RegressionResult<SmoothRTree> corePredict(Frame df, boolean withResiduals) {
RegressionResult<SmoothRTree> prediction = RegressionResult.build(this, df, withResiduals);
protected RegressionResult corePredict(Frame df, boolean withResiduals) {
RegressionResult prediction = RegressionResult.build(this, df, withResiduals);
for (int i = 0; i < df.rowCount(); i++) {
prediction.firstPrediction().setDouble(i, root.predict(df, i, this, 1.0));
}
Expand Down
15 changes: 7 additions & 8 deletions src/rapaio/ml/classifier/AbstractClassifierModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@
*
* @author <a href="mailto:[email protected]">Aurelian Tutuianu</a>
*/
public abstract class AbstractClassifierModel<M extends ClassifierModel<M, R>, R extends ClassifierResult<M>>
implements ClassifierModel<M, R> {
public abstract class AbstractClassifierModel<M extends ClassifierModel, R extends ClassifierResult> implements ClassifierModel {

private static final long serialVersionUID = -6866948033065091047L;

Expand Down Expand Up @@ -122,18 +121,18 @@ public BiConsumer<M, Integer> runningHook() {
}

@Override
public M withRunningHook(BiConsumer<M, Integer> runningHook) {
this.runningHook = runningHook;
return (M) this;
public <T extends ClassifierModel> T withRunningHook(BiConsumer<? extends ClassifierModel, Integer> runningHook) {
this.runningHook = (BiConsumer<M, Integer>) runningHook;
return (T)this;
}

public BiFunction<M, Integer, Boolean> stoppingHook() {
return stoppingHook;
}

public M withStoppingHook(BiFunction<M, Integer, Boolean> stoppingHook) {
this.stoppingHook = stoppingHook;
return (M) this;
public <T extends ClassifierModel> T withStoppingHook(BiFunction<? extends ClassifierModel, Integer, Boolean> stoppingHook) {
this.stoppingHook = (BiFunction<M, Integer, Boolean>) stoppingHook;
return (T) this;
}

@Override
Expand Down
Loading

0 comments on commit 0e07d53

Please sign in to comment.