Search in sources :

Example 1 with PipelineModel

use of org.apache.spark.ml.PipelineModel in project jpmml-sparkml by jpmml.

the class Main method run.

private void run() throws Exception {
    StructType schema;
    try (InputStream is = new FileInputStream(this.schemaInput)) {
        String json = CharStreams.toString(new InputStreamReader(is, "UTF-8"));
        schema = (StructType) DataType.fromJson(json);
    }
    File pipelineDir = this.pipelineInput;
    zipFile: {
        ZipFile zipFile;
        try {
            zipFile = new ZipFile(pipelineDir);
        } catch (IOException ioe) {
            break zipFile;
        }
        try {
            pipelineDir = File.createTempFile("PipelineModel", "");
            if (!pipelineDir.delete()) {
                throw new IOException();
            }
            pipelineDir.mkdirs();
            ZipUtil.uncompress(zipFile, pipelineDir);
        } finally {
            zipFile.close();
        }
    }
    PipelineModel pipelineModel = PipelineModel.load(pipelineDir.getAbsolutePath());
    PMML pmml = ConverterUtil.toPMML(schema, pipelineModel);
    try (OutputStream os = new FileOutputStream(this.output)) {
        MetroJAXBUtil.marshalPMML(pmml, os);
    }
}
Also used : StructType(org.apache.spark.sql.types.StructType) InputStreamReader(java.io.InputStreamReader) ZipFile(java.util.zip.ZipFile) FileInputStream(java.io.FileInputStream) InputStream(java.io.InputStream) OutputStream(java.io.OutputStream) FileOutputStream(java.io.FileOutputStream) FileOutputStream(java.io.FileOutputStream) PMML(org.dmg.pmml.PMML) IOException(java.io.IOException) File(java.io.File) ZipFile(java.util.zip.ZipFile) FileInputStream(java.io.FileInputStream) PipelineModel(org.apache.spark.ml.PipelineModel)

Example 2 with PipelineModel

use of org.apache.spark.ml.PipelineModel in project jpmml-sparkml by jpmml.

the class RFormulaModelConverter method registerFeatures.

@Override
public void registerFeatures(SparkMLEncoder encoder) {
    RFormulaModel transformer = getTransformer();
    ResolvedRFormula resolvedFormula = transformer.resolvedFormula();
    String targetCol = resolvedFormula.label();
    String labelCol = transformer.getLabelCol();
    if (!(targetCol).equals(labelCol)) {
        List<Feature> features = encoder.getFeatures(targetCol);
        encoder.putFeatures(labelCol, features);
    }
    PipelineModel pipelineModel = transformer.pipelineModel();
    Transformer[] stages = pipelineModel.stages();
    for (Transformer stage : stages) {
        FeatureConverter<?> featureConverter = ConverterUtil.createFeatureConverter(stage);
        featureConverter.registerFeatures(encoder);
    }
}
Also used : Transformer(org.apache.spark.ml.Transformer) ResolvedRFormula(org.apache.spark.ml.feature.ResolvedRFormula) RFormulaModel(org.apache.spark.ml.feature.RFormulaModel) Feature(org.jpmml.converter.Feature) PipelineModel(org.apache.spark.ml.PipelineModel)

Example 3 with PipelineModel

use of org.apache.spark.ml.PipelineModel in project mmtf-spark by sbl-sdsc.

the class SparkMultiClassClassifier method fit.

/**
 * Dataset must at least contain the following two columns:
 * label: the class labels
 * features: feature vector
 * @param data
 * @return map with metrics
 */
public Map<String, String> fit(Dataset<Row> data) {
    int classCount = (int) data.select(label).distinct().count();
    StringIndexerModel labelIndexer = new StringIndexer().setInputCol(label).setOutputCol("indexedLabel").fit(data);
    // Split the data into training and test sets (30% held out for testing)
    Dataset<Row>[] splits = data.randomSplit(new double[] { 1.0 - testFraction, testFraction }, seed);
    Dataset<Row> trainingData = splits[0];
    Dataset<Row> testData = splits[1];
    String[] labels = labelIndexer.labels();
    System.out.println();
    System.out.println("Class\tTrain\tTest");
    for (String l : labels) {
        System.out.println(l + "\t" + trainingData.select(label).filter(label + " = '" + l + "'").count() + "\t" + testData.select(label).filter(label + " = '" + l + "'").count());
    }
    // Set input columns
    predictor.setLabelCol("indexedLabel").setFeaturesCol("features");
    // Convert indexed labels back to original labels.
    IndexToString labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels());
    // Chain indexers and forest in a Pipeline
    Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] { labelIndexer, predictor, labelConverter });
    // Train model. This also runs the indexers.
    PipelineModel model = pipeline.fit(trainingData);
    // Make predictions.
    Dataset<Row> predictions = model.transform(testData).cache();
    // Display some sample predictions
    System.out.println();
    System.out.println("Sample predictions: " + predictor.getClass().getSimpleName());
    predictions.sample(false, 0.1, seed).show(25);
    predictions = predictions.withColumnRenamed(label, "stringLabel");
    predictions = predictions.withColumnRenamed("indexedLabel", label);
    // collect metrics
    Dataset<Row> pred = predictions.select("prediction", label);
    Map<String, String> metrics = new LinkedHashMap<>();
    metrics.put("Method", predictor.getClass().getSimpleName());
    if (classCount == 2) {
        BinaryClassificationMetrics b = new BinaryClassificationMetrics(pred);
        metrics.put("AUC", Float.toString((float) b.areaUnderROC()));
    }
    MulticlassMetrics m = new MulticlassMetrics(pred);
    metrics.put("F", Float.toString((float) m.weightedFMeasure()));
    metrics.put("Accuracy", Float.toString((float) m.accuracy()));
    metrics.put("Precision", Float.toString((float) m.weightedPrecision()));
    metrics.put("Recall", Float.toString((float) m.weightedRecall()));
    metrics.put("False Positive Rate", Float.toString((float) m.weightedFalsePositiveRate()));
    metrics.put("True Positive Rate", Float.toString((float) m.weightedTruePositiveRate()));
    metrics.put("", "\nConfusion Matrix\n" + Arrays.toString(labels) + "\n" + m.confusionMatrix().toString());
    return metrics;
}
Also used : Dataset(org.apache.spark.sql.Dataset) BinaryClassificationMetrics(org.apache.spark.mllib.evaluation.BinaryClassificationMetrics) IndexToString(org.apache.spark.ml.feature.IndexToString) MulticlassMetrics(org.apache.spark.mllib.evaluation.MulticlassMetrics) StringIndexerModel(org.apache.spark.ml.feature.StringIndexerModel) Pipeline(org.apache.spark.ml.Pipeline) PipelineModel(org.apache.spark.ml.PipelineModel) LinkedHashMap(java.util.LinkedHashMap) StringIndexer(org.apache.spark.ml.feature.StringIndexer) IndexToString(org.apache.spark.ml.feature.IndexToString) Row(org.apache.spark.sql.Row)

Example 4 with PipelineModel

use of org.apache.spark.ml.PipelineModel in project mmtf-spark by sbl-sdsc.

the class SparkRegressor method fit.

/**
 * Dataset must at least contain the following two columns:
 * label: the class labels
 * features: feature vector
 * @param data
 * @return map with metrics
 */
public Map<String, String> fit(Dataset<Row> data) {
    // Split the data into training and test sets (30% held out for testing)
    Dataset<Row>[] splits = data.randomSplit(new double[] { 1.0 - testFraction, testFraction }, seed);
    Dataset<Row> trainingData = splits[0];
    Dataset<Row> testData = splits[1];
    // Train a RandomForest model.
    predictor.setLabelCol(label).setFeaturesCol("features");
    // Chain indexer and forest in a Pipeline
    Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] { predictor });
    // Train model. This also runs the indexer.
    PipelineModel model = pipeline.fit(trainingData);
    // Make predictions.
    Dataset<Row> predictions = model.transform(testData);
    // Display some sample predictions
    System.out.println("Sample predictions: " + predictor.getClass().getSimpleName());
    String primaryKey = predictions.columns()[0];
    predictions.select(primaryKey, label, "prediction").sample(false, 0.1, seed).show(50);
    Map<String, String> metrics = new LinkedHashMap<>();
    metrics.put("Method", predictor.getClass().getSimpleName());
    // Select (prediction, true label) and compute test error
    RegressionEvaluator evaluator = new RegressionEvaluator().setLabelCol(label).setPredictionCol("prediction").setMetricName("rmse");
    metrics.put("rmse", Double.toString(evaluator.evaluate(predictions)));
    return metrics;
}
Also used : Dataset(org.apache.spark.sql.Dataset) Row(org.apache.spark.sql.Row) RegressionEvaluator(org.apache.spark.ml.evaluation.RegressionEvaluator) Pipeline(org.apache.spark.ml.Pipeline) PipelineModel(org.apache.spark.ml.PipelineModel) LinkedHashMap(java.util.LinkedHashMap)

Example 5 with PipelineModel

use of org.apache.spark.ml.PipelineModel in project mm-dev by sbl-sdsc.

the class CathClassificationDataset method sequenceToFeatureVector.

private static Dataset<Row> sequenceToFeatureVector(Dataset<Row> data, int n, int windowSize, int vectorSize) {
    // split sequence into an array of one-letter codes (1-grams)
    // e.g. IDCGHVDSL => [i, d, c, g, h, v...
    RegexTokenizer tokenizer = new RegexTokenizer().setInputCol("sequence").setOutputCol("1gram").setPattern("(?!^)");
    // create n-grams out of the sequence
    // e.g., 2-gram [i, d, c, g, h, v... => [i d, d c, c g, g...
    NGram ngrammer = new NGram().setN(n).setInputCol("1gram").setOutputCol("ngram");
    // convert n-grams to W2V feature vector
    // [i d, d c, c g, g... => [0.1234, 0.23948, ...]
    Word2Vec word2Vec = new Word2Vec().setInputCol("ngram").setOutputCol("features").setWindowSize(windowSize).setVectorSize(vectorSize).setMinCount(0);
    Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] { tokenizer, ngrammer, word2Vec });
    // .setStages(new PipelineStage[] {tokenizer, word2Vec});
    PipelineModel model = pipeline.fit(data);
    data = model.transform(data);
    return data;
}
Also used : Word2Vec(org.apache.spark.ml.feature.Word2Vec) RegexTokenizer(org.apache.spark.ml.feature.RegexTokenizer) NGram(org.apache.spark.ml.feature.NGram) Pipeline(org.apache.spark.ml.Pipeline) PipelineModel(org.apache.spark.ml.PipelineModel)

Aggregations

PipelineModel (org.apache.spark.ml.PipelineModel)8 Pipeline (org.apache.spark.ml.Pipeline)3 Transformer (org.apache.spark.ml.Transformer)3 PMML (org.dmg.pmml.PMML)3 File (java.io.File)2 FileOutputStream (java.io.FileOutputStream)2 IOException (java.io.IOException)2 InputStream (java.io.InputStream)2 InputStreamReader (java.io.InputStreamReader)2 OutputStream (java.io.OutputStream)2 ArrayList (java.util.ArrayList)2 LinkedHashMap (java.util.LinkedHashMap)2 List (java.util.List)2 CrossValidatorModel (org.apache.spark.ml.tuning.CrossValidatorModel)2 TrainValidationSplitModel (org.apache.spark.ml.tuning.TrainValidationSplitModel)2 Dataset (org.apache.spark.sql.Dataset)2 Row (org.apache.spark.sql.Row)2 StructType (org.apache.spark.sql.types.StructType)2 Feature (org.jpmml.converter.Feature)2 Function (com.google.common.base.Function)1