use of com.alibaba.alink.pipeline.classification.GbdtClassifier in project Alink by alibaba.
the class GridSearchCVTest method findBestMulti.
@Test
public void findBestMulti() {
GbdtClassifier gbdtClassifier = new GbdtClassifier().setFeatureCols(colNames[0], colNames[1]).setLabelCol(colNames[2]).setMinSamplesPerLeaf(1).setPredictionCol("pred").setPredictionDetailCol("pred_detail");
ParamGrid grid = new ParamGrid().addGrid(gbdtClassifier, GbdtClassifier.NUM_TREES, new Integer[] { 1, 2 });
GridSearchCV gridSearchCV = new GridSearchCV().setEstimator(gbdtClassifier).setParamGrid(grid).setNumFolds(2).setTuningEvaluator(new MultiClassClassificationTuningEvaluator().setTuningMultiClassMetric(TuningMultiClassMetric.ACCURACY).setLabelCol(colNames[2]).setPredictionDetailCol("pred_detail"));
GridSearchCVModel model = gridSearchCV.fit(memSourceBatchOp);
Assert.assertEquals(testArray.length, model.transform(memSourceBatchOp).collect().size());
}
use of com.alibaba.alink.pipeline.classification.GbdtClassifier in project Alink by alibaba.
the class GridSearchTVSplitTest method findBestMulti.
@Test
public void findBestMulti() {
GbdtClassifier gbdtClassifier = new GbdtClassifier().setFeatureCols(colNames[0], colNames[1]).setLabelCol(colNames[2]).setMinSamplesPerLeaf(1).setPredictionCol("pred").setPredictionDetailCol("pred_detail");
ParamGrid grid = new ParamGrid().addGrid(gbdtClassifier, GbdtClassifier.NUM_TREES, new Integer[] { 1, 2 });
GridSearchTVSplit gridSearchTVSplit = new GridSearchTVSplit().setEstimator(gbdtClassifier).setParamGrid(grid).setTuningEvaluator(new MultiClassClassificationTuningEvaluator().setTuningMultiClassMetric(TuningMultiClassMetric.ACCURACY).setLabelCol(colNames[2]).setPredictionDetailCol("pred_detail"));
GridSearchTVSplitModel model = gridSearchTVSplit.fit(memSourceBatchOp);
Assert.assertEquals(testArray.length, model.transform(memSourceBatchOp).collect().size());
}
use of com.alibaba.alink.pipeline.classification.GbdtClassifier in project Alink by alibaba.
the class GBDTExample method main.
public static void main(String[] args) throws Exception {
String schema = "age bigint, workclass string, fnlwgt bigint, education string, " + "education_num bigint, marital_status string, occupation string, " + "relationship string, race string, sex string, capital_gain bigint, " + "capital_loss bigint, hours_per_week bigint, native_country string, label string";
BatchOperator trainData = new CsvSourceBatchOp().setFilePath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/adult_train.csv").setSchemaStr(schema);
BatchOperator testData = new CsvSourceBatchOp().setFilePath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/adult_test.csv").setSchemaStr(schema);
GbdtClassifier gbdt = new GbdtClassifier().setFeatureCols(new String[] { "age", "capital_gain", "capital_loss", "hours_per_week", "workclass", "education", "marital_status", "occupation" }).setCategoricalCols(new String[] { "workclass", "education", "marital_status", "occupation" }).setLabelCol("label").setNumTrees(20).setPredictionCol("prediction_result");
gbdt.fit(trainData).transform(testData).print();
}
use of com.alibaba.alink.pipeline.classification.GbdtClassifier in project Alink by alibaba.
the class GridSearchCVTest method findBest.
@Test
public void findBest() {
GbdtClassifier gbdtClassifier = new GbdtClassifier().setFeatureCols(colNames[0], colNames[1]).setLabelCol(colNames[2]).setMinSamplesPerLeaf(1).setPredictionCol("pred").setPredictionDetailCol("pred_detail");
ParamGrid grid = new ParamGrid().addGrid(gbdtClassifier, GbdtClassifier.NUM_TREES, new Integer[] { 1, 2 }).addGrid(gbdtClassifier, GbdtClassifier.MAX_DEPTH, new Integer[] { 3, -1 });
GridSearchCV gridSearchCV = new GridSearchCV().setEstimator(gbdtClassifier).setParamGrid(grid).setNumFolds(2).enableLazyPrintTrainInfo().setTuningEvaluator(new BinaryClassificationTuningEvaluator().setTuningBinaryClassMetric(TuningBinaryClassMetric.ACCURACY).setLabelCol(colNames[2]).setPositiveLabelValueString("1").setPredictionDetailCol("pred_detail"));
GridSearchCVModel model = gridSearchCV.fit(memSourceBatchOp);
Assert.assertEquals(testArray.length, model.transform(memSourceBatchOp).collect().size());
}
use of com.alibaba.alink.pipeline.classification.GbdtClassifier in project Alink by alibaba.
the class GridSearchTVSplitTest method findBestAndGet.
@Test
public void findBestAndGet() {
GbdtClassifier gbdtClassifier = new GbdtClassifier().setFeatureCols(colNames[0], colNames[1]).setLabelCol(colNames[2]).setMinSamplesPerLeaf(1).setPredictionCol("pred").setPredictionDetailCol("pred_detail");
ParamGrid grid = new ParamGrid().addGrid(gbdtClassifier, GbdtClassifier.NUM_TREES, new Integer[] { 1, 2 });
GridSearchTVSplit gridSearchTVSplit = new GridSearchTVSplit().setEstimator(gbdtClassifier).setParamGrid(grid).setTuningEvaluator(new BinaryClassificationTuningEvaluator().setTuningBinaryClassMetric(TuningBinaryClassMetric.ACCURACY).setLabelCol(colNames[2]).setPositiveLabelValueString("1").setPredictionDetailCol("pred_detail"));
GridSearchTVSplitModel model = gridSearchTVSplit.fit(memSourceBatchOp);
PipelineModel loaded = PipelineModel.collectLoad(model.getBestPipelineModel().save());
Assert.assertEquals(testArray.length, loaded.transform(memSourceBatchOp).collect().size());
}
Aggregations