Search in sources :

Example 1 with GbdtClassifier

use of com.alibaba.alink.pipeline.classification.GbdtClassifier in project Alink by alibaba.

the class GridSearchCVTest method findBestMulti.

@Test
public void findBestMulti() {
    GbdtClassifier gbdtClassifier = new GbdtClassifier().setFeatureCols(colNames[0], colNames[1]).setLabelCol(colNames[2]).setMinSamplesPerLeaf(1).setPredictionCol("pred").setPredictionDetailCol("pred_detail");
    ParamGrid grid = new ParamGrid().addGrid(gbdtClassifier, GbdtClassifier.NUM_TREES, new Integer[] { 1, 2 });
    GridSearchCV gridSearchCV = new GridSearchCV().setEstimator(gbdtClassifier).setParamGrid(grid).setNumFolds(2).setTuningEvaluator(new MultiClassClassificationTuningEvaluator().setTuningMultiClassMetric(TuningMultiClassMetric.ACCURACY).setLabelCol(colNames[2]).setPredictionDetailCol("pred_detail"));
    GridSearchCVModel model = gridSearchCV.fit(memSourceBatchOp);
    Assert.assertEquals(testArray.length, model.transform(memSourceBatchOp).collect().size());
}
Also used : GbdtClassifier(com.alibaba.alink.pipeline.classification.GbdtClassifier) Test(org.junit.Test)

Example 2 with GbdtClassifier

use of com.alibaba.alink.pipeline.classification.GbdtClassifier in project Alink by alibaba.

the class GridSearchTVSplitTest method findBestMulti.

@Test
public void findBestMulti() {
    GbdtClassifier gbdtClassifier = new GbdtClassifier().setFeatureCols(colNames[0], colNames[1]).setLabelCol(colNames[2]).setMinSamplesPerLeaf(1).setPredictionCol("pred").setPredictionDetailCol("pred_detail");
    ParamGrid grid = new ParamGrid().addGrid(gbdtClassifier, GbdtClassifier.NUM_TREES, new Integer[] { 1, 2 });
    GridSearchTVSplit gridSearchTVSplit = new GridSearchTVSplit().setEstimator(gbdtClassifier).setParamGrid(grid).setTuningEvaluator(new MultiClassClassificationTuningEvaluator().setTuningMultiClassMetric(TuningMultiClassMetric.ACCURACY).setLabelCol(colNames[2]).setPredictionDetailCol("pred_detail"));
    GridSearchTVSplitModel model = gridSearchTVSplit.fit(memSourceBatchOp);
    Assert.assertEquals(testArray.length, model.transform(memSourceBatchOp).collect().size());
}
Also used : GbdtClassifier(com.alibaba.alink.pipeline.classification.GbdtClassifier) Test(org.junit.Test)

Example 3 with GbdtClassifier

use of com.alibaba.alink.pipeline.classification.GbdtClassifier in project Alink by alibaba.

the class GBDTExample method main.

public static void main(String[] args) throws Exception {
    String schema = "age bigint, workclass string, fnlwgt bigint, education string, " + "education_num bigint, marital_status string, occupation string, " + "relationship string, race string, sex string, capital_gain bigint, " + "capital_loss bigint, hours_per_week bigint, native_country string, label string";
    BatchOperator trainData = new CsvSourceBatchOp().setFilePath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/adult_train.csv").setSchemaStr(schema);
    BatchOperator testData = new CsvSourceBatchOp().setFilePath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/adult_test.csv").setSchemaStr(schema);
    GbdtClassifier gbdt = new GbdtClassifier().setFeatureCols(new String[] { "age", "capital_gain", "capital_loss", "hours_per_week", "workclass", "education", "marital_status", "occupation" }).setCategoricalCols(new String[] { "workclass", "education", "marital_status", "occupation" }).setLabelCol("label").setNumTrees(20).setPredictionCol("prediction_result");
    gbdt.fit(trainData).transform(testData).print();
}
Also used : GbdtClassifier(com.alibaba.alink.pipeline.classification.GbdtClassifier) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) CsvSourceBatchOp(com.alibaba.alink.operator.batch.source.CsvSourceBatchOp)

Example 4 with GbdtClassifier

use of com.alibaba.alink.pipeline.classification.GbdtClassifier in project Alink by alibaba.

the class GridSearchCVTest method findBest.

@Test
public void findBest() {
    GbdtClassifier gbdtClassifier = new GbdtClassifier().setFeatureCols(colNames[0], colNames[1]).setLabelCol(colNames[2]).setMinSamplesPerLeaf(1).setPredictionCol("pred").setPredictionDetailCol("pred_detail");
    ParamGrid grid = new ParamGrid().addGrid(gbdtClassifier, GbdtClassifier.NUM_TREES, new Integer[] { 1, 2 }).addGrid(gbdtClassifier, GbdtClassifier.MAX_DEPTH, new Integer[] { 3, -1 });
    GridSearchCV gridSearchCV = new GridSearchCV().setEstimator(gbdtClassifier).setParamGrid(grid).setNumFolds(2).enableLazyPrintTrainInfo().setTuningEvaluator(new BinaryClassificationTuningEvaluator().setTuningBinaryClassMetric(TuningBinaryClassMetric.ACCURACY).setLabelCol(colNames[2]).setPositiveLabelValueString("1").setPredictionDetailCol("pred_detail"));
    GridSearchCVModel model = gridSearchCV.fit(memSourceBatchOp);
    Assert.assertEquals(testArray.length, model.transform(memSourceBatchOp).collect().size());
}
Also used : GbdtClassifier(com.alibaba.alink.pipeline.classification.GbdtClassifier) Test(org.junit.Test)

Example 5 with GbdtClassifier

use of com.alibaba.alink.pipeline.classification.GbdtClassifier in project Alink by alibaba.

the class GridSearchTVSplitTest method findBestAndGet.

@Test
public void findBestAndGet() {
    GbdtClassifier gbdtClassifier = new GbdtClassifier().setFeatureCols(colNames[0], colNames[1]).setLabelCol(colNames[2]).setMinSamplesPerLeaf(1).setPredictionCol("pred").setPredictionDetailCol("pred_detail");
    ParamGrid grid = new ParamGrid().addGrid(gbdtClassifier, GbdtClassifier.NUM_TREES, new Integer[] { 1, 2 });
    GridSearchTVSplit gridSearchTVSplit = new GridSearchTVSplit().setEstimator(gbdtClassifier).setParamGrid(grid).setTuningEvaluator(new BinaryClassificationTuningEvaluator().setTuningBinaryClassMetric(TuningBinaryClassMetric.ACCURACY).setLabelCol(colNames[2]).setPositiveLabelValueString("1").setPredictionDetailCol("pred_detail"));
    GridSearchTVSplitModel model = gridSearchTVSplit.fit(memSourceBatchOp);
    PipelineModel loaded = PipelineModel.collectLoad(model.getBestPipelineModel().save());
    Assert.assertEquals(testArray.length, loaded.transform(memSourceBatchOp).collect().size());
}
Also used : GbdtClassifier(com.alibaba.alink.pipeline.classification.GbdtClassifier) PipelineModel(com.alibaba.alink.pipeline.PipelineModel) Test(org.junit.Test)

Aggregations

GbdtClassifier (com.alibaba.alink.pipeline.classification.GbdtClassifier)9 Test (org.junit.Test)5 AkSourceBatchOp (com.alibaba.alink.operator.batch.source.AkSourceBatchOp)3 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)2 EvalBinaryClassBatchOp (com.alibaba.alink.operator.batch.evaluation.EvalBinaryClassBatchOp)2 Stopwatch (com.alibaba.alink.common.utils.Stopwatch)1 EvalMultiClassBatchOp (com.alibaba.alink.operator.batch.evaluation.EvalMultiClassBatchOp)1 CsvSourceBatchOp (com.alibaba.alink.operator.batch.source.CsvSourceBatchOp)1 PipelineModel (com.alibaba.alink.pipeline.PipelineModel)1 LinearSvm (com.alibaba.alink.pipeline.classification.LinearSvm)1 LogisticRegression (com.alibaba.alink.pipeline.classification.LogisticRegression)1 OneVsRest (com.alibaba.alink.pipeline.classification.OneVsRest)1 BinaryClassificationTuningEvaluator (com.alibaba.alink.pipeline.tuning.BinaryClassificationTuningEvaluator)1 ParamDist (com.alibaba.alink.pipeline.tuning.ParamDist)1 RandomSearchTVSplit (com.alibaba.alink.pipeline.tuning.RandomSearchTVSplit)1 RandomSearchTVSplitModel (com.alibaba.alink.pipeline.tuning.RandomSearchTVSplitModel)1