Search in sources :

Example 1 with LinearRegression

use of org.apache.spark.ml.regression.LinearRegression in project mmtf-spark by sbl-sdsc.

the class DatasetRegressor method main.

/**
 * @param args args[0] path to parquet file, args[1] name of the prediction column
 * @throws IOException
 * @throws StructureException
 */
public static void main(String[] args) throws IOException {
    if (args.length != 2) {
        System.err.println("Usage: " + DatasetRegressor.class.getSimpleName() + " <parquet file> <prediction column name>");
        System.exit(1);
    }
    // name of the prediction column
    String label = args[1];
    long start = System.nanoTime();
    SparkSession spark = SparkSession.builder().master("local[*]").appName(DatasetRegressor.class.getSimpleName()).getOrCreate();
    Dataset<Row> data = spark.read().parquet(args[0]).cache();
    int featureCount = ((DenseVector) data.first().getAs("features")).numActives();
    System.out.println("Feature count: " + featureCount);
    System.out.println("Dataset size : " + data.count());
    double testFraction = 0.3;
    long seed = 123;
    LinearRegression lr = new LinearRegression().setLabelCol(label).setFeaturesCol("features");
    SparkRegressor reg = new SparkRegressor(lr, label, testFraction, seed);
    System.out.println(reg.fit(data));
    GBTRegressor gbt = new GBTRegressor().setLabelCol(label).setFeaturesCol("features");
    reg = new SparkRegressor(gbt, label, testFraction, seed);
    System.out.println(reg.fit(data));
    GeneralizedLinearRegression glr = new GeneralizedLinearRegression().setLabelCol(label).setFeaturesCol("features").setFamily("gaussian").setLink("identity").setMaxIter(10).setRegParam(0.3);
    reg = new SparkRegressor(glr, label, testFraction, seed);
    System.out.println(reg.fit(data));
    long end = System.nanoTime();
    System.out.println((end - start) / 1E9 + " sec");
}
Also used : SparkSession(org.apache.spark.sql.SparkSession) GeneralizedLinearRegression(org.apache.spark.ml.regression.GeneralizedLinearRegression) GBTRegressor(org.apache.spark.ml.regression.GBTRegressor) Row(org.apache.spark.sql.Row) LinearRegression(org.apache.spark.ml.regression.LinearRegression) GeneralizedLinearRegression(org.apache.spark.ml.regression.GeneralizedLinearRegression) DenseVector(org.apache.spark.ml.linalg.DenseVector)

Example 2 with LinearRegression

use of org.apache.spark.ml.regression.LinearRegression in project net.jgp.labs.spark by jgperrin.

the class SimplePredictionFromTextFile method start.

private void start() {
    SparkSession spark = SparkSession.builder().appName("Simple prediction from Text File").master("local").getOrCreate();
    spark.udf().register("vectorBuilder", new VectorBuilder(), new VectorUDT());
    String filename = "data/tuple-data-file.csv";
    StructType schema = new StructType(new StructField[] { new StructField("_c0", DataTypes.DoubleType, false, Metadata.empty()), new StructField("_c1", DataTypes.DoubleType, false, Metadata.empty()), new StructField("features", new VectorUDT(), true, Metadata.empty()) });
    Dataset<Row> df = spark.read().format("csv").schema(schema).option("header", "false").load(filename);
    df = df.withColumn("valuefeatures", df.col("_c0")).drop("_c0");
    df = df.withColumn("label", df.col("_c1")).drop("_c1");
    df.printSchema();
    df = df.withColumn("features", callUDF("vectorBuilder", df.col("valuefeatures")));
    df.printSchema();
    df.show();
    // .setRegParam(1).setElasticNetParam(1);
    LinearRegression lr = new LinearRegression().setMaxIter(20);
    // Fit the model to the data.
    LinearRegressionModel model = lr.fit(df);
    // Given a dataset, predict each point's label, and show the results.
    model.transform(df).show();
    LinearRegressionTrainingSummary trainingSummary = model.summary();
    System.out.println("numIterations: " + trainingSummary.totalIterations());
    System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory()));
    trainingSummary.residuals().show();
    System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError());
    System.out.println("r2: " + trainingSummary.r2());
    double intercept = model.intercept();
    System.out.println("Interesection: " + intercept);
    double regParam = model.getRegParam();
    System.out.println("Regression parameter: " + regParam);
    double tol = model.getTol();
    System.out.println("Tol: " + tol);
    Double feature = 7.0;
    Vector features = Vectors.dense(feature);
    double p = model.predict(features);
    System.out.println("Prediction for feature " + feature + " is " + p);
    System.out.println(8 * regParam + intercept);
}
Also used : VectorUDT(org.apache.spark.ml.linalg.VectorUDT) SparkSession(org.apache.spark.sql.SparkSession) StructType(org.apache.spark.sql.types.StructType) LinearRegressionModel(org.apache.spark.ml.regression.LinearRegressionModel) StructField(org.apache.spark.sql.types.StructField) VectorBuilder(net.jgp.labs.spark.x.udf.VectorBuilder) Row(org.apache.spark.sql.Row) LinearRegression(org.apache.spark.ml.regression.LinearRegression) Vector(org.apache.spark.ml.linalg.Vector) LinearRegressionTrainingSummary(org.apache.spark.ml.regression.LinearRegressionTrainingSummary)

Aggregations

LinearRegression (org.apache.spark.ml.regression.LinearRegression)2 Row (org.apache.spark.sql.Row)2 SparkSession (org.apache.spark.sql.SparkSession)2 VectorBuilder (net.jgp.labs.spark.x.udf.VectorBuilder)1 DenseVector (org.apache.spark.ml.linalg.DenseVector)1 Vector (org.apache.spark.ml.linalg.Vector)1 VectorUDT (org.apache.spark.ml.linalg.VectorUDT)1 GBTRegressor (org.apache.spark.ml.regression.GBTRegressor)1 GeneralizedLinearRegression (org.apache.spark.ml.regression.GeneralizedLinearRegression)1 LinearRegressionModel (org.apache.spark.ml.regression.LinearRegressionModel)1 LinearRegressionTrainingSummary (org.apache.spark.ml.regression.LinearRegressionTrainingSummary)1 StructField (org.apache.spark.sql.types.StructField)1 StructType (org.apache.spark.sql.types.StructType)1