Search in sources :

Example 6 with GbdtClassifier

use of com.alibaba.alink.pipeline.classification.GbdtClassifier in project Alink by alibaba.

the class GridSearchTVSplitTest method findBest.

@Test
public void findBest() {
    GbdtClassifier gbdtClassifier = new GbdtClassifier().setFeatureCols(colNames[0], colNames[1]).setLabelCol(colNames[2]).setMinSamplesPerLeaf(1).setPredictionCol("pred").setPredictionDetailCol("pred_detail");
    ParamGrid grid = new ParamGrid().addGrid(gbdtClassifier, GbdtClassifier.NUM_TREES, new Integer[] { 1, 2 });
    GridSearchTVSplit gridSearchTVSplit = new GridSearchTVSplit().setEstimator(gbdtClassifier).setParamGrid(grid).setTuningEvaluator(new BinaryClassificationTuningEvaluator().setTuningBinaryClassMetric(TuningBinaryClassMetric.ACCURACY).setLabelCol(colNames[2]).setPositiveLabelValueString("1").setPredictionDetailCol("pred_detail"));
    GridSearchTVSplitModel model = gridSearchTVSplit.fit(memSourceBatchOp);
    Assert.assertEquals(testArray.length, model.transform(memSourceBatchOp).collect().size());
}
Also used : GbdtClassifier(com.alibaba.alink.pipeline.classification.GbdtClassifier) Test(org.junit.Test)

Example 7 with GbdtClassifier

use of com.alibaba.alink.pipeline.classification.GbdtClassifier in project Alink by alibaba.

the class Chap11 method c_8.

static void c_8() throws Exception {
    AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);
    AkSourceBatchOp train_sample = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_SAMPLE_FILE);
    String[] featureColNames = ArrayUtils.removeElement(test_data.getColNames(), LABEL_COL_NAME);
    new GbdtClassifier().setNumTrees(100).setMaxDepth(5).setMaxBins(256).setFeatureCols(featureColNames).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME).fit(train_sample).transform(test_data).link(new EvalBinaryClassBatchOp().setLabelCol(LABEL_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME).lazyPrintMetrics("GBDT with Stratified Sample"));
    BatchOperator.execute();
}
Also used : AkSourceBatchOp(com.alibaba.alink.operator.batch.source.AkSourceBatchOp) GbdtClassifier(com.alibaba.alink.pipeline.classification.GbdtClassifier) EvalBinaryClassBatchOp(com.alibaba.alink.operator.batch.evaluation.EvalBinaryClassBatchOp)

Example 8 with GbdtClassifier

use of com.alibaba.alink.pipeline.classification.GbdtClassifier in project Alink by alibaba.

the class Chap12 method c_4.

static void c_4() throws Exception {
    AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
    AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);
    new OneVsRest().setClassifier(new LogisticRegression().setFeatureCols(FEATURE_COL_NAMES).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME)).setNumClass(3).fit(train_data).transform(test_data).link(new EvalMultiClassBatchOp().setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).lazyPrintMetrics("OneVsRest_LogisticRegression"));
    new OneVsRest().setClassifier(new GbdtClassifier().setFeatureCols(FEATURE_COL_NAMES).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME)).setNumClass(3).fit(train_data).transform(test_data).link(new EvalMultiClassBatchOp().setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).lazyPrintMetrics("OneVsRest_GBDT"));
    new OneVsRest().setClassifier(new LinearSvm().setFeatureCols(FEATURE_COL_NAMES).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME)).setNumClass(3).fit(train_data).transform(test_data).link(new EvalMultiClassBatchOp().setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).lazyPrintMetrics("OneVsRest_LinearSvm"));
    BatchOperator.execute();
}
Also used : EvalMultiClassBatchOp(com.alibaba.alink.operator.batch.evaluation.EvalMultiClassBatchOp) AkSourceBatchOp(com.alibaba.alink.operator.batch.source.AkSourceBatchOp) OneVsRest(com.alibaba.alink.pipeline.classification.OneVsRest) GbdtClassifier(com.alibaba.alink.pipeline.classification.GbdtClassifier) LinearSvm(com.alibaba.alink.pipeline.classification.LinearSvm) LogisticRegression(com.alibaba.alink.pipeline.classification.LogisticRegression)

Example 9 with GbdtClassifier

use of com.alibaba.alink.pipeline.classification.GbdtClassifier in project Alink by alibaba.

the class Chap20 method c_2.

static void c_2() throws Exception {
    Stopwatch sw = new Stopwatch();
    sw.start();
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    BatchOperator train_sample = new AkSourceBatchOp().setFilePath(Chap11.DATA_DIR + Chap11.TRAIN_SAMPLE_FILE);
    BatchOperator test_data = new AkSourceBatchOp().setFilePath(Chap11.DATA_DIR + Chap11.TEST_FILE);
    final String[] featuresColNames = ArrayUtils.removeElement(train_sample.getColNames(), Chap11.LABEL_COL_NAME);
    GbdtClassifier gbdt = new GbdtClassifier().setFeatureCols(featuresColNames).setLabelCol(Chap11.LABEL_COL_NAME).setPredictionCol(Chap11.PREDICTION_COL_NAME).setPredictionDetailCol(Chap11.PRED_DETAIL_COL_NAME);
    RandomSearchTVSplit randomSearch = new RandomSearchTVSplit().setNumIter(20).setTrainRatio(0.8).setEstimator(gbdt).setParamDist(new ParamDist().addDist(gbdt, GbdtClassifier.NUM_TREES, ValueDist.randArray(new Integer[] { 50, 100 })).addDist(gbdt, GbdtClassifier.MAX_DEPTH, ValueDist.randInteger(4, 10)).addDist(gbdt, GbdtClassifier.MAX_BINS, ValueDist.randArray(new Integer[] { 64, 128, 256, 512 })).addDist(gbdt, GbdtClassifier.LEARNING_RATE, ValueDist.randArray(new Double[] { 0.3, 0.1, 0.01 }))).setTuningEvaluator(new BinaryClassificationTuningEvaluator().setLabelCol(Chap11.LABEL_COL_NAME).setPredictionDetailCol(Chap11.PRED_DETAIL_COL_NAME).setTuningBinaryClassMetric(TuningBinaryClassMetric.F1)).enableLazyPrintTrainInfo();
    RandomSearchTVSplitModel bestModel = randomSearch.fit(train_sample);
    bestModel.transform(test_data).link(new EvalBinaryClassBatchOp().setPositiveLabelValueString("1").setLabelCol(Chap11.LABEL_COL_NAME).setPredictionDetailCol(Chap11.PRED_DETAIL_COL_NAME).lazyPrintMetrics());
    BatchOperator.execute();
    sw.stop();
    System.out.println(sw.getElapsedTimeSpan());
}
Also used : ParamDist(com.alibaba.alink.pipeline.tuning.ParamDist) AkSourceBatchOp(com.alibaba.alink.operator.batch.source.AkSourceBatchOp) GbdtClassifier(com.alibaba.alink.pipeline.classification.GbdtClassifier) Stopwatch(com.alibaba.alink.common.utils.Stopwatch) RandomSearchTVSplit(com.alibaba.alink.pipeline.tuning.RandomSearchTVSplit) BinaryClassificationTuningEvaluator(com.alibaba.alink.pipeline.tuning.BinaryClassificationTuningEvaluator) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) RandomSearchTVSplitModel(com.alibaba.alink.pipeline.tuning.RandomSearchTVSplitModel) EvalBinaryClassBatchOp(com.alibaba.alink.operator.batch.evaluation.EvalBinaryClassBatchOp)

Aggregations

GbdtClassifier (com.alibaba.alink.pipeline.classification.GbdtClassifier)9 Test (org.junit.Test)5 AkSourceBatchOp (com.alibaba.alink.operator.batch.source.AkSourceBatchOp)3 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)2 EvalBinaryClassBatchOp (com.alibaba.alink.operator.batch.evaluation.EvalBinaryClassBatchOp)2 Stopwatch (com.alibaba.alink.common.utils.Stopwatch)1 EvalMultiClassBatchOp (com.alibaba.alink.operator.batch.evaluation.EvalMultiClassBatchOp)1 CsvSourceBatchOp (com.alibaba.alink.operator.batch.source.CsvSourceBatchOp)1 PipelineModel (com.alibaba.alink.pipeline.PipelineModel)1 LinearSvm (com.alibaba.alink.pipeline.classification.LinearSvm)1 LogisticRegression (com.alibaba.alink.pipeline.classification.LogisticRegression)1 OneVsRest (com.alibaba.alink.pipeline.classification.OneVsRest)1 BinaryClassificationTuningEvaluator (com.alibaba.alink.pipeline.tuning.BinaryClassificationTuningEvaluator)1 ParamDist (com.alibaba.alink.pipeline.tuning.ParamDist)1 RandomSearchTVSplit (com.alibaba.alink.pipeline.tuning.RandomSearchTVSplit)1 RandomSearchTVSplitModel (com.alibaba.alink.pipeline.tuning.RandomSearchTVSplitModel)1