Search in sources :

Example 1 with DecisionTreeClassifier

use of org.apache.spark.ml.classification.DecisionTreeClassifier in project mmtf-spark by sbl-sdsc.

the class DatasetClassifier method main.

/**
 * @param args args[0] path to parquet file, args[1] name of classification column
 * @throws IOException
 * @throws StructureException
 */
public static void main(String[] args) throws IOException {
    if (args.length != 2) {
        System.err.println("Usage: " + DatasetClassifier.class.getSimpleName() + " <parquet file> <classification column name>");
        System.exit(1);
    }
    // name of the class label
    String label = args[1];
    long start = System.nanoTime();
    SparkSession spark = SparkSession.builder().master("local[*]").appName(DatasetClassifier.class.getSimpleName()).getOrCreate();
    Dataset<Row> data = spark.read().parquet(args[0]).cache();
    int featureCount = 0;
    Object vector = data.first().getAs("features");
    if (vector instanceof DenseVector) {
        featureCount = ((DenseVector) vector).numActives();
    } else if (vector instanceof SparseVector) {
        featureCount = ((SparseVector) vector).numActives();
    }
    System.out.println("Feature count            : " + featureCount);
    int classCount = (int) data.select(label).distinct().count();
    System.out.println("Class count              : " + classCount);
    System.out.println("Dataset size (unbalanced): " + data.count());
    data.groupBy(label).count().show(classCount);
    data = DatasetBalancer.downsample(data, label, 1);
    System.out.println("Dataset size (balanced)  : " + data.count());
    data.groupBy(label).count().show(classCount);
    double testFraction = 0.3;
    long seed = 123;
    SparkMultiClassClassifier mcc;
    Map<String, String> metrics;
    DecisionTreeClassifier dtc = new DecisionTreeClassifier();
    mcc = new SparkMultiClassClassifier(dtc, label, testFraction, seed);
    metrics = mcc.fit(data);
    System.out.println(metrics);
    RandomForestClassifier rfc = new RandomForestClassifier();
    mcc = new SparkMultiClassClassifier(rfc, label, testFraction, seed);
    metrics = mcc.fit(data);
    System.out.println(metrics);
    LogisticRegression lr = new LogisticRegression();
    mcc = new SparkMultiClassClassifier(lr, label, testFraction, seed);
    metrics = mcc.fit(data);
    System.out.println(metrics);
    // specify layers for the neural network
    // input layer: dimension of feature vector
    // output layer: number of classes
    int[] layers = new int[] { featureCount, 10, classCount };
    MultilayerPerceptronClassifier mpc = new MultilayerPerceptronClassifier().setLayers(layers).setBlockSize(128).setSeed(1234L).setMaxIter(200);
    mcc = new SparkMultiClassClassifier(mpc, label, testFraction, seed);
    metrics = mcc.fit(data);
    System.out.println(metrics);
    long end = System.nanoTime();
    System.out.println((end - start) / 1E9 + " sec");
}
Also used : SparkSession(org.apache.spark.sql.SparkSession) SparseVector(org.apache.spark.ml.linalg.SparseVector) RandomForestClassifier(org.apache.spark.ml.classification.RandomForestClassifier) MultilayerPerceptronClassifier(org.apache.spark.ml.classification.MultilayerPerceptronClassifier) DecisionTreeClassifier(org.apache.spark.ml.classification.DecisionTreeClassifier) Row(org.apache.spark.sql.Row) LogisticRegression(org.apache.spark.ml.classification.LogisticRegression) DenseVector(org.apache.spark.ml.linalg.DenseVector)

Aggregations

DecisionTreeClassifier (org.apache.spark.ml.classification.DecisionTreeClassifier)1 LogisticRegression (org.apache.spark.ml.classification.LogisticRegression)1 MultilayerPerceptronClassifier (org.apache.spark.ml.classification.MultilayerPerceptronClassifier)1 RandomForestClassifier (org.apache.spark.ml.classification.RandomForestClassifier)1 DenseVector (org.apache.spark.ml.linalg.DenseVector)1 SparseVector (org.apache.spark.ml.linalg.SparseVector)1 Row (org.apache.spark.sql.Row)1 SparkSession (org.apache.spark.sql.SparkSession)1