Search in sources :

Example 1 with BoostingStrategy

use of org.apache.spark.mllib.tree.configuration.BoostingStrategy in project java_study by aloyschen.

the class GbdtAndLr method train_gbdt.

/*
    * 训练gbdt模型,设置为2分类模型,默认损失函数为log loss,设置maxBin为categorical特征类别数的最大值
    * @param data: 存储训练样本标签和特征的数据
    * @return model: 训练后的GBDT模型
     */
private GradientBoostedTreesModel train_gbdt(JavaSparkContext jsc, JavaRDD<LabeledPoint> data) {
    Date now = new Date();
    DateFormat d1 = DateFormat.getDateInstance();
    String date = d1.format(now);
    JavaRDD<LabeledPoint>[] splits;
    JavaRDD<LabeledPoint> trainingData;
    JavaRDD<LabeledPoint> testData;
    splits = data.randomSplit(new double[] { 0.7, 0.3 });
    trainingData = splits[0];
    testData = splits[1];
    GradientBoostedTreesModel model;
    BoostingStrategy boostingStrategy;
    boostingStrategy = BoostingStrategy.defaultParams("Classification");
    boostingStrategy.setNumIterations(this.maxIter);
    boostingStrategy.getTreeStrategy().setNumClasses(2);
    boostingStrategy.getTreeStrategy().setMaxDepth(this.maxDepth);
    // boostingStrategy.getTreeStrategy().setMaxBins(maxBin);
    // boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo);
    System.out.println("Start train GBDT");
    model = GradientBoostedTrees.train(trainingData, boostingStrategy);
    // model.save(jsc.sc(), "./GBDT_Model");
    System.out.println("model: " + model.toDebugString());
    GradientBoostedTreesModelUtil modelUtil = new GradientBoostedTreesModelUtil(model.algo(), model.trees(), model.treeWeights());
    modelUtil.saveGradientBoostedTreesModelToFile(model, this.modelPath + "gbdt_model" + date + ".json");
    // predict_gbdt(jsc, testData);
    return model;
}
Also used : BoostingStrategy(org.apache.spark.mllib.tree.configuration.BoostingStrategy) DateFormat(java.text.DateFormat) LabeledPoint(org.apache.spark.mllib.regression.LabeledPoint) JavaRDD(org.apache.spark.api.java.JavaRDD)

Aggregations

DateFormat (java.text.DateFormat)1 JavaRDD (org.apache.spark.api.java.JavaRDD)1 LabeledPoint (org.apache.spark.mllib.regression.LabeledPoint)1 BoostingStrategy (org.apache.spark.mllib.tree.configuration.BoostingStrategy)1