use of com.alibaba.alink.operator.batch.classification.DecisionTreeTrainBatchOp in project Alink by alibaba.
the class Chap09 method c_5.
static void c_5() throws Exception {
BatchOperator train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
BatchOperator test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);
for (TreeType treeType : new TreeType[] { TreeType.GINI, TreeType.INFOGAIN, TreeType.INFOGAINRATIO }) {
BatchOperator<?> model = train_data.link(new DecisionTreeTrainBatchOp().setTreeType(treeType).setFeatureCols(FEATURE_COL_NAMES).setCategoricalCols(FEATURE_COL_NAMES).setLabelCol(LABEL_COL_NAME).lazyPrintModelInfo("< " + treeType.toString() + " >").lazyCollectModelInfo(new Consumer<DecisionTreeModelInfo>() {
@Override
public void accept(DecisionTreeModelInfo decisionTreeModelInfo) {
try {
decisionTreeModelInfo.saveTreeAsImage(DATA_DIR + "tree_" + treeType.toString() + ".jpg", true);
} catch (IOException e) {
e.printStackTrace();
}
}
}));
DecisionTreePredictBatchOp predictor = new DecisionTreePredictBatchOp().setPredictionCol(PREDICTION_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME);
predictor.linkFrom(model, test_data);
predictor.link(new EvalBinaryClassBatchOp().setPositiveLabelValueString("p").setLabelCol(LABEL_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME).lazyPrintMetrics("< " + treeType.toString() + " >"));
}
BatchOperator.execute();
}
Aggregations