Search in sources :

Example 1 with BinaryClassificationMetrics

use of org.apache.spark.mllib.evaluation.BinaryClassificationMetrics in project mmtf-spark by sbl-sdsc.

the class SparkMultiClassClassifier method fit.

/**
 * Dataset must at least contain the following two columns:
 * label: the class labels
 * features: feature vector
 * @param data
 * @return map with metrics
 */
public Map<String, String> fit(Dataset<Row> data) {
    int classCount = (int) data.select(label).distinct().count();
    StringIndexerModel labelIndexer = new StringIndexer().setInputCol(label).setOutputCol("indexedLabel").fit(data);
    // Split the data into training and test sets (30% held out for testing)
    Dataset<Row>[] splits = data.randomSplit(new double[] { 1.0 - testFraction, testFraction }, seed);
    Dataset<Row> trainingData = splits[0];
    Dataset<Row> testData = splits[1];
    String[] labels = labelIndexer.labels();
    System.out.println();
    System.out.println("Class\tTrain\tTest");
    for (String l : labels) {
        System.out.println(l + "\t" + trainingData.select(label).filter(label + " = '" + l + "'").count() + "\t" + testData.select(label).filter(label + " = '" + l + "'").count());
    }
    // Set input columns
    predictor.setLabelCol("indexedLabel").setFeaturesCol("features");
    // Convert indexed labels back to original labels.
    IndexToString labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels());
    // Chain indexers and forest in a Pipeline
    Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] { labelIndexer, predictor, labelConverter });
    // Train model. This also runs the indexers.
    PipelineModel model = pipeline.fit(trainingData);
    // Make predictions.
    Dataset<Row> predictions = model.transform(testData).cache();
    // Display some sample predictions
    System.out.println();
    System.out.println("Sample predictions: " + predictor.getClass().getSimpleName());
    predictions.sample(false, 0.1, seed).show(25);
    predictions = predictions.withColumnRenamed(label, "stringLabel");
    predictions = predictions.withColumnRenamed("indexedLabel", label);
    // collect metrics
    Dataset<Row> pred = predictions.select("prediction", label);
    Map<String, String> metrics = new LinkedHashMap<>();
    metrics.put("Method", predictor.getClass().getSimpleName());
    if (classCount == 2) {
        BinaryClassificationMetrics b = new BinaryClassificationMetrics(pred);
        metrics.put("AUC", Float.toString((float) b.areaUnderROC()));
    }
    MulticlassMetrics m = new MulticlassMetrics(pred);
    metrics.put("F", Float.toString((float) m.weightedFMeasure()));
    metrics.put("Accuracy", Float.toString((float) m.accuracy()));
    metrics.put("Precision", Float.toString((float) m.weightedPrecision()));
    metrics.put("Recall", Float.toString((float) m.weightedRecall()));
    metrics.put("False Positive Rate", Float.toString((float) m.weightedFalsePositiveRate()));
    metrics.put("True Positive Rate", Float.toString((float) m.weightedTruePositiveRate()));
    metrics.put("", "\nConfusion Matrix\n" + Arrays.toString(labels) + "\n" + m.confusionMatrix().toString());
    return metrics;
}
Also used : Dataset(org.apache.spark.sql.Dataset) BinaryClassificationMetrics(org.apache.spark.mllib.evaluation.BinaryClassificationMetrics) IndexToString(org.apache.spark.ml.feature.IndexToString) MulticlassMetrics(org.apache.spark.mllib.evaluation.MulticlassMetrics) StringIndexerModel(org.apache.spark.ml.feature.StringIndexerModel) Pipeline(org.apache.spark.ml.Pipeline) PipelineModel(org.apache.spark.ml.PipelineModel) LinkedHashMap(java.util.LinkedHashMap) StringIndexer(org.apache.spark.ml.feature.StringIndexer) IndexToString(org.apache.spark.ml.feature.IndexToString) Row(org.apache.spark.sql.Row)

Example 2 with BinaryClassificationMetrics

use of org.apache.spark.mllib.evaluation.BinaryClassificationMetrics in project java_study by aloyschen.

the class GbdtAndLr method train.

/*
    * 获取GBDT模型组合后的特征输入到Lr模型中,训练LR模型
    * @param Path: 训练数据路径
     */
public void train(String Path) {
    JavaSparkContext jsc = getSc();
    ArrayList<ArrayList<Integer>> treeLeafArray = new ArrayList<>();
    Dataset<Row> all_data = Preprocessing(jsc, Path);
    JavaRDD<LabeledPoint> gbdt_data_labelpoint = load_gbdt_data(all_data);
    GradientBoostedTreesModel gbdt = train_gbdt(jsc, gbdt_data_labelpoint);
    DecisionTreeModel[] decisionTreeModels = gbdt.trees();
    // 获取GBDT每棵树的叶子索引
    for (int i = 0; i < this.maxIter; i++) {
        treeLeafArray.add(getLeafNodes(decisionTreeModels[i].topNode()));
    // System.out.println("叶子索引");
    // System.out.println(treeLeafArray.get(i));
    }
    JavaRDD<LabeledPoint> CombineFeatures = all_data.toJavaRDD().map(line -> {
        double[] newvaluesDouble;
        double[] features = new double[24];
        // 将dataset中每列特征值放入DenseVector中
        for (Integer i = 6; i < 18; i++) {
            org.apache.spark.mllib.linalg.DenseVector den = null;
            if (line.get(i) instanceof org.apache.spark.ml.linalg.Vector) {
                den = (DenseVector) Vectors.fromML((org.apache.spark.ml.linalg.DenseVector) line.get(i));
                features[i - 6] = den.toArray()[0];
            } else {
                features[i - 6] = Double.parseDouble(line.get(i).toString());
            }
        }
        DenseVector numerical_vector = new DenseVector(features);
        ArrayList<Double> newvaluesArray = new ArrayList<>();
        for (int i = 0; i < this.maxIter; i++) {
            int treePredict = predictModify(decisionTreeModels[i].topNode(), numerical_vector);
            int len = treeLeafArray.get(i).size();
            ArrayList<Double> treeArray = new ArrayList<>(len);
            // 数组所有值初始化为0,落在的叶子节点至为1
            for (int j = 0; j < len; j++) treeArray.add(j, 0d);
            treeArray.set(treeLeafArray.get(i).indexOf(treePredict), 1d);
            newvaluesArray.addAll(treeArray);
        }
        for (int i = 18; i < 29; i++) {
            SparseVector onehot_data = (SparseVector) Vectors.fromML((org.apache.spark.ml.linalg.SparseVector) line.get(i));
            DenseVector cat_data = onehot_data.toDense();
            for (int j = 0; j < cat_data.size(); j++) {
                newvaluesArray.add(cat_data.apply(j));
            }
        }
        newvaluesDouble = newvaluesArray.stream().mapToDouble(Double::doubleValue).toArray();
        DenseVector newdenseVector = new DenseVector(newvaluesDouble);
        return (new LabeledPoint(Double.valueOf(line.get(1).toString()), newdenseVector));
    });
    JavaRDD<LabeledPoint>[] splitsLR = CombineFeatures.randomSplit(new double[] { 0.7, 0.3 });
    JavaRDD<LabeledPoint> trainingDataLR = splitsLR[0];
    JavaRDD<LabeledPoint> testDataLR = splitsLR[1];
    System.out.println("Start train LR");
    LogisticRegressionModel LR = new LogisticRegressionWithLBFGS().setNumClasses(2).run(trainingDataLR.rdd()).clearThreshold();
    System.out.println("modelLR.weights().size():" + LR.weights().size());
    JavaPairRDD<Object, Object> test_LR = testDataLR.mapToPair((PairFunction<LabeledPoint, Object, Object>) labeledPoint -> {
        Tuple2<Object, Object> tuple2 = new Tuple2<>(LR.predict(labeledPoint.features()), labeledPoint.label());
        return tuple2;
    });
    BinaryClassificationMetrics test_metrics = new BinaryClassificationMetrics(test_LR.rdd());
    double test_auc = test_metrics.areaUnderROC();
    System.out.println("test data auc_score:" + test_auc);
    JavaPairRDD<Object, Object> train_LR = trainingDataLR.mapToPair((PairFunction<LabeledPoint, Object, Object>) labeledPoint -> {
        Tuple2<Object, Object> tuple2 = new Tuple2<>(LR.predict(labeledPoint.features()), labeledPoint.label());
        return tuple2;
    });
    BinaryClassificationMetrics train_metrics = new BinaryClassificationMetrics(train_LR.rdd());
    double train_auc = train_metrics.areaUnderROC();
    System.out.println("train data auc_score:" + train_auc);
    // 不同阈值下的精确度排序,取前十个输出
    JavaRDD<Tuple2<Object, Object>> precision = train_metrics.precisionByThreshold().toJavaRDD();
    JavaPairRDD<Object, Object> temp = JavaPairRDD.fromJavaRDD(precision);
    JavaPairRDD<Object, Object> swap = temp.mapToPair(Tuple2::swap);
    JavaPairRDD<Object, Object> precision_sort = swap.sortByKey(false);
    System.out.println("Precision by threshold: (Precision, Threshold)");
    for (int i = 0; i < 10; i++) {
        System.out.println(precision_sort.take(10).toArray()[i]);
    }
}
Also used : Vectors(org.apache.spark.mllib.linalg.Vectors) java.util(java.util) LabeledPoint(org.apache.spark.mllib.regression.LabeledPoint) Serializable(scala.Serializable) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) BinaryClassificationMetrics(org.apache.spark.mllib.evaluation.BinaryClassificationMetrics) CmdlineParser(de.tototec.cmdoption.CmdlineParser) SparseVector(org.apache.spark.mllib.linalg.SparseVector) LogisticRegressionModel(org.apache.spark.mllib.classification.LogisticRegressionModel) utils(utils) JavaRDD(org.apache.spark.api.java.JavaRDD) FeatureType(org.apache.spark.mllib.tree.configuration.FeatureType) DateFormat(java.text.DateFormat) DataTypes(org.apache.spark.sql.types.DataTypes) StructField(org.apache.spark.sql.types.StructField) StructType(org.apache.spark.sql.types.StructType) DenseVector(org.apache.spark.mllib.linalg.DenseVector) GradientBoostedTrees(org.apache.spark.mllib.tree.GradientBoostedTrees) SparkConf(org.apache.spark.SparkConf) LogisticRegressionWithLBFGS(org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) Option(scala.Option) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) org.apache.spark.ml.feature(org.apache.spark.ml.feature) org.apache.spark.mllib.tree.model(org.apache.spark.mllib.tree.model) org.apache.spark.sql(org.apache.spark.sql) JavaConverters(scala.collection.JavaConverters) BoostingStrategy(org.apache.spark.mllib.tree.configuration.BoostingStrategy) PairFunction(org.apache.spark.api.java.function.PairFunction) LogisticRegressionModel(org.apache.spark.mllib.classification.LogisticRegressionModel) LabeledPoint(org.apache.spark.mllib.regression.LabeledPoint) SparseVector(org.apache.spark.mllib.linalg.SparseVector) LogisticRegressionWithLBFGS(org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) SparseVector(org.apache.spark.mllib.linalg.SparseVector) DenseVector(org.apache.spark.mllib.linalg.DenseVector) DenseVector(org.apache.spark.mllib.linalg.DenseVector) BinaryClassificationMetrics(org.apache.spark.mllib.evaluation.BinaryClassificationMetrics) LabeledPoint(org.apache.spark.mllib.regression.LabeledPoint) JavaRDD(org.apache.spark.api.java.JavaRDD) Tuple2(scala.Tuple2) DenseVector(org.apache.spark.mllib.linalg.DenseVector)

Aggregations

BinaryClassificationMetrics (org.apache.spark.mllib.evaluation.BinaryClassificationMetrics)2 CmdlineParser (de.tototec.cmdoption.CmdlineParser)1 DateFormat (java.text.DateFormat)1 java.util (java.util)1 LinkedHashMap (java.util.LinkedHashMap)1 SparkConf (org.apache.spark.SparkConf)1 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)1 JavaRDD (org.apache.spark.api.java.JavaRDD)1 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)1 PairFunction (org.apache.spark.api.java.function.PairFunction)1 Pipeline (org.apache.spark.ml.Pipeline)1 PipelineModel (org.apache.spark.ml.PipelineModel)1 org.apache.spark.ml.feature (org.apache.spark.ml.feature)1 IndexToString (org.apache.spark.ml.feature.IndexToString)1 StringIndexer (org.apache.spark.ml.feature.StringIndexer)1 StringIndexerModel (org.apache.spark.ml.feature.StringIndexerModel)1 LogisticRegressionModel (org.apache.spark.mllib.classification.LogisticRegressionModel)1 LogisticRegressionWithLBFGS (org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS)1 MulticlassMetrics (org.apache.spark.mllib.evaluation.MulticlassMetrics)1 DenseVector (org.apache.spark.mllib.linalg.DenseVector)1