Search in sources :

Example 1 with NaiveBayesModel

use of org.apache.spark.mllib.classification.NaiveBayesModel in project cdap by caskdata.

the class NaiveBayesTrainer method run.

@Override
public void run(SparkExecutionPluginContext sparkContext, JavaRDD<StructuredRecord> input) throws Exception {
    Preconditions.checkArgument(input.count() != 0, "Input RDD is empty.");
    final HashingTF tf = new HashingTF(100);
    JavaRDD<LabeledPoint> trainingData = input.map(new Function<StructuredRecord, LabeledPoint>() {

        @Override
        public LabeledPoint call(StructuredRecord record) throws Exception {
            // should never happen, here to test app correctness in unit tests
            if (inputSchema != null && !inputSchema.equals(record.getSchema())) {
                throw new IllegalStateException("runtime schema does not match what was set at configure time.");
            }
            String text = record.get(config.fieldToClassify);
            return new LabeledPoint((Double) record.get(config.predictionField), tf.transform(Lists.newArrayList(text.split(" "))));
        }
    });
    trainingData.cache();
    final NaiveBayesModel model = NaiveBayes.train(trainingData.rdd(), 1.0);
    // save the model to a file in the output FileSet
    JavaSparkContext javaSparkContext = sparkContext.getSparkContext();
    FileSet outputFS = sparkContext.getDataset(config.fileSetName);
    model.save(JavaSparkContext.toSparkContext(javaSparkContext), outputFS.getBaseLocation().append(config.path).toURI().getPath());
    JavaPairRDD<Long, String> textsToClassify = sparkContext.fromStream(TEXTS_TO_CLASSIFY, String.class);
    JavaRDD<Vector> featuresToClassify = textsToClassify.map(new Function<Tuple2<Long, String>, Vector>() {

        @Override
        public Vector call(Tuple2<Long, String> longWritableTextTuple2) throws Exception {
            String text = longWritableTextTuple2._2();
            return tf.transform(Lists.newArrayList(text.split(" ")));
        }
    });
    JavaRDD<Double> predict = model.predict(featuresToClassify);
    LOG.info("Predictions: {}", predict.collect());
    // key the predictions with the message
    JavaPairRDD<String, Double> keyedPredictions = textsToClassify.values().zip(predict);
    // convert to byte[],byte[] to write to data
    JavaPairRDD<byte[], byte[]> bytesRDD = keyedPredictions.mapToPair(new PairFunction<Tuple2<String, Double>, byte[], byte[]>() {

        @Override
        public Tuple2<byte[], byte[]> call(Tuple2<String, Double> tuple) throws Exception {
            return new Tuple2<>(Bytes.toBytes(tuple._1()), Bytes.toBytes(tuple._2()));
        }
    });
    sparkContext.saveAsDataset(bytesRDD, CLASSIFIED_TEXTS);
}
Also used : LabeledPoint(org.apache.spark.mllib.regression.LabeledPoint) NaiveBayesModel(org.apache.spark.mllib.classification.NaiveBayesModel) StructuredRecord(co.cask.cdap.api.data.format.StructuredRecord) HashingTF(org.apache.spark.mllib.feature.HashingTF) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Vector(org.apache.spark.mllib.linalg.Vector) FileSet(co.cask.cdap.api.dataset.lib.FileSet) Tuple2(scala.Tuple2)

Aggregations

StructuredRecord (co.cask.cdap.api.data.format.StructuredRecord)1 FileSet (co.cask.cdap.api.dataset.lib.FileSet)1 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)1 NaiveBayesModel (org.apache.spark.mllib.classification.NaiveBayesModel)1 HashingTF (org.apache.spark.mllib.feature.HashingTF)1 Vector (org.apache.spark.mllib.linalg.Vector)1 LabeledPoint (org.apache.spark.mllib.regression.LabeledPoint)1 Tuple2 (scala.Tuple2)1