use of org.apache.spark.mllib.tree.configuration.BoostingStrategy in project java_study by aloyschen.
the class GbdtAndLr method train_gbdt.
/*
* 训练gbdt模型,设置为2分类模型,默认损失函数为log loss,设置maxBin为categorical特征类别数的最大值
* @param data: 存储训练样本标签和特征的数据
* @return model: 训练后的GBDT模型
*/
private GradientBoostedTreesModel train_gbdt(JavaSparkContext jsc, JavaRDD<LabeledPoint> data) {
Date now = new Date();
DateFormat d1 = DateFormat.getDateInstance();
String date = d1.format(now);
JavaRDD<LabeledPoint>[] splits;
JavaRDD<LabeledPoint> trainingData;
JavaRDD<LabeledPoint> testData;
splits = data.randomSplit(new double[] { 0.7, 0.3 });
trainingData = splits[0];
testData = splits[1];
GradientBoostedTreesModel model;
BoostingStrategy boostingStrategy;
boostingStrategy = BoostingStrategy.defaultParams("Classification");
boostingStrategy.setNumIterations(this.maxIter);
boostingStrategy.getTreeStrategy().setNumClasses(2);
boostingStrategy.getTreeStrategy().setMaxDepth(this.maxDepth);
// boostingStrategy.getTreeStrategy().setMaxBins(maxBin);
// boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo);
System.out.println("Start train GBDT");
model = GradientBoostedTrees.train(trainingData, boostingStrategy);
// model.save(jsc.sc(), "./GBDT_Model");
System.out.println("model: " + model.toDebugString());
GradientBoostedTreesModelUtil modelUtil = new GradientBoostedTreesModelUtil(model.algo(), model.trees(), model.treeWeights());
modelUtil.saveGradientBoostedTreesModelToFile(model, this.modelPath + "gbdt_model" + date + ".json");
// predict_gbdt(jsc, testData);
return model;
}
Aggregations