Search in sources :

Example 1 with GeneralizedLinearRegressionTrainingSummary

use of org.apache.spark.ml.regression.GeneralizedLinearRegressionTrainingSummary in project net.jgp.labs.spark by jgperrin.

the class GeneralizedLinearRegressionApp method main.

public static void main(String[] args) {
    SparkSession spark = SparkSession.builder().appName("GeneralizedLinearRegressionApp").master("local[*]").getOrCreate();
    // $example on$
    // Load training data
    // Dataset<Row> dataset = spark.read().format("libsvm")
    // .load("data/mllib/sample_linear_regression_data.txt");
    Dataset<Row> dataset = spark.read().format("libsvm").load("data/sample-ml/simplegauss.txt");
    dataset.show(20, false);
    System.out.println("Records: " + dataset.count());
    GeneralizedLinearRegression glr = new GeneralizedLinearRegression().setFamily("gamma").setLink("log").setMaxIter(10).setRegParam(0.3);
    // Fit the model
    GeneralizedLinearRegressionModel model = glr.fit(dataset);
    // Print the coefficients and intercept for generalized linear regression model
    System.out.println("Coefficients: " + model.coefficients());
    System.out.println("Intercept: " + model.intercept());
    // Summarize the model over the training set and print out some metrics
    GeneralizedLinearRegressionTrainingSummary summary = model.summary();
    System.out.println("Coefficient Standard Errors: " + Arrays.toString(summary.coefficientStandardErrors()));
    System.out.println("T Values: " + Arrays.toString(summary.tValues()));
    System.out.println("P Values: " + Arrays.toString(summary.pValues()));
    System.out.println("Dispersion: " + summary.dispersion());
    System.out.println("Null Deviance: " + summary.nullDeviance());
    System.out.println("Residual Degree Of Freedom Null: " + summary.residualDegreeOfFreedomNull());
    System.out.println("Deviance: " + summary.deviance());
    System.out.println("Residual Degree Of Freedom: " + summary.residualDegreeOfFreedom());
    System.out.println("AIC: " + summary.aic());
    System.out.println("Deviance Residuals: ");
    summary.residuals().show();
    Double feature = 2.0;
    Vector features = Vectors.dense(feature);
    double p = model.predict(features);
    System.out.println("Prediction for feature " + feature + " is " + p);
    feature = 11.0;
    features = Vectors.dense(feature);
    p = model.predict(features);
    System.out.println("Prediction for feature " + feature + " is " + p);
    spark.stop();
}
Also used : SparkSession(org.apache.spark.sql.SparkSession) GeneralizedLinearRegressionModel(org.apache.spark.ml.regression.GeneralizedLinearRegressionModel) GeneralizedLinearRegressionTrainingSummary(org.apache.spark.ml.regression.GeneralizedLinearRegressionTrainingSummary) GeneralizedLinearRegression(org.apache.spark.ml.regression.GeneralizedLinearRegression) Row(org.apache.spark.sql.Row) Vector(org.apache.spark.ml.linalg.Vector)

Aggregations

Vector (org.apache.spark.ml.linalg.Vector)1 GeneralizedLinearRegression (org.apache.spark.ml.regression.GeneralizedLinearRegression)1 GeneralizedLinearRegressionModel (org.apache.spark.ml.regression.GeneralizedLinearRegressionModel)1 GeneralizedLinearRegressionTrainingSummary (org.apache.spark.ml.regression.GeneralizedLinearRegressionTrainingSummary)1 Row (org.apache.spark.sql.Row)1 SparkSession (org.apache.spark.sql.SparkSession)1