use of com.alibaba.alink.pipeline.tuning.RandomSearchTVSplitModel in project Alink by alibaba.
the class Chap20 method c_2.
static void c_2() throws Exception {
Stopwatch sw = new Stopwatch();
sw.start();
AlinkGlobalConfiguration.setPrintProcessInfo(true);
BatchOperator train_sample = new AkSourceBatchOp().setFilePath(Chap11.DATA_DIR + Chap11.TRAIN_SAMPLE_FILE);
BatchOperator test_data = new AkSourceBatchOp().setFilePath(Chap11.DATA_DIR + Chap11.TEST_FILE);
final String[] featuresColNames = ArrayUtils.removeElement(train_sample.getColNames(), Chap11.LABEL_COL_NAME);
GbdtClassifier gbdt = new GbdtClassifier().setFeatureCols(featuresColNames).setLabelCol(Chap11.LABEL_COL_NAME).setPredictionCol(Chap11.PREDICTION_COL_NAME).setPredictionDetailCol(Chap11.PRED_DETAIL_COL_NAME);
RandomSearchTVSplit randomSearch = new RandomSearchTVSplit().setNumIter(20).setTrainRatio(0.8).setEstimator(gbdt).setParamDist(new ParamDist().addDist(gbdt, GbdtClassifier.NUM_TREES, ValueDist.randArray(new Integer[] { 50, 100 })).addDist(gbdt, GbdtClassifier.MAX_DEPTH, ValueDist.randInteger(4, 10)).addDist(gbdt, GbdtClassifier.MAX_BINS, ValueDist.randArray(new Integer[] { 64, 128, 256, 512 })).addDist(gbdt, GbdtClassifier.LEARNING_RATE, ValueDist.randArray(new Double[] { 0.3, 0.1, 0.01 }))).setTuningEvaluator(new BinaryClassificationTuningEvaluator().setLabelCol(Chap11.LABEL_COL_NAME).setPredictionDetailCol(Chap11.PRED_DETAIL_COL_NAME).setTuningBinaryClassMetric(TuningBinaryClassMetric.F1)).enableLazyPrintTrainInfo();
RandomSearchTVSplitModel bestModel = randomSearch.fit(train_sample);
bestModel.transform(test_data).link(new EvalBinaryClassBatchOp().setPositiveLabelValueString("1").setLabelCol(Chap11.LABEL_COL_NAME).setPredictionDetailCol(Chap11.PRED_DETAIL_COL_NAME).lazyPrintMetrics());
BatchOperator.execute();
sw.stop();
System.out.println(sw.getElapsedTimeSpan());
}
Aggregations