use of com.alibaba.alink.pipeline.classification.RandomForestClassifier in project Alink by alibaba.
the class Chap11 method c_7.
static void c_7() throws Exception {
AkSourceBatchOp train_sample = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_SAMPLE_FILE);
AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);
String[] featureColNames = ArrayUtils.removeElement(test_data.getColNames(), LABEL_COL_NAME);
for (TreeType treeType : new TreeType[] { TreeType.GINI, TreeType.INFOGAIN, TreeType.INFOGAINRATIO }) {
new DecisionTreeClassifier().setTreeType(treeType).setFeatureCols(featureColNames).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME).fit(train_sample).transform(test_data).link(new EvalBinaryClassBatchOp().setPositiveLabelValueString("1").setLabelCol(LABEL_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME).lazyPrintMetrics(treeType.toString()));
}
BatchOperator.execute();
new RandomForestClassifier().setNumTrees(20).setMaxDepth(4).setMaxBins(512).setFeatureCols(featureColNames).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME).fit(train_sample).transform(test_data).link(new EvalBinaryClassBatchOp().setLabelCol(LABEL_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME).lazyPrintMetrics("RandomForest with Stratified Sample"));
BatchOperator.execute();
}
use of com.alibaba.alink.pipeline.classification.RandomForestClassifier in project Alink by alibaba.
the class GridSearchCVTest method testSplit.
@Test
public void testSplit() throws Exception {
List<Row> rows = Arrays.asList(Row.of(1.0, "A", 0, 0, 0), Row.of(2.0, "B", 1, 1, 0), Row.of(3.0, "C", 2, 2, 1), Row.of(4.0, "D", 3, 3, 1), Row.of(1.0, "A", 0, 0, 0), Row.of(2.0, "B", 1, 1, 0), Row.of(3.0, "C", 2, 2, 1), Row.of(4.0, "D", 3, 3, 1), Row.of(1.0, "A", 0, 0, 0), Row.of(2.0, "B", 1, 1, 0), Row.of(3.0, "C", 2, 2, 1));
String[] colNames = new String[] { "f0", "f1", "f2", "f3", "label" };
MemSourceBatchOp data = new MemSourceBatchOp(rows, colNames);
String[] featureColNames = new String[] { colNames[0], colNames[1], colNames[2], colNames[3] };
String[] categoricalColNames = new String[] { colNames[1] };
String labelColName = colNames[4];
RandomForestClassifier rf = new RandomForestClassifier().setFeatureCols(featureColNames).setCategoricalCols(categoricalColNames).setLabelCol(labelColName).setPredictionCol("pred_result").setPredictionDetailCol("pred_detail").setSubsamplingRatio(1.0);
Pipeline pipeline = new Pipeline(rf);
ParamGrid paramGrid = new ParamGrid().addGrid(rf, "SUBSAMPLING_RATIO", new Double[] { 1.0 }).addGrid(rf, "NUM_TREES", new Integer[] { 3 });
BinaryClassificationTuningEvaluator tuning_evaluator = new BinaryClassificationTuningEvaluator().setLabelCol(labelColName).setPredictionDetailCol("pred_detail").setTuningBinaryClassMetric("Accuracy");
GridSearchTVSplit cv = new GridSearchTVSplit().setEstimator(pipeline).setParamGrid(paramGrid).setTuningEvaluator(tuning_evaluator).setTrainRatio(0.8);
ModelBase cvModel = cv.fit(data);
cvModel.transform(data).print();
}
use of com.alibaba.alink.pipeline.classification.RandomForestClassifier in project Alink by alibaba.
the class Chap13 method c_5.
static void c_5() throws Exception {
BatchOperator.setParallelism(4);
if (!new File(DATA_DIR + TABLE_TRAIN_FILE).exists()) {
AkSourceBatchOp train_sparse = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TRAIN_FILE);
AkSourceBatchOp test_sparse = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TEST_FILE);
StringBuilder sbd = new StringBuilder();
sbd.append("c_0 double");
for (int i = 1; i < 784; i++) {
sbd.append(", c_").append(i).append(" double");
}
new VectorToColumns().setVectorCol(VECTOR_COL_NAME).setSchemaStr(sbd.toString()).setReservedCols(LABEL_COL_NAME).transform(train_sparse).link(new AkSinkBatchOp().setFilePath(DATA_DIR + TABLE_TRAIN_FILE));
new VectorToColumns().setVectorCol(VECTOR_COL_NAME).setSchemaStr(sbd.toString()).setReservedCols(LABEL_COL_NAME).transform(test_sparse).link(new AkSinkBatchOp().setFilePath(DATA_DIR + TABLE_TEST_FILE));
BatchOperator.execute();
}
AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TABLE_TRAIN_FILE);
AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TABLE_TEST_FILE);
final String[] featureColNames = ArrayUtils.removeElement(train_data.getColNames(), LABEL_COL_NAME);
train_data.lazyPrint(5);
Stopwatch sw = new Stopwatch();
for (TreeType treeType : new TreeType[] { TreeType.GINI, TreeType.INFOGAIN, TreeType.INFOGAINRATIO }) {
sw.reset();
sw.start();
new DecisionTreeClassifier().setTreeType(treeType).setFeatureCols(featureColNames).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).enableLazyPrintModelInfo().fit(train_data).transform(test_data).link(new EvalMultiClassBatchOp().setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).lazyPrintMetrics("DecisionTreeClassifier " + treeType.toString()));
BatchOperator.execute();
sw.stop();
System.out.println(sw.getElapsedTimeSpan());
}
for (int numTrees : new int[] { 2, 4, 8, 16, 32, 64, 128 }) {
sw.reset();
sw.start();
new RandomForestClassifier().setSubsamplingRatio(0.6).setNumTreesOfInfoGain(numTrees).setFeatureCols(featureColNames).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).enableLazyPrintModelInfo().fit(train_data).transform(test_data).link(new EvalMultiClassBatchOp().setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).lazyPrintMetrics("RandomForestClassifier : " + numTrees));
BatchOperator.execute();
sw.stop();
System.out.println(sw.getElapsedTimeSpan());
}
}
Aggregations